Source code for mlprodict.onnxrt.ops_cpu.op_loop
# -*- encoding: utf-8 -*-
# pylint: disable=E0203,E1101,C0111
"""
Runtime operator.
:githublink:`%|py|7`
"""
import numpy
from ._op import OpRun
[docs]class Loop(OpRun):
atts = {
'body': None,
}
[docs] def __init__(self, onnx_node, desc=None, **options):
OpRun.__init__(self, onnx_node, desc=desc,
expected_attributes=Loop.atts,
**options)
if not hasattr(self.body, 'run'):
raise RuntimeError("Parameter 'body' must have a method 'run', "
"type {}.".format(type(self.body)))
self._run_meth = (self.body.run_in_scan
if hasattr(self.body, 'run_in_scan')
else self.body.run)
[docs] def _run(self, M, cond, v_initial, *args): # pylint: disable=W0221
inputs = {name: None for name in self.body.input_names}
inputs[self.body.input_names[0]] = cond
inputs[self.body.input_names[1]] = v_initial
cond_name = self.body.output_names[0]
if len(args) > 0:
begin = len(self.body.input_names) - len(args)
for name, val in zip(self.body.input_names[begin:], args):
inputs[name] = val
it = 0
while cond and it < M:
outputs = self._run_meth_then(inputs)
cond = outputs[cond_name]
for i, o in zip(self.body.input_names[2:],
self.body.output_names[2:]):
inputs[i] = outputs[o]
it += 1
if it == 0:
outputs = {self.body.output_names[1]: cond}
for i, o in zip(self.body.input_names[2:],
self.body.output_names[2:]):
outputs[o] = inputs[i]
for o in self.body.output_names:
if o not in outputs:
outputs[o] = numpy.empty(shape=tuple())
return tuple([outputs[name] for name in self.body.output_names[1:]])
[docs] def _infer_shapes(self, M, cond, v_initial, *args): # pylint: disable=W0221
res = self.body._set_shape_inference_runtime()
return tuple([res[name] for name in self.body.output_names[1:]])