Tenncor
rnn.hpp
Go to the documentation of this file.
1 #include "eteq/generated/api.hpp"
2 
3 #ifndef MODL_RNN_HPP
4 #define MODL_RNN_HPP
5 
6 namespace modl
7 {
8 
9 struct RNN final : public iMarshalSet
10 {
11  RNN (teq::DimT n_input, teq::DimT n_output, size_t timestep,
12  NonLinearF nonlin, std::string label) :
13  iMarshalSet(label), nonlin_(nonlin),
14  bias_(eteq::make_variable_scalar<PybindT>(
15  0., teq::Shape({n_output}), "bias")
16  {
17  assert(timestep > 0);
18  {
19  PybindT bound = 1. / std::sqrt(n_input);
20  std::uniform_real_distribution<PybindT> dist(-bound, bound);
21  auto gen = [&dist]()
22  {
23  return dist(eteq::get_engine());
24  };
25  std::vector<PybindT> wdata(n_output * n_input);
26  std::generate(wdata.begin(), wdata.end(), gen);
27 
28  eteq::VarptrT<PybindT> weight = eteq::make_variable<PybindT>(
29  wdata.data(), teq::Shape({n_output, n_input}), "weight_0");
30  layers_.push_back(std::make_shared<MarshalVar>(weight));
31  }
32  for (size_t i = 1; i < timestep; ++i)
33  {
34  teq::Shape weight_shape({n_output, n_output});
35  teq::NElemT nweight = weight_shape.n_elems();
36 
37  PybindT bound = 1. / std::sqrt(n_output);
38  std::uniform_real_distribution<PybindT> dist(-bound, bound);
39  auto gen = [&dist]()
40  {
41  return dist(eteq::get_engine());
42  };
43  std::vector<PybindT> wdata(nweight);
44  std::generate(wdata.begin(), wdata.end(), gen);
45 
46  eteq::VarptrT<PybindT> weight = eteq::make_variable<PybindT>(
47  wdata.data(), weight_shape, fmts::sprintf("weight_%d", i));
48 
49  layers_.push_back(std::make_shared<MarshalVar>(weight));
50  }
51  }
52 
53  RNN (const RNN& other) : iMarshalSet(other)
54  {
55  copy_helper(other);
56  }
57 
58  RNN& operator = (const RNN& other)
59  {
60  if (this != &other)
61  {
62  iMarshalSet::operator = (other);
63  copy_helper(other);
64  }
65  return *this;
66  }
67 
68  RNN (RNN&& other) = default;
69 
70  RNN& operator = (RNN&& other) = default;
71 
72 
73  // expect all inputs of shape <n_input, n_batch>
74  eteq::NodesT<PybindT> operator () (eteq::NodesT<PybindT> inputs)
75  {
76  // sanity check
77  const teq::Shape& in_shape = input->shape();
78  uint8_t ninput = get_ninput();
79  if (in_shape.at(0) != ninput)
80  {
81  logs::fatalf("cannot dbn with input shape %s against n_input %d",
82  in_shape.to_string().c_str(), ninput);
83  }
84 
85  size_t nins = inputs.size();
86  if (weights_.size() != nins)
87  {
88  logs::fatalf("cannot connect %d inputs with %d weights",
89  nins, weights_.size());
90  }
91 
93  outs.reserve(nins);
94  outs.push_back(nonlin_(tenncor::nn::fully_connect(
95  {inputs[0]}, {weights_[0]}, bias_)));
96  for (uint8_t i = 1; i < ninput; ++i)
97  {
98  outs.push_back(nonlin(tenncor::nn::fully_connect(
99  {outs.back(), inputs[i]},
100  {weights_[i - 1], weights_[i]}, bias_)));
101  }
102 
103  return outs;
104  }
105 
106  teq::DimT get_ninput (void) const
107  {
108  return weights_.front()->var_->shape().at(1);
109  }
110 
111  teq::DimT get_noutput (void) const
112  {
113  return weights_.back()->var_->shape().at(0);
114  }
115 
116  MarsarrT get_subs (void) const override
117  {
118  MarsarrT out = weights_;
119  out.push_back(bias_);
120  return out;
121  }
122 
123  MarsarrT weights_;
124 
125  MarVarsptrT bias_;
126 
127  NonLinearF nonlin_;
128 
129 private:
130  void copy_helper (const RNN& other)
131  {
132  weights_.clear();
133  for (const auto& weight : other.weights_)
134  {
135  weights_.push_back(
136  std::make_shared<MarshalVar>(*weight));
137  }
138  bias_ = std::make_shared<MarshalVar>(*other.bias_);
139  nonlin_ = other.nonlin_;
140  }
141 
142  iMarshaler* clone_impl (void) const override
143  {
144  return new RNN(*this);
145  }
146 };
147 
148 }
149 
150 #endif // MODL_RNN_HPP
std::shared_ptr< VariableNode< T > > VarptrT
Smart pointer of variable nodes to preserve assign functions.
Definition: variable.hpp:210
std::string to_string(void) const
Return string representation of shape.
Definition: shape.hpp:148
uint64_t NElemT
Definition: shape.hpp:44
Definition: rnn.hpp:9
EigenptrT< T > sqrt(teq::Shape &outshape, const OpArg< T > &in)
Definition: operator.hpp:427
Definition: shape.hpp:62
std::vector< NodeptrT< T > > NodesT
Vector of nodes.
Definition: inode.hpp:67
Definition: rnn.hpp:6
uint16_t DimT
Type used for shape dimension.
Definition: shape.hpp:31
DimT at(RankT idx) const
Return DimT element at idx for any index in range [0:rank_cap)
Definition: shape.hpp:108
EngineT & get_engine(void)
Return global random generator.
NElemT n_elems(void) const
Return the total number of elements represented by the shape.
Definition: shape.hpp:118