11 #include <unordered_set> 20 #ifndef ETEQ_SESSION_HPP 21 #define ETEQ_SESSION_HPP 53 tracked_.insert(roots.begin(), roots.end());
62 for (
auto& statpair : statmap)
64 if (0 < statpair.second.upper_)
70 logs::fatalf(
"cannot track non-operable functor %s",
71 statpair.first->to_string().c_str());
78 {
return statmap[a].upper_ < statmap[b].upper_; });
84 std::list<teq::iOperableFunc*> reqs;
88 acceptable.emplace(root.get());
91 for (
auto rit =
ops_.rbegin(), ret =
ops_.rend();
95 if (estd::has(acceptable, op) &&
96 false == estd::has(ignored, op))
99 auto& children = op->get_children();
100 for (
auto& child : children)
102 acceptable.emplace(child.get_tensor().get());
107 for (
auto& op : reqs)
116 std::list<teq::iOperableFunc*> reqs;
118 for (
auto& root : target)
120 acceptable.emplace(root);
123 for (
auto rit =
ops_.rbegin(), ret =
ops_.rend();
127 if (estd::has(acceptable, op) &&
128 false == estd::has(ignored, op))
131 auto& children = op->get_children();
132 for (
auto& child : children)
134 acceptable.emplace(child.get_tensor().get());
139 for (
auto& op : reqs)
158 std::vector<teq::iOperableFunc*>
ops_;
163 #endif // ETEQ_SESSION_HPP std::unordered_set< teq::iTensor * > TensSetT
Hash set of raw tensor pointers.
Definition: itensor.hpp:63
std::unordered_map< iTensor *, estd::NumRange< size_t > > graphsize_
Definition: traveler.hpp:105
A functor node with direct access to evaluated data.
Definition: iopfunc.hpp:20
std::unordered_set< teq::TensptrT > TensptrSetT
Hash set of tensor smart pointers.
Definition: itensor.hpp:66
Definition: constant.hpp:17
virtual void update_target(teq::TensSetT target, teq::TensSetT ignored={})=0
Definition: session.hpp:47
void track(teq::TensptrsT roots) override
Implementation of iSession.
Definition: session.hpp:50
void optimize(const opt::OptCtx &rules)
Apply input optimization rules using opt module, then re-track.
Definition: session.hpp:146
teq::TensptrsT optimize(teq::TensptrsT roots, const OptCtx &opts)
Session interface that tracks and rapidly updates subgraphs.
Definition: session.hpp:27
std::vector< teq::iOperableFunc * > ops_
Operable functors ordered by height in the tracked graph.
Definition: session.hpp:158
teq::TensptrSetT tracked_
Definition: session.hpp:155
Encapsulation of all conversion rules.
Definition: optimize.hpp:23
Traveler that maps each tensor to its subtree's maximum depth.
Definition: traveler.hpp:57
std::shared_ptr< iTensor > TensptrT
Tensor smart pointer.
Definition: itensor.hpp:51
std::vector< TensptrT > TensptrsT
Vector of tensor smart pointers.
Definition: itensor.hpp:60
virtual ~iSession(void)=default
virtual void track(teq::TensptrsT roots)=0
Record subgraphs of roots.
void update(teq::TensSetT ignored={}) override
Implementation of iSession.
Definition: session.hpp:82
void update_target(teq::TensSetT target, teq::TensSetT ignored={}) override
Implementation of iSession.
Definition: session.hpp:114
virtual void update(teq::TensSetT ignored={})=0