11 #include "ade/ade.hpp" 25 ade::iTensor*,std::unordered_set<size_t>>;
28 using PruneFuncT = std::function<ade::TensptrT(ade::iFunctor*,\
29 std::unordered_set<size_t>,ade::ArgsT)>;
40 void visit (ade::iLeaf* leaf)
override 49 void visit (ade::iFunctor* func)
override 53 auto& children = func->get_children();
54 size_t n = children.size();
55 std::unordered_set<size_t> path;
56 for (
size_t i = 0; i < n; ++i)
58 ade::TensptrT tens = children[i].get_tensor();
66 if (
false == path.empty())
80 std::unordered_set<ade::iTensor*>
founds_;
97 ade::TensptrT
prune (ade::TensptrT root)
102 auto& pathmap =
finder_.parents_;
110 std::list<ade::iFunctor*> parents;
111 std::transform(pathmap.begin(), pathmap.end(),
112 std::back_inserter(parents),
113 [](std::pair<ade::iTensor*,std::unordered_set<size_t>> parent)
115 return static_cast<ade::iFunctor*
>(parent.first);
118 [&](ade::iTensor* a, ade::iTensor* b)
120 return stat.graphsize_[a] < stat.graphsize_[b];
125 std::unordered_set<ade::iTensor*> targets =
finder_.founds_;
127 std::unordered_map<ade::iTensor*,ade::TensptrT> mapping;
128 std::unordered_map<ade::iTensor*,bool> targetmap;
129 for (ade::iFunctor* func : parents)
131 ade::ArgsT children = func->get_children();
133 std::unordered_set<size_t> indices = pathmap[func];
134 for (
auto it = indices.begin(), et = indices.end(); it != et;)
136 ade::MappedTensor& child = children[*it];
137 ade::iTensor* tens = child.get_tensor().get();
139 auto zit = targets.find(tens);
140 if (targets.end() == zit)
142 it = indices.erase(it);
149 auto fwd =
pruner_(func, indices, children);
150 mapping.emplace(func, fwd);
152 if (ade::iLeaf* fwdleaf = dynamic_cast<ade::iLeaf*>(fwd.get()))
156 targets.emplace(func);
160 auto it = mapping.find(root.get());
161 if (mapping.end() == it)
164 "GraphStat failed to identify children of root subgraph");
179 #endif // OPT_SHEAR_HPP void visit(ade::iLeaf *leaf) override
Implementation of iTraveler.
Definition: shear.hpp:40
PruneFuncT pruner_
Prune functor defining how to prune a given graph.
Definition: shear.hpp:174
std::function< T(ade::iLeaf *)> GetLeafValT
Functor for getting leaf values.
Definition: shear.hpp:21
GetLeafValT< T > get_leaf_
Leaf value getter.
Definition: shear.hpp:77
std::unordered_map< ade::iTensor *, std::unordered_set< size_t > > ParentMapT
Type for mapping function nodes in path to boolean vector.
Definition: shear.hpp:25
void visit(ade::iFunctor *func) override
Implementation of iTraveler.
Definition: shear.hpp:49
TargetPruner(T target, GetLeafValT< T > find_target, PruneFuncT pruner)
Definition: shear.hpp:93
LeafFinder(T target, GetLeafValT< T > get_leaf)
Definition: shear.hpp:36
LeafFinder< T > finder_
Target finding traveler.
Definition: shear.hpp:171
T target_
Target of label all paths are travelling to.
Definition: shear.hpp:74
std::function< ade::TensptrT(ade::iFunctor *, std::unordered_set< size_t >, ade::ArgsT)> PruneFuncT
Pruning functor type.
Definition: shear.hpp:29
std::unordered_set< ade::iTensor * > founds_
Set of leaf nodes found.
Definition: shear.hpp:80
ParentMapT parents_
Map of parent nodes in path.
Definition: shear.hpp:83
ade::TensptrT prune(ade::TensptrT root)
Prune graph of root Tensptr.
Definition: shear.hpp:97