Source code for mlprodict.onnxrt.ops_cpu.op_if

# -*- encoding: utf-8 -*-
# pylint: disable=E0203,E1101,C0111
"""
Runtime operator.


:githublink:`%|py|7`
"""
from ._op import OpRun


[docs]class If(OpRun): atts = { 'then_branch': None, 'else_branch': None, }
[docs] def __init__(self, onnx_node, desc=None, **options): OpRun.__init__(self, onnx_node, desc=desc, expected_attributes=If.atts, **options) if not hasattr(self.then_branch, 'run'): raise RuntimeError( # pragma: no cover "Parameter 'then_branch' must have a method 'run', " "type {}.".format(type(self.then_branch))) if not hasattr(self.else_branch, 'run'): raise RuntimeError( # pragma: no cover "Parameter 'else_branch' must have a method 'run', " "type {}.".format(type(self.else_branch))) self._run_meth_then = (self.then_branch.run_in_scan if hasattr(self.then_branch, 'run_in_scan') else self.then_branch.run) self._run_meth_else = (self.else_branch.run_in_scan if hasattr(self.else_branch, 'run_in_scan') else self.else_branch.run)
[docs] def _run(self, cond, named_inputs=None): # pylint: disable=W0221 if named_inputs is None: named_inputs = {} if len(self.then_branch.input_names) > 0: if len(named_inputs) == 0: raise RuntimeError( # pragma: no cover "named_inputs is empty but the graph needs {}.".format( self.then_branch.input_names)) for k in self.then_branch.input_names: if k not in named_inputs: raise RuntimeError( # pragma: no cover "Unable to find named input '{}' in\n{}.".format( k, "\n".join(sorted(named_inputs)))) if len(self.else_branch.input_names) > 0: if len(named_inputs) == 0: raise RuntimeError( # pragma: no cover "named_inputs is empty but the graph needs {}.".format( self.then_branch.input_names)) for k in self.else_branch.input_names: if k not in named_inputs: raise RuntimeError( # pragma: no cover "Unable to find named input '{}' in\n{}.".format( k, "\n".join(sorted(named_inputs)))) if all(cond): outputs = self._run_meth_then(named_inputs) return tuple([outputs[name] for name in self.then_branch.output_names]) outputs = self._run_meth_else(named_inputs) return tuple([outputs[name] for name in self.else_branch.output_names])
[docs] def _infer_shapes(self, cond, named_inputs=None): # pylint: disable=W0221 res = self.then_branch._set_shape_inference_runtime() return tuple([res[name] for name in self.then_branch.output_names])