Tenncor
voter.hpp
Go to the documentation of this file.
1 
9 #include "opt/ivoter.hpp"
10 
11 #ifndef OPT_VOTER_HPP
12 #define OPT_VOTER_HPP
13 
14 namespace opt
15 {
16 
18 struct OrdrVoter final : public iVoter
19 {
20  OrdrVoter (std::string label) : label_(label) {}
21 
23  void emplace (VoterArgsT args, Symbol sym) override
24  {
25  args_.emplace(args, sym);
26  }
27 
29  CandsT inspect (const CandArgsT& args) const override
30  {
31  CandsT out;
32  out.reserve(args_.size());
33  for (const auto& vpair : args_)
34  {
35  CtxsT ctxs;
36  const VoterArgsT& vargs = vpair.first;
37  if (vargs.size() != args.size())
38  {
39  continue;
40  }
41  if ([&]() -> bool
42  {
43  for (size_t i = 0, n = args.size(); i < n; ++i)
44  {
45  if (false == vargs[i].match(ctxs, args[i]))
46  {
47  return true;
48  }
49  }
50  return false;
51  }())
52  {
53  // failure to match one of the arguments
54  continue;
55  }
56  out[vpair.second].insert(ctxs.begin(), ctxs.end());
57  }
58  return out;
59  }
60 
62  std::string label_;
63 
65  std::unordered_map<VoterArgsT,Symbol,OrdrHasher> args_;
66 };
67 
68 // todo: ensure comm voters are inspected after all available ordr voters are inspected (optimization)
70 struct CommVoter final : public iVoter
71 {
72  CommVoter (std::string label) : label_(label) {}
73 
75  void emplace (VoterArgsT args, Symbol sym) override
76  {
77  // sort args
78  SegVArgs segs;
79  for (VoterArg& arg : args)
80  {
81  switch (arg.type_)
82  {
84  segs.scalars_.push_back(arg);
85  break;
87  segs.branches_.push_back(arg);
88  break;
90  default:
91  segs.anys_.push_back(arg);
92  }
93  }
94  if (segs.branches_.size() > 1)
95  {
96  logs::fatal("implementation limit: "
97  "cannot have more than 1 operator as an argument of the "
98  "commutative operator for the source subgraph");
99  }
100  sort_vargs(segs.scalars_);
101  sort_vargs(segs.anys_);
102  args_.emplace(segs, sym);
103  }
104 
106  CandsT inspect (const CandArgsT& args) const override
107  {
108  CandsT out;
109  out.reserve(args_.size());
110  for (const auto& vpair : args_)
111  {
112  const SegVArgs& vargs = vpair.first;
113  if (vargs.size() != args.size())
114  {
115  continue;
116  }
117  CtxsT ctxs;
118  std::list<CandArg> unmatched(args.begin(), args.end());
119  // attempt matching scalars first
120  bool match_failed = false;
121  for (const VoterArg& sarg : vargs.scalars_)
122  {
123  match_failed = true;
124  for (auto it = unmatched.begin(), et = unmatched.end();
125  it != et; ++it)
126  {
127  if (sarg.match(ctxs, *it))
128  {
129  unmatched.erase(it);
130  match_failed = false;
131  break;
132  }
133  }
134  if (match_failed)
135  {
136  break;
137  }
138  }
139  // none of the unmatched args matched
140  // a scalar voter argument
141  if (match_failed)
142  {
143  continue;
144  }
145 
146  for (const VoterArg& barg : vargs.branches_)
147  {
148  match_failed = true;
149  for (auto it = unmatched.begin(), et = unmatched.end();
150  it != et; ++it)
151  {
152  if (barg.match(ctxs, *it))
153  {
154  unmatched.erase(it);
155  match_failed = false;
156  break;
157  }
158  }
159  if (match_failed)
160  {
161  break;
162  }
163  }
164  // none of the unmatched args matched
165  // a branch voter argument
166  if (match_failed)
167  {
168  continue;
169  }
170 
171  // create permutations of remaining against anys matches
172  std::vector<CandArg> remaining(unmatched.begin(), unmatched.end());
173  size_t nremaining = remaining.size();
174  std::vector<size_t> indices(nremaining);
175  std::iota(indices.begin(), indices.end(), 0);
176  do
177  {
178  bool matched = true;
179  CtxsT local_ctxs = ctxs;
180  for (size_t i = 0; i < nremaining && matched; ++i)
181  {
182  matched = vargs.anys_[i].match(local_ctxs,
183  remaining[indices[i]]);
184  }
185  if (false == matched)
186  {
187  continue;
188  }
189  out[vpair.second].insert(local_ctxs.begin(), local_ctxs.end());
190  }
191  while (std::next_permutation(indices.begin(), indices.end()));
192  }
193  return out;
194  }
195 
197  std::string label_;
198 
200  std::unordered_map<SegVArgs,Symbol,CommHasher> args_;
201 };
202 
204 struct VariadicVoter final : public iVoter
205 {
206  VariadicVoter (std::string label, std::string variadic) :
207  label_(label), variadic_(variadic) {}
208 
210  void emplace (VoterArgsT args, Symbol sym) override
211  {
212  // sort args
213  SegVArgs segs;
214  for (VoterArg& arg : args)
215  {
216  switch (arg.type_)
217  {
219  segs.scalars_.push_back(arg);
220  break;
222  segs.branches_.push_back(arg);
223  break;
225  default:
226  segs.anys_.push_back(arg);
227  }
228  }
229  if (segs.branches_.size() > 1)
230  {
231  logs::fatal("implementation limit: "
232  "cannot have more than 1 operator as an argument of the "
233  "commutative operator for the source subgraph");
234  }
235  sort_vargs(segs.scalars_);
236  sort_vargs(segs.anys_);
237  args_.emplace(segs, sym);
238  }
239 
241  CandsT inspect (const CandArgsT& args) const override
242  {
243  CandsT out;
244  out.reserve(args_.size());
245  for (const auto& vpair : args_)
246  {
247  const SegVArgs& vargs = vpair.first;
248  if (vargs.size() > args.size())
249  {
250  // not enough voter arguments to match candidate arguments
251  continue;
252  }
253  CtxsT ctxs;
254  std::list<CandArg> unmatched(args.begin(), args.end());
255  // attempt matching scalars first
256  bool match_failed = false;
257  for (const VoterArg& sarg : vargs.scalars_)
258  {
259  match_failed = true;
260  for (auto it = unmatched.begin(), et = unmatched.end();
261  it != et; ++it)
262  {
263  if (sarg.match(ctxs, *it))
264  {
265  unmatched.erase(it);
266  match_failed = false;
267  break;
268  }
269  }
270  if (match_failed)
271  {
272  break;
273  }
274  }
275  // none of the unmatched args matched
276  // a scalar voter argument
277  if (match_failed)
278  {
279  continue;
280  }
281 
282  for (const VoterArg& barg : vargs.branches_)
283  {
284  match_failed = true;
285  for (auto it = unmatched.begin(), et = unmatched.end();
286  it != et; ++it)
287  {
288  if (barg.match(ctxs, *it))
289  {
290  unmatched.erase(it);
291  match_failed = false;
292  break;
293  }
294  }
295  if (match_failed)
296  {
297  break;
298  }
299  }
300  // none of the unmatched args matched
301  // a branch voter argument
302  if (match_failed)
303  {
304  continue;
305  }
306 
307  // create permutations of remaining against anys matches
308  std::vector<CandArg> remaining(unmatched.begin(), unmatched.end());
309  size_t nremaining = remaining.size();
310  std::vector<size_t> indices(nremaining);
311  std::iota(indices.begin(), indices.end(), 0);
312 
313  size_t nanys = vargs.anys_.size();
314  do
315  {
316  // select first nanys indices,
317  // and dump remaining as variadic
318  bool matched = true;
319  CtxsT local_ctxs = ctxs;
320  for (size_t i = 0; i < nanys && matched; ++i)
321  {
322  matched = vargs.anys_[i].match(local_ctxs,
323  remaining[indices[i]]);
324  }
325  if (false == matched)
326  {
327  continue;
328  }
329  CtxValT cvals;
330  for (size_t i = nanys; i < nremaining; ++i)
331  {
332  // dump remaining[indices[i]] as variadic
333  cvals.emplace(remaining[indices[i]].tensor_); // todo: also store coorder and shaper
334  }
335  CtxsT& out_ctxs = out[vpair.second];
336  for (ContexT ctx : local_ctxs)
337  {
338  ctx.emplace(variadic_, cvals);
339  out_ctxs.emplace(ctx);
340  }
341  }
342  while (std::next_permutation(indices.begin(), indices.end()));
343  }
344  return out;
345  }
346 
348  std::string label_;
349 
351  std::string variadic_;
352 
354  std::unordered_map<SegVArgs,Symbol,CommHasher> args_;
355 };
356 
357 }
358 
359 #endif // OPT_VOTER_HPP
std::string label_
Label type of the functor.
Definition: voter.hpp:197
CandsT inspect(const CandArgsT &args) const override
Implementation of iVoter.
Definition: voter.hpp:241
args
Definition: csv_to_png.py:105
VariadicVoter(std::string label, std::string variadic)
Definition: voter.hpp:206
std::set< teq::TensptrT > CtxValT
Set of tensors that potentially matches some id.
Definition: candidate.hpp:23
VoterArgsT branches_
Branch-typed arguments (functors/groups)
Definition: ivoter.hpp:157
std::unordered_map< SegVArgs, Symbol, CommHasher > args_
Map functor arguments to emplaced symbols.
Definition: voter.hpp:200
std::string label_
Label type of the functor.
Definition: voter.hpp:62
std::string variadic_
Symbol of variadic argument.
Definition: voter.hpp:351
Definition: candidate.hpp:19
std::unordered_map< VoterArgsT, Symbol, OrdrHasher > args_
Map functor arguments to emplaced symbols.
Definition: voter.hpp:65
VoterArgsT anys_
Any-typed leaf arguments.
Definition: ivoter.hpp:160
Branching node.
Definition: def.h:33
void emplace(VoterArgsT args, Symbol sym) override
Implementation of iVoter.
Definition: voter.hpp:210
std::unordered_set< ContexT, boost::hash< ContexT > > CtxsT
Set of contexts that serve as a candidates of a conversion rule.
Definition: candidate.hpp:29
std::vector< CandArg > CandArgsT
Vector of candidate arguments.
Definition: candidate.hpp:95
void emplace(VoterArgsT args, Symbol sym) override
Implementation of iVoter.
Definition: voter.hpp:75
Implement voter for variadic groups.
Definition: voter.hpp:204
void emplace(VoterArgsT args, Symbol sym) override
Implementation of iVoter.
Definition: voter.hpp:23
Variadic/communtative branch voter arguments.
Definition: ivoter.hpp:146
std::map< std::string, CtxValT > ContexT
Map of rule graph leaf identifiers to corresponding matches.
Definition: candidate.hpp:26
Rule tree leaf that represents any real node.
Definition: def.h:31
std::unordered_map< Symbol, CtxsT, SymbolHash > CandsT
Map of convers symbols to its potential candidate conversion rules.
Definition: candidate.hpp:76
VoterArgsT scalars_
Scalar-typed arguments.
Definition: ivoter.hpp:154
void sort_vargs(VoterArgsT &args)
Normalize voter arguments to facilitate matching.
Rule tree node that identify and selects matching candidates.
Definition: ivoter.hpp:234
CommVoter(std::string label)
Definition: voter.hpp:72
std::string label_
Label type of group.
Definition: voter.hpp:348
CandsT inspect(const CandArgsT &args) const override
Implementation of iVoter.
Definition: voter.hpp:29
Generic representation of a conversion rule.
Definition: candidate.hpp:45
OrdrVoter(std::string label)
Definition: voter.hpp:20
size_t size(void) const
Definition: ivoter.hpp:148
std::vector< VoterArg > VoterArgsT
Vector of voter arguments for branching nodes.
Definition: ivoter.hpp:143
std::unordered_map< SegVArgs, Symbol, CommHasher > args_
Map functor arguments to emplaced symbols.
Definition: voter.hpp:354
CandsT inspect(const CandArgsT &args) const override
Implementation of iVoter.
Definition: voter.hpp:106
Implement voter for ordered (non-commutative) functors.
Definition: voter.hpp:18
Definitive scalar constant.
Definition: def.h:29
Implement voter for commutative functors.
Definition: voter.hpp:70
Argument voter for functors.
Definition: ivoter.hpp:23
bool match(CtxsT &ctxs, const CandArg &arg) const
Return true if arg matches this only add to ctxs if matches.
Definition: ivoter.hpp:35