Tenncor
session.hpp
Go to the documentation of this file.
1 
10 #include <list>
11 #include <unordered_set>
12 
13 #include "teq/traveler.hpp"
14 
15 #include "opt/optimize.hpp"
16 
17 #include "eteq/constant.hpp"
18 #include "eteq/functor.hpp"
19 
20 #ifndef ETEQ_SESSION_HPP
21 #define ETEQ_SESSION_HPP
22 
23 namespace eteq
24 {
25 
27 struct iSession
28 {
29  virtual ~iSession (void) = default;
30 
32  virtual void track (teq::TensptrsT roots) = 0;
33 
37  virtual void update (teq::TensSetT ignored = {}) = 0;
38 
42  virtual void update_target (teq::TensSetT target, teq::TensSetT ignored = {}) = 0;
43 };
44 
47 struct Session final : public iSession
48 {
50  void track (teq::TensptrsT roots) override
51  {
52  ops_.clear();
53  tracked_.insert(roots.begin(), roots.end());
54 
55  teq::GraphStat stat;
56  for (teq::TensptrT& root : roots)
57  {
58  root->accept(stat);
59  }
60  auto& statmap = stat.graphsize_;
61 
62  for (auto& statpair : statmap)
63  {
64  if (0 < statpair.second.upper_)
65  {
66  // ensure we only track operable functors
67  auto op = dynamic_cast<teq::iOperableFunc*>(statpair.first);
68  if (nullptr == op)
69  {
70  logs::fatalf("cannot track non-operable functor %s",
71  statpair.first->to_string().c_str());
72  }
73  ops_.push_back(op);
74  }
75  }
76  std::sort(ops_.begin(), ops_.end(),
77  [&statmap](teq::iOperableFunc* a, teq::iOperableFunc* b)
78  { return statmap[a].upper_ < statmap[b].upper_; });
79  }
80 
82  void update (teq::TensSetT ignored = {}) override
83  {
84  std::list<teq::iOperableFunc*> reqs;
85  teq::TensSetT acceptable;
86  for (auto& root : tracked_)
87  {
88  acceptable.emplace(root.get());
89  }
90  // ignored tensors will never populate reqs
91  for (auto rit = ops_.rbegin(), ret = ops_.rend();
92  rit != ret; ++rit)
93  {
94  auto& op = *rit;
95  if (estd::has(acceptable, op) &&
96  false == estd::has(ignored, op))
97  {
98  reqs.push_front(op);
99  auto& children = op->get_children();
100  for (auto& child : children)
101  {
102  acceptable.emplace(child.get_tensor().get());
103  }
104  }
105  }
106 
107  for (auto& op : reqs)
108  {
109  op->update();
110  }
111  }
112 
114  void update_target (teq::TensSetT target, teq::TensSetT ignored = {}) override
115  {
116  std::list<teq::iOperableFunc*> reqs;
117  teq::TensSetT acceptable;
118  for (auto& root : target)
119  {
120  acceptable.emplace(root);
121  }
122  // ignored tensors will never populate reqs
123  for (auto rit = ops_.rbegin(), ret = ops_.rend();
124  rit != ret; ++rit)
125  {
126  auto& op = *rit;
127  if (estd::has(acceptable, op) &&
128  false == estd::has(ignored, op))
129  {
130  reqs.push_front(op);
131  auto& children = op->get_children();
132  for (auto& child : children)
133  {
134  acceptable.emplace(child.get_tensor().get());
135  }
136  }
137  }
138 
139  for (auto& op : reqs)
140  {
141  op->update();
142  }
143  }
144 
146  void optimize (const opt::OptCtx& rules)
147  {
148  teq::TensptrsT tracked(tracked_.begin(), tracked_.end());
149  opt::optimize(tracked, rules);
150  track(tracked);
151  }
152 
156 
158  std::vector<teq::iOperableFunc*> ops_;
159 };
160 
161 }
162 
163 #endif // ETEQ_SESSION_HPP
std::unordered_set< teq::iTensor * > TensSetT
Hash set of raw tensor pointers.
Definition: itensor.hpp:63
std::unordered_map< iTensor *, estd::NumRange< size_t > > graphsize_
Definition: traveler.hpp:105
A functor node with direct access to evaluated data.
Definition: iopfunc.hpp:20
std::unordered_set< teq::TensptrT > TensptrSetT
Hash set of tensor smart pointers.
Definition: itensor.hpp:66
Definition: constant.hpp:17
virtual void update_target(teq::TensSetT target, teq::TensSetT ignored={})=0
Definition: session.hpp:47
void track(teq::TensptrsT roots) override
Implementation of iSession.
Definition: session.hpp:50
void optimize(const opt::OptCtx &rules)
Apply input optimization rules using opt module, then re-track.
Definition: session.hpp:146
teq::TensptrsT optimize(teq::TensptrsT roots, const OptCtx &opts)
Session interface that tracks and rapidly updates subgraphs.
Definition: session.hpp:27
std::vector< teq::iOperableFunc * > ops_
Operable functors ordered by height in the tracked graph.
Definition: session.hpp:158
teq::TensptrSetT tracked_
Definition: session.hpp:155
Encapsulation of all conversion rules.
Definition: optimize.hpp:23
Traveler that maps each tensor to its subtree&#39;s maximum depth.
Definition: traveler.hpp:57
std::shared_ptr< iTensor > TensptrT
Tensor smart pointer.
Definition: itensor.hpp:51
std::vector< TensptrT > TensptrsT
Vector of tensor smart pointers.
Definition: itensor.hpp:60
virtual ~iSession(void)=default
virtual void track(teq::TensptrsT roots)=0
Record subgraphs of roots.
void update(teq::TensSetT ignored={}) override
Implementation of iSession.
Definition: session.hpp:82
void update_target(teq::TensSetT target, teq::TensSetT ignored={}) override
Implementation of iSession.
Definition: session.hpp:114
virtual void update(teq::TensSetT ignored={})=0