Tenncor
grad_def.hpp
Go to the documentation of this file.
1 
9 #include <list>
10 
11 #include "teq/traveler.hpp"
12 
13 #ifndef TEQ_GRAD_DEF_HPP
14 #define TEQ_GRAD_DEF_HPP
15 
16 namespace teq
17 {
18 
29 {
30  virtual ~iGradientBuilder (void) = default;
31 
34  virtual TensptrT local_derivative (FuncptrT op, size_t arg_idx) const = 0;
35 
41  virtual TensptrT chain_rule (FuncptrT op, const TensptrT& local_der,
42  TensptrT supcomp_grad, size_t arg_idx) const = 0;
43 
45  virtual TensptrT get_const_one (Shape shape) const = 0;
46 
48  virtual TensptrT get_const_zero (Shape shape) const = 0;
49 
51  virtual TensptrT add (TensptrT& lhs, TensptrT& rhs) const = 0;
52 
54  TensptrT derive (TensptrT root, TensptrT target) const
55  {
56  if (root == target)
57  {
58  return get_const_one(target->shape());
59  }
60 
61  PathFinder finder(target.get());
62  root->accept(finder);
63 
64  auto& pathmap = finder.parents_;
65  // no path to wrt
66  if (pathmap.empty())
67  {
68  return get_const_zero(target->shape());
69  }
70  // else there exists a path to wrt
71  // using pathfinder, breadth first traverse from this to wrt
72  GraphStat stat;
73  root->accept(stat);
74  auto owners = track_owners({root});
75 
76  std::list<iFunctor*> parents;
77  std::transform(pathmap.begin(), pathmap.end(),
78  std::back_inserter(parents),
79  [](std::pair<iTensor*,std::vector<size_t>> parent)
80  {
81  return static_cast<iFunctor*>(parent.first);
82  });
83  parents.sort(
84  [&](iFunctor* a, iFunctor* b)
85  {
86  size_t aheight = stat.graphsize_[a].upper_;
87  size_t bheight = stat.graphsize_[b].upper_;
88  if (aheight == bheight) // make traversal more deterministic
89  {
90  return a->to_string() > b->to_string();
91  }
92  return aheight > bheight;
93  });
94 
95  // map functor to its respective super composite derivative
96  // let L = root, F = key functor, value of F in grads is dL/dF
97  std::unordered_map<const iTensor*,TensptrsT> grads = {
98  {root.get(), {get_const_one(root->shape())}}
99  };
100  for (iFunctor* parent : parents)
101  {
102  TensptrsT& gargs = grads[parent];
103  TensptrT bwd = gargs[0];
104  for (size_t i = 1, n = gargs.size(); i < n; ++i)
105  {
106  bwd = add(bwd, gargs[i]);
107  }
108 
109  auto& grad_indices = pathmap[parent];
110  ArgsT children = parent->get_children();
111  size_t nchildren = children.size();
112  // assert: all nnary-children use identity mapping,
113  // so no children-arg is direct mapping
114  TensptrsT args(nchildren);
115  std::transform(children.begin(), children.end(), args.begin(),
116  [](FuncArg& arg)
117  {
118  return arg.get_tensor();
119  });
120  // for each painted child, calculate dThis/dChild
121  // go through grads in order
122  std::list<size_t> ordered(grad_indices.begin(), grad_indices.end());
123  ordered.sort();
124  for (size_t i : ordered)
125  {
126  auto parent_ptr = std::static_pointer_cast<iFunctor>(
127  owners[parent].lock());
128  auto local = local_derivative(parent_ptr, i);
129  auto grad_step = chain_rule(parent_ptr, local, bwd, i);
130  grads[args[i].get()].push_back(grad_step);
131  }
132  }
133  TensptrsT& outargs = grads[target.get()];
134  TensptrT out = outargs[0];
135  for (size_t i = 1, n = outargs.size(); i < n; ++i)
136  {
137  out = add(out, outargs[i]);
138  }
139  return out;
140  }
141 };
142 
143 }
144 
145 #endif // TEQ_GRAD_DEF_HPP
args
Definition: csv_to_png.py:105
std::unordered_map< iTensor *, estd::NumRange< size_t > > graphsize_
Definition: traveler.hpp:105
Interface of iOperation-defined operation node.
Definition: ifunctor.hpp:28
std::vector< FuncArg > ArgsT
Type of functor arguments.
Definition: funcarg.hpp:101
virtual TensptrT add(TensptrT &lhs, TensptrT &rhs) const =0
Return functor representing lhs + rhs.
Definition: shape.hpp:62
Coordinate mapper and tensor pair.
Definition: funcarg.hpp:21
Definition: traveler.hpp:116
Traveler that maps each tensor to its subtree&#39;s maximum depth.
Definition: traveler.hpp:57
Definition: grad_def.hpp:28
std::shared_ptr< iTensor > TensptrT
Tensor smart pointer.
Definition: itensor.hpp:51
std::vector< TensptrT > TensptrsT
Vector of tensor smart pointers.
Definition: itensor.hpp:60
virtual TensptrT get_const_one(Shape shape) const =0
Return tensor representing 1 constant.
virtual ~iGradientBuilder(void)=default
virtual std::string to_string(void) const =0
Return the string representation of the tensor.
Definition: coord.hpp:16
ParentMapT parents_
Map of parent to child indices that lead to target tensor.
Definition: traveler.hpp:158
virtual TensptrT get_const_zero(Shape shape) const =0
Return tensor representing 0 constant.
std::shared_ptr< iFunctor > FuncptrT
Functor smart pointer.
Definition: ifunctor.hpp:49
TensptrT derive(TensptrT root, TensptrT target) const
Return derivative of root with respect to target.
Definition: grad_def.hpp:54
virtual TensptrT local_derivative(FuncptrT op, size_t arg_idx) const =0
OwnerMapT track_owners(TensptrsT roots)
virtual TensptrT chain_rule(FuncptrT op, const TensptrT &local_der, TensptrT supcomp_grad, size_t arg_idx) const =0