Tenncor
Public Member Functions | List of all members
teq::iGradientBuilder Struct Referenceabstract

#include <grad_def.hpp>

Inheritance diagram for teq::iGradientBuilder:
Inheritance graph
[legend]

Public Member Functions

virtual ~iGradientBuilder (void)=default
 
virtual TensptrT local_derivative (FuncptrT op, size_t arg_idx) const =0
 
virtual TensptrT chain_rule (FuncptrT op, const TensptrT &local_der, TensptrT supcomp_grad, size_t arg_idx) const =0
 
virtual TensptrT get_const_one (Shape shape) const =0
 Return tensor representing 1 constant. More...
 
virtual TensptrT get_const_zero (Shape shape) const =0
 Return tensor representing 0 constant. More...
 
virtual TensptrT add (TensptrT &lhs, TensptrT &rhs) const =0
 Return functor representing lhs + rhs. More...
 
TensptrT derive (TensptrT root, TensptrT target) const
 Return derivative of root with respect to target. More...
 

Detailed Description

Define manditory definitions required for tensor differentiation For some graph F(G(x)), chain rule for calculating dF/dx is defined in the following order:

  1. calcualte dF/dG => F local derivative and derivative of super composition (supcomp_grad for G)
  2. calculate dG/dx => G local derivative
  3. chain dF/dG (supcomp_grad) and dG/dx (local_der) This top-down approach updates tensor shape information such that output derivative dF/dx has the shape of x

Constructor & Destructor Documentation

◆ ~iGradientBuilder()

virtual teq::iGradientBuilder::~iGradientBuilder ( void  )
virtualdefault

Member Function Documentation

◆ add()

virtual TensptrT teq::iGradientBuilder::add ( TensptrT lhs,
TensptrT rhs 
) const
pure virtual

Return functor representing lhs + rhs.

Implemented in eteq::GradientBuilder< T >.

◆ chain_rule()

virtual TensptrT teq::iGradientBuilder::chain_rule ( FuncptrT  op,
const TensptrT local_der,
TensptrT  supcomp_grad,
size_t  arg_idx 
) const
pure virtual

Let op be functor F with arguments args, and local_der is derivative of F wrt one of args (say x) Let supcomp_grad be defined as dG/dF where G is some super-functor using F Return derivative G wrt to arg x by applying chain rule

Implemented in eteq::GradientBuilder< T >.

◆ derive()

TensptrT teq::iGradientBuilder::derive ( TensptrT  root,
TensptrT  target 
) const
inline

Return derivative of root with respect to target.

◆ get_const_one()

virtual TensptrT teq::iGradientBuilder::get_const_one ( Shape  shape) const
pure virtual

Return tensor representing 1 constant.

Implemented in eteq::GradientBuilder< T >.

◆ get_const_zero()

virtual TensptrT teq::iGradientBuilder::get_const_zero ( Shape  shape) const
pure virtual

Return tensor representing 0 constant.

Implemented in eteq::GradientBuilder< T >.

◆ local_derivative()

virtual TensptrT teq::iGradientBuilder::local_derivative ( FuncptrT  op,
size_t  arg_idx 
) const
pure virtual

Let op be functor F with arguments args Return derivative of F wrt args[arg_idx]

Implemented in eteq::GradientBuilder< T >.


The documentation for this struct was generated from the following file: