21 using TagRepsT = std::map<std::string,std::vector<std::string>>;
27 virtual ~iTag (
void) =
default;
30 virtual size_t tag_id (
void)
const = 0;
33 virtual void absorb (std::unique_ptr<iTag>&& other) = 0;
57 tags_ = std::move(other.tags_);
65 for (
auto& tagpair : other.tags_)
67 size_t tid = tagpair.first;
68 auto it =
tags_.find(tid);
69 if (
tags_.end() == it)
71 tags_.emplace(tid, std::move(tagpair.second));
75 it->second->absorb(std::move(tagpair.second));
85 size_t tid = entry->tag_id();
86 auto it =
tags_.find(tid);
87 if (
tags_.end() == it)
89 tags_.emplace(tid, std::move(entry));
93 it->second->absorb(std::move(entry));
101 for (
auto& tpair :
tags_)
103 auto temp = tpair.second->get_tags();
104 tags.insert(temp.begin(), temp.end());
110 std::unordered_map<size_t,TagptrT>
tags_;
132 return ref_.expired();
148 return std::hash<const void*>()(key.
val_);
156 return hasher(lhs) == hasher(rhs);
160 using TagrF = std::function<void(teq::TensrefT,std::string)>;
172 logs::fatal(
"cannot tag with expired tensor ref");
176 if (
registry_.end() != it && it->first.expired())
187 if (
registry_.end() == it || it->first.expired())
191 return it->second.get_tags();
203 logs::fatal(
"cannot move with expired destination tensor");
207 if (
registry_.end() == src_it || src_it->first.expired())
212 if (
registry_.end() == dest_it || dest_it->first.expired())
214 registry_[dest] = std::move(src_it->second);
218 dest_it->second.absorb(std::move(src_it->second));
227 "cannot find tagr associated with %s", tag_key.c_str());
239 std::unordered_map<TensKey,TagCollective,TensKeyHash>
registry_;
255 using LTensT = std::unordered_map<std::string,std::vector<teq::iTensor*>>;
258 using TTensT = std::unordered_map<std::string,LTensT>;
277 for (
auto child : children)
279 child.get_tensor()->accept(*
this);
295 for (
auto& tpair :
tag)
297 auto& labs =
labels_[tpair.first];
298 for (
auto lpair : tpair.second)
300 labs[lpair].push_back(tens);
308 #endif // TAG_TAG_HPP TensKey hasher.
Definition: tag.hpp:144
std::unordered_set< teq::iTensor * > TensSetT
Hash set of raw tensor pointers.
Definition: itensor.hpp:63
void recursive_tag(teq::TensptrT root, teq::TensSetT stops, std::function< void(teq::TensrefT)> tag_op)
std::function< void(teq::TensrefT, std::string)> TagrF
Function that associate tag key to tensor ref.
Definition: tag.hpp:160
TagRepsT get_tags(void) const
Return all key-values under collected iTag.
Definition: tag.hpp:98
virtual const ArgsT & get_children(void) const =0
Return children nodes as a vector of raw pointers.
virtual void absorb(std::unique_ptr< iTag > &&other)=0
Add key and labels pairs of other.
const teq::iTensor * val_
Definition: tag.hpp:137
Interface of iOperation-defined operation node.
Definition: ifunctor.hpp:28
Registry for associating tensors to tag collectives.
Definition: tag.hpp:165
std::unordered_map< std::string, TagrF > key_tagr_assoc_
Definition: tag.hpp:242
TensKey(teq::TensrefT tens)
Definition: tag.hpp:116
std::unordered_map< std::string, std::vector< teq::iTensor * > > LTensT
Map tag label to any tensor with label.
Definition: tag.hpp:255
virtual TagRepsT get_tags(void) const =0
Return all key-labels associations.
void move_tags(teq::TensrefT dest, const teq::iTensor *source)
Definition: tag.hpp:199
std::unique_ptr< iTag > TagptrT
Unique pointer of tag.
Definition: tag.hpp:40
TagCollective(void)=default
virtual ~iTag(void)=default
std::map< std::string, std::vector< std::string > > TagRepsT
Map tag key to a series of labels.
Definition: tag.hpp:21
void visit_func(teq::iFunctor *func) override
Gather the tag key-label to tensor association of visited functor.
Definition: tag.hpp:274
std::unordered_map< TensKey, TagCollective, TensKeyHash > registry_
Map tensor to tag collective.
Definition: tag.hpp:239
void visit_leaf(teq::iLeaf *leaf) override
Gather the tag key-label to tensor association of visited leaf.
Definition: tag.hpp:267
TagRegistry & reg_
Tag registry providing the reverse tensor to <key:labels> associations.
Definition: tag.hpp:290
void save_tags(TagRepsT &tag, teq::iTensor *tens)
Definition: tag.hpp:293
std::string register_tagr(std::string tag_key, TagrF tagr)
Definition: tag.hpp:232
Query(TagRegistry ®=get_reg())
Definition: tag.hpp:264
Collective of generic iTag instances.
Definition: tag.hpp:47
void add(TagptrT entry)
Definition: tag.hpp:83
bool operator==(const TensKey &lhs, const TensKey &rhs)
TensKey equality overload.
Definition: tag.hpp:153
std::unordered_map< size_t, TagptrT > tags_
Definition: tag.hpp:110
std::shared_ptr< iTensor > TensptrT
Tensor smart pointer.
Definition: itensor.hpp:51
TensKey(teq::iTensor *tens)
Definition: tag.hpp:119
size_t operator()(const TensKey &key) const
Definition: tag.hpp:146
teq::TensrefT ref_
Weak reference of tensor.
Definition: tag.hpp:140
TTensT labels_
Map <key:label> to tensors found under tag regsitry.
Definition: tag.hpp:287
TagrF tagr_by_key(std::string tag_key)
Return tagger associated to TagRepsT key.
Definition: tag.hpp:224
Tensor ref key wrapper.
Definition: tag.hpp:114
TagCollective & operator=(TagCollective &&other)
Definition: tag.hpp:53
TagRegistry & get_reg(void)
Return reference to global tag registry.
TagRepsT get_tags(const teq::iTensor *tens)
Return all key-labels under the collective associated with tens.
Definition: tag.hpp:184
bool expired(void) const
Return true if weak reference is expired.
Definition: tag.hpp:130
Interface of traversible and differentiable nodes with shape information.
Definition: itensor.hpp:36
void add_tag(teq::TensrefT tens, TagptrT tag)
Add tag to collective referenced by tens.
Definition: tag.hpp:168
Extremely generic traveler that visits every node in the graph once.
Definition: traveler.hpp:22
TagCollective(TagCollective &&other)
Definition: tag.hpp:51
virtual size_t tag_id(void) const =0
Return type hash of tag instance.
std::weak_ptr< iTensor > TensrefT
Tensor weak pointers.
Definition: itensor.hpp:54
void absorb(TagCollective &&other)
Absorb iTags of the other collective.
Definition: tag.hpp:63
Leaf of the graph commonly representing the variable in an equation.
Definition: ileaf.hpp:19
std::unordered_map< std::string, LTensT > TTensT
Map tag key to label-tensor association.
Definition: tag.hpp:258
TensKey(const teq::iTensor *tens)
Definition: tag.hpp:121