Tenncor
init.hpp
Go to the documentation of this file.
1 
9 #include "eteq/variable.hpp"
10 #include "eteq/random.hpp"
11 
12 #ifndef LAYR_INIT_HPP
13 #define LAYR_INIT_HPP
14 
15 namespace layr
16 {
17 
19 template <typename T>
20 using InitF = std::function<eteq::VarptrT<T>(teq::Shape,std::string)>;
21 
23 template <typename T>
24 using ShapeFactorF = std::function<T(teq::Shape)>;
25 
27 template <typename T>
28 T fanio (teq::Shape shape)
29 {
30  return shape.at(0) + shape.at(1);
31 }
32 
34 template <typename T>
35 T fanavg (teq::Shape shape)
36 {
37  return fanio<T>(shape) / 2;
38 }
39 
42 template <typename T>
43 void truncated_normal (std::vector<T>& out, teq::Shape shape, T mean, T stdev,
44  size_t max_repick = 5)
45 {
46  size_t n = shape.n_elems();
47  out = std::vector<T>(n);
48  auto gen = eteq::norm_gen<T>(mean, stdev);
49  std::generate(out.begin(), out.end(), gen);
50  // if T is not decimal, program would fail to compile therefore T is signed
51  T upperbound = mean + 2 * stdev;
52  T lowerbound = mean - 2 * stdev;
53  for (size_t i = 0; i < n; ++i)
54  {
55  // keep repicking until we give-up (statistical unlikely)
56  for (size_t retry = 0;
57  (out[i] > upperbound || out[i] < lowerbound) && max_repick;
58  ++retry)
59  {
60  out[i] = gen();
61  }
62  // clip
63  if (out[i] > upperbound)
64  {
65  out[i] = upperbound;
66  }
67  else if (out[i] < lowerbound)
68  {
69  out[i] = lowerbound;
70  }
71  }
72 }
73 
75 template <typename T>
77 {
78  return
79  [](teq::Shape shape, std::string label)
80  {
81  return eteq::make_variable_scalar<T>(0, shape, label);
82  };
83 }
84 
87 template <typename T>
88 InitF<T> variance_scaling_init (T factor, ShapeFactorF<T> sfactor=fanavg<T>)
89 {
90  return
91  [factor, sfactor](teq::Shape shape, std::string label)
92  {
93  std::vector<T> vec;
94  T stdev = std::sqrt(factor / sfactor(shape));
95  truncated_normal<T>(vec, shape, 0, stdev);
96  return eteq::make_variable(vec.data(), shape, label);
97  };
98 }
99 
102 template <typename T>
104 {
105  return
106  [factor](teq::Shape shape, std::string label)
107  {
108  std::vector<T> vec(shape.n_elems());
109  T bound = factor * std::sqrt(6. / fanio<T>(shape));
110  std::generate(vec.begin(), vec.end(), eteq::unif_gen<T>(-bound, bound));
111  return eteq::make_variable(vec.data(), shape, label);
112  };
113 }
114 
117 template <typename T>
119 {
120  return
121  [factor](teq::Shape shape, std::string label)
122  {
123  std::vector<T> vec(shape.n_elems());
124  T stdev = factor * std::sqrt(2. / fanio<T>(shape));
125  std::generate(vec.begin(), vec.end(), eteq::norm_gen<T>(0., stdev));
126  return eteq::make_variable(vec.data(), shape, label);
127  };
128 }
129 
130 }
131 
132 #endif // LAYR_INIT_HPP
InitF< T > variance_scaling_init(T factor, ShapeFactorF< T > sfactor=fanavg< T >)
Definition: init.hpp:88
std::function< T(teq::Shape)> ShapeFactorF
Function that returns some metric of a shape.
Definition: init.hpp:24
EigenptrT< T > sqrt(teq::Shape &outshape, const OpArg< T > &in)
Definition: operator.hpp:427
Definition: shape.hpp:62
InitF< T > norm_xavier_init(T factor=1)
Definition: init.hpp:118
Definition: conv.hpp:16
InitF< T > zero_init(void)
Return initialization function that makes zero variables.
Definition: init.hpp:76
std::function< eteq::VarptrT< T >(teq::Shape, std::string)> InitF
Function that produces a variable given the variable&#39;s shape and label.
Definition: init.hpp:20
InitF< T > unif_xavier_init(T factor=1)
Definition: init.hpp:103
T fanavg(teq::Shape shape)
Return the mean of the first 2 dimensions of a shape.
Definition: init.hpp:35
T fanio(teq::Shape shape)
Return the sum of the first 2 dimensions of a shape.
Definition: init.hpp:28
DimT at(RankT idx) const
Return DimT element at idx for any index in range [0:rank_cap)
Definition: shape.hpp:108
VarptrT< T > make_variable(teq::Shape shape, std::string label="")
Return zero-initialized variable node of specified shape.
Definition: variable.hpp:230
void truncated_normal(std::vector< T > &out, teq::Shape shape, T mean, T stdev, size_t max_repick=5)
Definition: init.hpp:43
NElemT n_elems(void) const
Return the total number of elements represented by the shape.
Definition: shape.hpp:118