Tenncor
tag.hpp
Go to the documentation of this file.
1 
9 #include <map>
10 #include <set>
11 
12 #include "teq/teq.hpp"
13 
14 #ifndef TAG_TAG_HPP
15 #define TAG_TAG_HPP
16 
17 namespace tag
18 {
19 
21 using TagRepsT = std::map<std::string,std::vector<std::string>>;
22 
25 struct iTag
26 {
27  virtual ~iTag (void) = default;
28 
30  virtual size_t tag_id (void) const = 0;
31 
33  virtual void absorb (std::unique_ptr<iTag>&& other) = 0;
34 
36  virtual TagRepsT get_tags (void) const = 0;
37 };
38 
40 using TagptrT = std::unique_ptr<iTag>;
41 
47 struct TagCollective final
48 {
49  TagCollective (void) = default;
50 
51  TagCollective (TagCollective&& other) : tags_(std::move(other.tags_)) {}
52 
54  {
55  if (this != &other)
56  {
57  tags_ = std::move(other.tags_);
58  }
59  return *this;
60  }
61 
63  void absorb (TagCollective&& other)
64  {
65  for (auto& tagpair : other.tags_)
66  {
67  size_t tid = tagpair.first;
68  auto it = tags_.find(tid);
69  if (tags_.end() == it)
70  {
71  tags_.emplace(tid, std::move(tagpair.second));
72  }
73  else
74  {
75  it->second->absorb(std::move(tagpair.second));
76  }
77  }
78  other.tags_.clear();
79  }
80 
83  void add (TagptrT entry)
84  {
85  size_t tid = entry->tag_id();
86  auto it = tags_.find(tid);
87  if (tags_.end() == it)
88  {
89  tags_.emplace(tid, std::move(entry));
90  }
91  else
92  {
93  it->second->absorb(std::move(entry));
94  }
95  }
96 
98  TagRepsT get_tags (void) const
99  {
100  TagRepsT tags;
101  for (auto& tpair : tags_)
102  {
103  auto temp = tpair.second->get_tags();
104  tags.insert(temp.begin(), temp.end());
105  }
106  return tags;
107  }
108 
109 private:
110  std::unordered_map<size_t,TagptrT> tags_;
111 };
112 
114 struct TensKey final
115 {
116  TensKey (teq::TensrefT tens) : val_(tens.lock().get()), ref_(tens) {}
117 
118  // used to match keys
119  TensKey (teq::iTensor* tens) : val_(tens) {}
120 
121  TensKey (const teq::iTensor* tens) : val_(tens) {}
122 
124  operator const teq::iTensor*() const
125  {
126  return val_;
127  }
128 
130  bool expired (void) const
131  {
132  return ref_.expired();
133  }
134 
138 
141 };
142 
144 struct TensKeyHash final
145 {
146  size_t operator() (const TensKey& key) const
147  {
148  return std::hash<const void*>()(key.val_);
149  }
150 };
151 
153 inline bool operator == (const TensKey& lhs, const TensKey& rhs)
154 {
155  TensKeyHash hasher;
156  return hasher(lhs) == hasher(rhs);
157 }
158 
160 using TagrF = std::function<void(teq::TensrefT,std::string)>;
161 
162 // todo: move tag registry to some session that claims global context
163 // todo: make an interface for this
165 struct TagRegistry final
166 {
169  {
170  if (tens.expired())
171  {
172  logs::fatal("cannot tag with expired tensor ref");
173  }
174  auto it = registry_.find(TensKey(tens));
175  // clear out previous entry that is expired
176  if (registry_.end() != it && it->first.expired())
177  {
178  registry_.erase(tens.lock().get());
179  }
180  registry_[tens].add(std::move(tag));
181  }
182 
185  {
186  auto it = registry_.find(TensKey(tens));
187  if (registry_.end() == it || it->first.expired())
188  {
189  return {};
190  }
191  return it->second.get_tags();
192  }
193 
199  void move_tags (teq::TensrefT dest, const teq::iTensor* source)
200  {
201  if (dest.expired())
202  {
203  logs::fatal("cannot move with expired destination tensor");
204  }
205  auto src_it = registry_.find(TensKey(source));
206  auto dest_it = registry_.find(TensKey(dest));
207  if (registry_.end() == src_it || src_it->first.expired())
208  {
209  return;
210  }
211 
212  if (registry_.end() == dest_it || dest_it->first.expired())
213  {
214  registry_[dest] = std::move(src_it->second);
215  }
216  else
217  {
218  dest_it->second.absorb(std::move(src_it->second));
219  }
220  registry_.erase(TensKey(source));
221  }
222 
224  TagrF tagr_by_key (std::string tag_key)
225  {
226  return estd::must_getf(key_tagr_assoc_, tag_key,
227  "cannot find tagr associated with %s", tag_key.c_str());
228  }
229 
232  std::string register_tagr (std::string tag_key, TagrF tagr)
233  {
234  key_tagr_assoc_.emplace(tag_key, tagr);
235  return tag_key;
236  }
237 
239  std::unordered_map<TensKey,TagCollective,TensKeyHash> registry_;
240 
241 private:
242  std::unordered_map<std::string,TagrF> key_tagr_assoc_;
243 };
244 
246 TagRegistry& get_reg (void);
247 
250 void recursive_tag (teq::TensptrT root,
251  teq::TensSetT stops,
252  std::function<void(teq::TensrefT)> tag_op);
253 
255 using LTensT = std::unordered_map<std::string,std::vector<teq::iTensor*>>;
256 
258 using TTensT = std::unordered_map<std::string,LTensT>;
259 
262 struct Query final : public teq::OnceTraveler
263 {
264  Query (TagRegistry& reg = get_reg()) : reg_(reg) {}
265 
267  void visit_leaf (teq::iLeaf* leaf) override
268  {
269  auto tags = reg_.get_tags(leaf);
270  save_tags(tags, leaf);
271  }
272 
274  void visit_func (teq::iFunctor* func) override
275  {
276  auto& children = func->get_children();
277  for (auto child : children)
278  {
279  child.get_tensor()->accept(*this);
280  }
281 
282  auto tags = reg_.get_tags(func);
283  save_tags(tags, func);
284  }
285 
288 
291 
292 private:
294  {
295  for (auto& tpair : tag)
296  {
297  auto& labs = labels_[tpair.first];
298  for (auto lpair : tpair.second)
299  {
300  labs[lpair].push_back(tens);
301  }
302  }
303  }
304 };
305 
306 }
307 
308 #endif // TAG_TAG_HPP
TensKey hasher.
Definition: tag.hpp:144
std::unordered_set< teq::iTensor * > TensSetT
Hash set of raw tensor pointers.
Definition: itensor.hpp:63
void recursive_tag(teq::TensptrT root, teq::TensSetT stops, std::function< void(teq::TensrefT)> tag_op)
std::function< void(teq::TensrefT, std::string)> TagrF
Function that associate tag key to tensor ref.
Definition: tag.hpp:160
TagRepsT get_tags(void) const
Return all key-values under collected iTag.
Definition: tag.hpp:98
virtual const ArgsT & get_children(void) const =0
Return children nodes as a vector of raw pointers.
Definition: tag.hpp:25
virtual void absorb(std::unique_ptr< iTag > &&other)=0
Add key and labels pairs of other.
const teq::iTensor * val_
Definition: tag.hpp:137
Interface of iOperation-defined operation node.
Definition: ifunctor.hpp:28
Registry for associating tensors to tag collectives.
Definition: tag.hpp:165
std::unordered_map< std::string, TagrF > key_tagr_assoc_
Definition: tag.hpp:242
TensKey(teq::TensrefT tens)
Definition: tag.hpp:116
std::unordered_map< std::string, std::vector< teq::iTensor * > > LTensT
Map tag label to any tensor with label.
Definition: tag.hpp:255
virtual TagRepsT get_tags(void) const =0
Return all key-labels associations.
void move_tags(teq::TensrefT dest, const teq::iTensor *source)
Definition: tag.hpp:199
Definition: group.hpp:17
std::unique_ptr< iTag > TagptrT
Unique pointer of tag.
Definition: tag.hpp:40
TagCollective(void)=default
virtual ~iTag(void)=default
std::map< std::string, std::vector< std::string > > TagRepsT
Map tag key to a series of labels.
Definition: tag.hpp:21
void visit_func(teq::iFunctor *func) override
Gather the tag key-label to tensor association of visited functor.
Definition: tag.hpp:274
std::unordered_map< TensKey, TagCollective, TensKeyHash > registry_
Map tensor to tag collective.
Definition: tag.hpp:239
void visit_leaf(teq::iLeaf *leaf) override
Gather the tag key-label to tensor association of visited leaf.
Definition: tag.hpp:267
TagRegistry & reg_
Tag registry providing the reverse tensor to <key:labels> associations.
Definition: tag.hpp:290
void save_tags(TagRepsT &tag, teq::iTensor *tens)
Definition: tag.hpp:293
std::string register_tagr(std::string tag_key, TagrF tagr)
Definition: tag.hpp:232
Query(TagRegistry &reg=get_reg())
Definition: tag.hpp:264
Collective of generic iTag instances.
Definition: tag.hpp:47
void add(TagptrT entry)
Definition: tag.hpp:83
bool operator==(const TensKey &lhs, const TensKey &rhs)
TensKey equality overload.
Definition: tag.hpp:153
Definition: tag.hpp:262
std::unordered_map< size_t, TagptrT > tags_
Definition: tag.hpp:110
std::shared_ptr< iTensor > TensptrT
Tensor smart pointer.
Definition: itensor.hpp:51
TensKey(teq::iTensor *tens)
Definition: tag.hpp:119
size_t operator()(const TensKey &key) const
Definition: tag.hpp:146
teq::TensrefT ref_
Weak reference of tensor.
Definition: tag.hpp:140
TTensT labels_
Map <key:label> to tensors found under tag regsitry.
Definition: tag.hpp:287
TagrF tagr_by_key(std::string tag_key)
Return tagger associated to TagRepsT key.
Definition: tag.hpp:224
Tensor ref key wrapper.
Definition: tag.hpp:114
TagCollective & operator=(TagCollective &&other)
Definition: tag.hpp:53
TagRegistry & get_reg(void)
Return reference to global tag registry.
TagRepsT get_tags(const teq::iTensor *tens)
Return all key-labels under the collective associated with tens.
Definition: tag.hpp:184
bool expired(void) const
Return true if weak reference is expired.
Definition: tag.hpp:130
Interface of traversible and differentiable nodes with shape information.
Definition: itensor.hpp:36
void add_tag(teq::TensrefT tens, TagptrT tag)
Add tag to collective referenced by tens.
Definition: tag.hpp:168
Extremely generic traveler that visits every node in the graph once.
Definition: traveler.hpp:22
TagCollective(TagCollective &&other)
Definition: tag.hpp:51
virtual size_t tag_id(void) const =0
Return type hash of tag instance.
std::weak_ptr< iTensor > TensrefT
Tensor weak pointers.
Definition: itensor.hpp:54
void absorb(TagCollective &&other)
Absorb iTags of the other collective.
Definition: tag.hpp:63
Leaf of the graph commonly representing the variable in an equation.
Definition: ileaf.hpp:19
std::unordered_map< std::string, LTensT > TTensT
Map tag key to label-tensor association.
Definition: tag.hpp:258
TensKey(const teq::iTensor *tens)
Definition: tag.hpp:121