30 return shape.
at(0) + shape.
at(1);
37 return fanio<T>(shape) / 2;
44 size_t max_repick = 5)
47 out = std::vector<T>(n);
48 auto gen = eteq::norm_gen<T>(mean, stdev);
49 std::generate(out.begin(), out.end(), gen);
51 T upperbound = mean + 2 * stdev;
52 T lowerbound = mean - 2 * stdev;
53 for (
size_t i = 0; i < n; ++i)
56 for (
size_t retry = 0;
57 (out[i] > upperbound || out[i] < lowerbound) && max_repick;
63 if (out[i] > upperbound)
67 else if (out[i] < lowerbound)
81 return eteq::make_variable_scalar<T>(0, shape, label);
91 [factor, sfactor](
teq::Shape shape, std::string label)
94 T stdev =
std::sqrt(factor / sfactor(shape));
95 truncated_normal<T>(vec, shape, 0, stdev);
102 template <
typename T>
108 std::vector<T> vec(shape.
n_elems());
109 T bound = factor *
std::sqrt(6. / fanio<T>(shape));
110 std::generate(vec.begin(), vec.end(), eteq::unif_gen<T>(-bound, bound));
117 template <
typename T>
123 std::vector<T> vec(shape.
n_elems());
124 T stdev = factor *
std::sqrt(2. / fanio<T>(shape));
125 std::generate(vec.begin(), vec.end(), eteq::norm_gen<T>(0., stdev));
132 #endif // LAYR_INIT_HPP InitF< T > variance_scaling_init(T factor, ShapeFactorF< T > sfactor=fanavg< T >)
Definition: init.hpp:88
std::function< T(teq::Shape)> ShapeFactorF
Function that returns some metric of a shape.
Definition: init.hpp:24
EigenptrT< T > sqrt(teq::Shape &outshape, const OpArg< T > &in)
Definition: operator.hpp:427
InitF< T > norm_xavier_init(T factor=1)
Definition: init.hpp:118
InitF< T > zero_init(void)
Return initialization function that makes zero variables.
Definition: init.hpp:76
std::function< eteq::VarptrT< T >(teq::Shape, std::string)> InitF
Function that produces a variable given the variable's shape and label.
Definition: init.hpp:20
InitF< T > unif_xavier_init(T factor=1)
Definition: init.hpp:103
T fanavg(teq::Shape shape)
Return the mean of the first 2 dimensions of a shape.
Definition: init.hpp:35
T fanio(teq::Shape shape)
Return the sum of the first 2 dimensions of a shape.
Definition: init.hpp:28
DimT at(RankT idx) const
Return DimT element at idx for any index in range [0:rank_cap)
Definition: shape.hpp:108
VarptrT< T > make_variable(teq::Shape shape, std::string label="")
Return zero-initialized variable node of specified shape.
Definition: variable.hpp:230
void truncated_normal(std::vector< T > &out, teq::Shape shape, T mean, T stdev, size_t max_repick=5)
Definition: init.hpp:43
NElemT n_elems(void) const
Return the total number of elements represented by the shape.
Definition: shape.hpp:118