11 #include <boost/asio/thread_pool.hpp> 12 #include <boost/asio/post.hpp> 27 using SessReqsT = std::vector<std::pair<teq::iOperableFunc*,size_t>>;
30 using LSessReqsT = std::list<std::pair<teq::iOperableFunc*,size_t>>;
47 tracked_.insert(roots.begin(), roots.end());
57 root->accept(pfinder);
63 for (
auto& group : groups)
66 reqs.reserve(group.size());
69 auto&
args = func->get_children();
73 auto tens = arg.get_tensor().get();
76 unique_children.emplace(tens);
81 unique_children.size()
87 for (
auto& assocs : pfinder.
parents_)
89 for (
auto& parent_pair : assocs.second)
92 static_cast<teq::iOperableFunc*>(parent_pair.first));
99 if (tpair.second.upper_ > 0)
101 ops_.emplace(static_cast<teq::iOperableFunc*>(tpair.first));
110 std::vector<LSessReqsT> indep_requirements(nthreads);
111 for (
size_t i = 0; i < nthreads; ++i)
114 auto& indep_reqs = indep_requirements[i];
118 acceptable.emplace(root.get());
121 for (
auto rit = reqs.rbegin(), ret = reqs.rend();
124 auto& op = rit->first;
125 if (estd::has(acceptable, op) &&
126 false == estd::has(ignored, op))
128 indep_reqs.push_front({op, rit->second});
129 auto& children = op->get_children();
130 for (
auto& child : children)
132 acceptable.emplace(child.get_tensor().get());
141 fulfilments.emplace(op, 0);
144 for (
auto ig : ignored)
146 std::unordered_set<teq::iOperableFunc*> op_parents;
147 if (estd::get(op_parents,
parents_, ig))
149 for (
auto& op_parent : op_parents)
151 ++fulfilments.at(op_parent);
157 boost::asio::thread_pool pool(nthreads);
158 for (
auto& reqs : indep_requirements)
161 boost::asio::post(pool,
162 [
this, &reqs, &fulfilments]()
164 for (
auto& op : reqs)
167 auto& ff = fulfilments.at(op.first);
168 if (ff++ == op.second)
171 std::unordered_set<teq::iOperableFunc*> op_parents;
172 if (estd::get(op_parents,
175 for (
auto& op_parent : op_parents)
177 ++fulfilments.at(op_parent);
194 std::vector<LSessReqsT> indep_requirements(nthreads);
195 for (
size_t i = 0; i < nthreads; ++i)
198 auto& indep_reqs = indep_requirements[i];
200 for (
auto& root : target)
202 acceptable.emplace(root);
205 for (
auto rit = reqs.rbegin(), ret = reqs.rend();
208 auto& op = rit->first;
209 if (estd::has(acceptable, op) &&
210 false == estd::has(ignored, op))
212 indep_reqs.push_front({op, rit->second});
213 auto& children = op->get_children();
214 for (
auto& child : children)
216 acceptable.emplace(child.get_tensor().get());
225 fulfilments.emplace(op, 0);
228 for (
auto ig : ignored)
230 std::unordered_set<teq::iOperableFunc*> op_parents;
231 if (estd::get(op_parents,
parents_, ig))
233 for (
auto& op_parent : op_parents)
235 ++fulfilments.at(op_parent);
241 boost::asio::thread_pool pool(nthreads);
242 for (
auto& reqs : indep_requirements)
245 boost::asio::post(pool,
246 [
this, &reqs, &fulfilments]()
248 for (
auto& op : reqs)
251 auto& ff = fulfilments.at(op.first);
252 if (ff++ == op.second)
255 std::unordered_set<teq::iOperableFunc*> op_parents;
256 if (estd::get(op_parents,
259 for (
auto& op_parent : op_parents)
261 ++fulfilments.at(op_parent);
299 std::unordered_set<teq::iOperableFunc*>
ops_;
304 #endif // CCUR_SESS_HPP args
Definition: csv_to_png.py:105
std::unordered_set< teq::iTensor * > TensSetT
Hash set of raw tensor pointers.
Definition: itensor.hpp:63
std::unordered_map< iTensor *, estd::NumRange< size_t > > graphsize_
Definition: traveler.hpp:105
A functor node with direct access to evaluated data.
Definition: iopfunc.hpp:20
std::vector< std::pair< teq::iOperableFunc *, size_t > > SessReqsT
Definition: session.hpp:27
std::vector< std::vector< teq::iFunctor * > > PartGroupsT
Groups of functors.
Definition: partition.hpp:20
size_t nthreads_
Definition: session.hpp:295
Interface of iOperation-defined operation node.
Definition: ifunctor.hpp:28
std::unordered_set< teq::TensptrT > TensptrSetT
Hash set of tensor smart pointers.
Definition: itensor.hpp:66
std::unordered_map< iTensor *, ParentMapT > parents_
Definition: traveler.hpp:189
teq::TensptrsT optimize(teq::TensptrsT roots, const OptCtx &opts)
Session interface that tracks and rapidly updates subgraphs.
Definition: session.hpp:27
teq::TensptrSetT tracked_
Definition: session.hpp:284
std::unordered_map< teq::iOperableFunc *, std::atomic< long > > AtomicFulfilMapT
Definition: session.hpp:35
std::unordered_set< teq::iOperableFunc * > ops_
Definition: session.hpp:299
Traveler that for each child tracks the relationship to all parents.
Definition: traveler.hpp:162
Coordinate mapper and tensor pair.
Definition: funcarg.hpp:21
Encapsulation of all conversion rules.
Definition: optimize.hpp:23
Traveler that maps each tensor to its subtree's maximum depth.
Definition: traveler.hpp:57
Definition: session.hpp:39
Definition: partition.hpp:16
std::unordered_map< teq::iTensor *, std::unordered_set< teq::iOperableFunc * > > parents_
Map of tensor to the set of the tensor's parents.
Definition: session.hpp:288
PartGroupsT k_partition(teq::TensptrsT roots, size_t k, OpWeightT weights=OpWeightT())
Return k groups of graphs under roots given some weight.
void update_target(teq::TensSetT target, teq::TensSetT ignored={}) override
Implementation of iSession.
Definition: session.hpp:190
std::unordered_map< size_t, double > OpWeightT
Map functor opcode to the operation's weight value.
Definition: partition.hpp:23
std::shared_ptr< iTensor > TensptrT
Tensor smart pointer.
Definition: itensor.hpp:51
std::vector< TensptrT > TensptrsT
Vector of tensor smart pointers.
Definition: itensor.hpp:60
Interface of traversible and differentiable nodes with shape information.
Definition: itensor.hpp:36
void track(teq::TensptrsT roots) override
Implementation of iSession.
Definition: session.hpp:45
void update(teq::TensSetT ignored={}) override
Implementation of iSession.
Definition: session.hpp:107
std::list< std::pair< teq::iOperableFunc *, size_t > > LSessReqsT
Same as SessReqsT except as a list.
Definition: session.hpp:30
OpWeightT weights_
Definition: session.hpp:297
std::vector< SessReqsT > requirements_
Definition: session.hpp:292
void optimize(const opt::OptCtx &rules)
Apply input optimization rules using opt module, then re-track.
Definition: session.hpp:274
Session(size_t nthreads=2, OpWeightT weights=OpWeightT())
Definition: session.hpp:41