13 #ifndef TEQ_GRAD_DEF_HPP 14 #define TEQ_GRAD_DEF_HPP 42 TensptrT supcomp_grad,
size_t arg_idx)
const = 0;
76 std::list<iFunctor*> parents;
77 std::transform(pathmap.begin(), pathmap.end(),
78 std::back_inserter(parents),
79 [](std::pair<iTensor*,std::vector<size_t>> parent)
81 return static_cast<iFunctor*
>(parent.first);
88 if (aheight == bheight)
92 return aheight > bheight;
97 std::unordered_map<const iTensor*,TensptrsT> grads = {
104 for (
size_t i = 1, n = gargs.size(); i < n; ++i)
106 bwd =
add(bwd, gargs[i]);
109 auto& grad_indices = pathmap[parent];
110 ArgsT children = parent->get_children();
111 size_t nchildren = children.size();
115 std::transform(children.begin(), children.end(),
args.begin(),
118 return arg.get_tensor();
122 std::list<size_t> ordered(grad_indices.begin(), grad_indices.end());
124 for (
size_t i : ordered)
126 auto parent_ptr = std::static_pointer_cast<
iFunctor>(
127 owners[parent].lock());
129 auto grad_step =
chain_rule(parent_ptr, local, bwd, i);
130 grads[
args[i].get()].push_back(grad_step);
133 TensptrsT& outargs = grads[target.get()];
135 for (
size_t i = 1, n = outargs.size(); i < n; ++i)
137 out =
add(out, outargs[i]);
145 #endif // TEQ_GRAD_DEF_HPP args
Definition: csv_to_png.py:105
std::unordered_map< iTensor *, estd::NumRange< size_t > > graphsize_
Definition: traveler.hpp:105
Interface of iOperation-defined operation node.
Definition: ifunctor.hpp:28
std::vector< FuncArg > ArgsT
Type of functor arguments.
Definition: funcarg.hpp:101
virtual TensptrT add(TensptrT &lhs, TensptrT &rhs) const =0
Return functor representing lhs + rhs.
Coordinate mapper and tensor pair.
Definition: funcarg.hpp:21
Definition: traveler.hpp:116
Traveler that maps each tensor to its subtree's maximum depth.
Definition: traveler.hpp:57
Definition: grad_def.hpp:28
std::shared_ptr< iTensor > TensptrT
Tensor smart pointer.
Definition: itensor.hpp:51
std::vector< TensptrT > TensptrsT
Vector of tensor smart pointers.
Definition: itensor.hpp:60
virtual TensptrT get_const_one(Shape shape) const =0
Return tensor representing 1 constant.
virtual ~iGradientBuilder(void)=default
virtual std::string to_string(void) const =0
Return the string representation of the tensor.
ParentMapT parents_
Map of parent to child indices that lead to target tensor.
Definition: traveler.hpp:158
virtual TensptrT get_const_zero(Shape shape) const =0
Return tensor representing 0 constant.
std::shared_ptr< iFunctor > FuncptrT
Functor smart pointer.
Definition: ifunctor.hpp:49
TensptrT derive(TensptrT root, TensptrT target) const
Return derivative of root with respect to target.
Definition: grad_def.hpp:54
virtual TensptrT local_derivative(FuncptrT op, size_t arg_idx) const =0
OwnerMapT track_owners(TensptrsT roots)
virtual TensptrT chain_rule(FuncptrT op, const TensptrT &local_der, TensptrT supcomp_grad, size_t arg_idx) const =0