1 #include "eteq/generated/api.hpp" 9 struct RNN final :
public iMarshalSet
12 NonLinearF nonlin, std::string label) :
13 iMarshalSet(label), nonlin_(nonlin),
14 bias_(eteq::make_variable_scalar<PybindT>(
20 std::uniform_real_distribution<PybindT> dist(-bound, bound);
25 std::vector<PybindT> wdata(n_output * n_input);
26 std::generate(wdata.begin(), wdata.end(), gen);
29 wdata.data(),
teq::Shape({n_output, n_input}),
"weight_0");
30 layers_.push_back(std::make_shared<MarshalVar>(weight));
32 for (
size_t i = 1; i < timestep; ++i)
38 std::uniform_real_distribution<PybindT> dist(-bound, bound);
43 std::vector<PybindT> wdata(nweight);
44 std::generate(wdata.begin(), wdata.end(), gen);
47 wdata.data(), weight_shape, fmts::sprintf(
"weight_%d", i));
49 layers_.push_back(std::make_shared<MarshalVar>(weight));
53 RNN (
const RNN& other) : iMarshalSet(other)
58 RNN& operator = (
const RNN& other)
62 iMarshalSet::operator = (other);
68 RNN (
RNN&& other) =
default;
70 RNN& operator = (
RNN&& other) =
default;
78 uint8_t ninput = get_ninput();
79 if (in_shape.
at(0) != ninput)
81 logs::fatalf(
"cannot dbn with input shape %s against n_input %d",
85 size_t nins = inputs.size();
86 if (weights_.size() != nins)
88 logs::fatalf(
"cannot connect %d inputs with %d weights",
89 nins, weights_.size());
94 outs.push_back(nonlin_(tenncor::nn::fully_connect(
95 {inputs[0]}, {weights_[0]}, bias_)));
96 for (uint8_t i = 1; i < ninput; ++i)
98 outs.push_back(nonlin(tenncor::nn::fully_connect(
99 {outs.back(), inputs[i]},
100 {weights_[i - 1], weights_[i]}, bias_)));
108 return weights_.front()->var_->shape().at(1);
113 return weights_.back()->var_->shape().at(0);
116 MarsarrT get_subs (
void)
const override 118 MarsarrT out = weights_;
119 out.push_back(bias_);
130 void copy_helper (
const RNN& other)
133 for (
const auto& weight : other.weights_)
136 std::make_shared<MarshalVar>(*weight));
138 bias_ = std::make_shared<MarshalVar>(*other.bias_);
139 nonlin_ = other.nonlin_;
142 iMarshaler* clone_impl (
void)
const override 144 return new RNN(*
this);
150 #endif // MODL_RNN_HPP std::shared_ptr< VariableNode< T > > VarptrT
Smart pointer of variable nodes to preserve assign functions.
Definition: variable.hpp:210
std::string to_string(void) const
Return string representation of shape.
Definition: shape.hpp:148
uint64_t NElemT
Definition: shape.hpp:44
EigenptrT< T > sqrt(teq::Shape &outshape, const OpArg< T > &in)
Definition: operator.hpp:427
std::vector< NodeptrT< T > > NodesT
Vector of nodes.
Definition: inode.hpp:67
uint16_t DimT
Type used for shape dimension.
Definition: shape.hpp:31
DimT at(RankT idx) const
Return DimT element at idx for any index in range [0:rank_cap)
Definition: shape.hpp:108
EngineT & get_engine(void)
Return global random generator.
NElemT n_elems(void) const
Return the total number of elements represented by the shape.
Definition: shape.hpp:118