Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7from ._op import OpRun
10class If(OpRun):
12 atts = {
13 'then_branch': None,
14 'else_branch': None,
15 }
17 def __init__(self, onnx_node, desc=None, **options):
18 OpRun.__init__(self, onnx_node, desc=desc,
19 expected_attributes=If.atts,
20 **options)
21 if not hasattr(self.then_branch, 'run'):
22 raise RuntimeError( # pragma: no cover
23 "Parameter 'then_branch' must have a method 'run', "
24 "type {}.".format(type(self.then_branch)))
25 if not hasattr(self.else_branch, 'run'):
26 raise RuntimeError( # pragma: no cover
27 "Parameter 'else_branch' must have a method 'run', "
28 "type {}.".format(type(self.else_branch)))
30 self._run_meth_then = (self.then_branch.run_in_scan
31 if hasattr(self.then_branch, 'run_in_scan')
32 else self.then_branch.run)
33 self._run_meth_else = (self.else_branch.run_in_scan
34 if hasattr(self.else_branch, 'run_in_scan')
35 else self.else_branch.run)
37 def _run(self, cond, named_inputs=None): # pylint: disable=W0221
38 if named_inputs is None:
39 named_inputs = {}
40 if len(self.then_branch.input_names) > 0:
41 if len(named_inputs) == 0:
42 raise RuntimeError( # pragma: no cover
43 "named_inputs is empty but the graph needs {}.".format(
44 self.then_branch.input_names))
45 for k in self.then_branch.input_names:
46 if k not in named_inputs:
47 raise RuntimeError( # pragma: no cover
48 "Unable to find named input '{}' in\n{}.".format(
49 k, "\n".join(sorted(named_inputs))))
50 if len(self.else_branch.input_names) > 0:
51 if len(named_inputs) == 0:
52 raise RuntimeError( # pragma: no cover
53 "named_inputs is empty but the graph needs {}.".format(
54 self.then_branch.input_names))
55 for k in self.else_branch.input_names:
56 if k not in named_inputs:
57 raise RuntimeError( # pragma: no cover
58 "Unable to find named input '{}' in\n{}.".format(
59 k, "\n".join(sorted(named_inputs))))
61 if all(cond):
62 outputs = self._run_meth_then(named_inputs)
63 return tuple([outputs[name] for name in self.then_branch.output_names])
64 outputs = self._run_meth_else(named_inputs)
65 return tuple([outputs[name] for name in self.else_branch.output_names])
67 def _infer_shapes(self, cond, named_inputs=None): # pylint: disable=W0221
68 res = self.then_branch._set_shape_inference_runtime()
69 return tuple([res[name] for name in self.then_branch.output_names])
71 def _infer_types(self, cond, named_inputs=None): # pylint: disable=W0221
72 res = self.then_branch._set_type_inference_runtime()
73 return tuple([res[name] for name in self.then_branch.output_names])