Cortenn
shear.hpp
Go to the documentation of this file.
1 
9 #include <list>
10 
11 #include "ade/ade.hpp"
12 
13 #ifndef OPT_SHEAR_HPP
14 #define OPT_SHEAR_HPP
15 
16 namespace opt
17 {
18 
20 template <typename T>
21 using GetLeafValT = std::function<T(ade::iLeaf*)>;
22 
24 using ParentMapT = std::unordered_map<
25  ade::iTensor*,std::unordered_set<size_t>>;
26 
28 using PruneFuncT = std::function<ade::TensptrT(ade::iFunctor*,\
29  std::unordered_set<size_t>,ade::ArgsT)>;
30 
33 template <typename T>
34 struct LeafFinder final : public ade::iTraveler
35 {
36  LeafFinder (T target, GetLeafValT<T> get_leaf) :
37  target_(target), get_leaf_(get_leaf) {}
38 
40  void visit (ade::iLeaf* leaf) override
41  {
42  if (target_ == get_leaf_(leaf))
43  {
44  founds_.emplace(leaf);
45  }
46  }
47 
49  void visit (ade::iFunctor* func) override
50  {
51  if (parents_.end() == parents_.find(func))
52  {
53  auto& children = func->get_children();
54  size_t n = children.size();
55  std::unordered_set<size_t> path;
56  for (size_t i = 0; i < n; ++i)
57  {
58  ade::TensptrT tens = children[i].get_tensor();
59  tens->accept(*this);
60  if (parents_.end() != parents_.find(tens.get()) ||
61  founds_.end() != founds_.find(tens.get()))
62  {
63  path.emplace(i);
64  }
65  }
66  if (false == path.empty())
67  {
68  parents_[func] = path;
69  }
70  }
71  }
72 
75 
78 
80  std::unordered_set<ade::iTensor*> founds_;
81 
84 };
85 
90 template <typename T>
92 {
93  TargetPruner (T target, GetLeafValT<T> find_target, PruneFuncT pruner) :
94  finder_(target, find_target), pruner_(pruner) {}
95 
97  ade::TensptrT prune (ade::TensptrT root)
98  {
99  // assert that context will be unaffected by prune,
100  // since source will never be touched
101  root->accept(finder_);
102  auto& pathmap = finder_.parents_;
103  if (pathmap.empty()) // not path to target or root is not a parent
104  {
105  return root;
106  }
107  ade::GraphStat stat;
108  root->accept(stat);
109  // grab the intersection of stat.funcs_ and pathmap
110  std::list<ade::iFunctor*> parents;
111  std::transform(pathmap.begin(), pathmap.end(),
112  std::back_inserter(parents),
113  [](std::pair<ade::iTensor*,std::unordered_set<size_t>> parent)
114  {
115  return static_cast<ade::iFunctor*>(parent.first);
116  });
117  parents.sort(
118  [&](ade::iTensor* a, ade::iTensor* b)
119  {
120  return stat.graphsize_[a] < stat.graphsize_[b];
121  });
122 
123  // each proceeding node in parents list is closer to target
124  // start pruning according to each parent node in order
125  std::unordered_set<ade::iTensor*> targets = finder_.founds_;
126 
127  std::unordered_map<ade::iTensor*,ade::TensptrT> mapping;
128  std::unordered_map<ade::iTensor*,bool> targetmap;
129  for (ade::iFunctor* func : parents)
130  {
131  ade::ArgsT children = func->get_children();
132  // indices lead to target nodes
133  std::unordered_set<size_t> indices = pathmap[func];
134  for (auto it = indices.begin(), et = indices.end(); it != et;)
135  {
136  ade::MappedTensor& child = children[*it];
137  ade::iTensor* tens = child.get_tensor().get();
138  // child is not target, so erase ot from indices
139  auto zit = targets.find(tens);
140  if (targets.end() == zit)
141  {
142  it = indices.erase(it);
143  }
144  else
145  {
146  ++it;
147  }
148  }
149  auto fwd = pruner_(func, indices, children);
150  mapping.emplace(func, fwd);
151  // func maps to target, so store in targets
152  if (ade::iLeaf* fwdleaf = dynamic_cast<ade::iLeaf*>(fwd.get()))
153  {
154  if (finder_.get_leaf_(fwdleaf) == finder_.target_)
155  {
156  targets.emplace(func);
157  }
158  }
159  }
160  auto it = mapping.find(root.get());
161  if (mapping.end() == it)
162  {
163  logs::fatal(
164  "GraphStat failed to identify children of root subgraph");
165  }
166  return it->second;
167  }
168 
169 private:
172 
175 };
176 
177 }
178 
179 #endif // OPT_SHEAR_HPP
void visit(ade::iLeaf *leaf) override
Implementation of iTraveler.
Definition: shear.hpp:40
PruneFuncT pruner_
Prune functor defining how to prune a given graph.
Definition: shear.hpp:174
std::function< T(ade::iLeaf *)> GetLeafValT
Functor for getting leaf values.
Definition: shear.hpp:21
GetLeafValT< T > get_leaf_
Leaf value getter.
Definition: shear.hpp:77
std::unordered_map< ade::iTensor *, std::unordered_set< size_t > > ParentMapT
Type for mapping function nodes in path to boolean vector.
Definition: shear.hpp:25
void visit(ade::iFunctor *func) override
Implementation of iTraveler.
Definition: shear.hpp:49
Definition: shear.hpp:16
TargetPruner(T target, GetLeafValT< T > find_target, PruneFuncT pruner)
Definition: shear.hpp:93
LeafFinder(T target, GetLeafValT< T > get_leaf)
Definition: shear.hpp:36
LeafFinder< T > finder_
Target finding traveler.
Definition: shear.hpp:171
Definition: shear.hpp:34
T target_
Target of label all paths are travelling to.
Definition: shear.hpp:74
std::function< ade::TensptrT(ade::iFunctor *, std::unordered_set< size_t >, ade::ArgsT)> PruneFuncT
Pruning functor type.
Definition: shear.hpp:29
std::unordered_set< ade::iTensor * > founds_
Set of leaf nodes found.
Definition: shear.hpp:80
ParentMapT parents_
Map of parent nodes in path.
Definition: shear.hpp:83
Definition: shear.hpp:91
ade::TensptrT prune(ade::TensptrT root)
Prune graph of root Tensptr.
Definition: shear.hpp:97