Tenncor
rmdups.hpp
Go to the documentation of this file.
1 
9 #include "opt/stats.hpp"
10 
11 #ifndef OPT_RMDUPS_HPP
12 #define OPT_RMDUPS_HPP
13 
14 namespace opt
15 {
16 
20 void replace_parents (const teq::ParentFinder& pfinder,
21  teq::iTensor* source, teq::TensptrT target);
22 
29 template <typename T>
30 std::vector<T> remove_duplicates (teq::TensptrsT& roots, std::vector<T> tens,
31  const teq::ParentFinder& pfinder,
32  tag::TagRegistry& registry = tag::get_reg())
33 {
34  if (tens.empty())
35  {
36  return {};
37  }
38 
39  teq::TensSetT priorities;
40  std::unordered_map<teq::iTensor*,std::vector<size_t>> rindices;
41  for (size_t i = 0, n = roots.size(); i < n; ++i)
42  {
43  teq::TensptrT& root = roots[i];
44  priorities.emplace(root.get());
45  rindices[root.get()].push_back(i);
46  }
47 
48  std::sort(tens.begin(), tens.end(),
49  [&priorities](T& a, T& b) { return lt(priorities, a.get(), b.get()); });
50  T last = tens[0];
51  std::vector<T> uniques = {last};
52  size_t n = tens.size();
53  uniques.reserve(n - 1);
54  for (size_t i = 1; i < n; ++i)
55  {
56  T& cur = tens[i];
57  if (is_equal(last.get(), cur.get()))
58  {
59  logs::debugf("replacing %s", cur->to_string().c_str());
60  // remove equivalent node
61  replace_parents(pfinder, cur.get(), last);
62 
63  auto it = rindices.find(cur.get());
64  if (rindices.end() != it)
65  {
66  for (size_t ridx : it->second)
67  {
68  roots[ridx] = last;
69  }
70  }
71 
72  // todo: mark parents as uninitialized, reinitialize entire graph, or uninitialize everything to begin with
73 
74  // inherit tags
75  registry.move_tags(last, cur.get());
76  }
77  else
78  {
79  uniques.push_back(cur);
80  last = cur;
81  }
82  }
83  return uniques;
84 }
85 
87 using ImmutablesT = std::vector<teq::LeafptrT>;
88 
90 using HFunctorsT = std::vector<std::vector<teq::FuncptrT>>;
91 
94 void populate_graph (ImmutablesT& immutables, HFunctorsT& functors,
95  const teq::TensptrsT& roots);
96 
100  ImmutablesT& immutables, HFunctorsT& functors);
101 
102 }
103 
104 #endif // OPT_RMDUPS_HPP
std::vector< std::vector< teq::FuncptrT > > HFunctorsT
Matrix of functors.
Definition: rmdups.hpp:90
bool is_equal(teq::CoordptrT a, teq::CoordptrT b)
Return true if a is equal to b.
std::unordered_set< teq::iTensor * > TensSetT
Hash set of raw tensor pointers.
Definition: itensor.hpp:63
bool lt(teq::CoordptrT a, teq::CoordptrT b)
Return true if a < b according to some internal ordinal rule.
Registry for associating tensors to tag collectives.
Definition: tag.hpp:165
void populate_graph(ImmutablesT &immutables, HFunctorsT &functors, const teq::TensptrsT &roots)
Definition: candidate.hpp:19
std::vector< teq::LeafptrT > ImmutablesT
Vector of presumably immutable leaves.
Definition: rmdups.hpp:87
Traveler that for each child tracks the relationship to all parents.
Definition: traveler.hpp:162
std::shared_ptr< iTensor > TensptrT
Tensor smart pointer.
Definition: itensor.hpp:51
std::vector< TensptrT > TensptrsT
Vector of tensor smart pointers.
Definition: itensor.hpp:60
void replace_parents(const teq::ParentFinder &pfinder, teq::iTensor *source, teq::TensptrT target)
TagRegistry & get_reg(void)
Return reference to global tag registry.
Interface of traversible and differentiable nodes with shape information.
Definition: itensor.hpp:36
void remove_all_duplicates(teq::TensptrsT &roots, ImmutablesT &immutables, HFunctorsT &functors)
std::vector< T > remove_duplicates(teq::TensptrsT &roots, std::vector< T > tens, const teq::ParentFinder &pfinder, tag::TagRegistry &registry=tag::get_reg())
Definition: rmdups.hpp:30