Source code for mlprodict.onnxrt.onnx_inference_node

"""
OnnxInferenceNode definition.


:githublink:`%|py|5`
"""
import sys
import pprint
import numpy
from onnx import onnx_pb as onnx_proto
from .ops import load_op


[docs]class OnnxInferenceNode: """ A node to execute. :githublink:`%|py|15` """
[docs] def __init__(self, onnx_node, desc, global_index): """ :param onnx_node: onnx_node :param desc: internal description :param global_index: it is a function which returns a unique index for the output this operator generates :githublink:`%|py|23` """ if desc is None: raise ValueError("desc should not be None.") # pragma: no cover self.desc = desc self.onnx_node = onnx_node self._init(global_index)
@property def name(self): "Returns the ONNX name." return "_".join( [self.desc['domain'], self.onnx_node.op_type]).replace( ".", "_").replace('__', '_').strip('_')
[docs] def _init(self, global_index): """ Prepares the node. :githublink:`%|py|40` """ self.op_type = self.onnx_node.op_type self.order = -1 self.variable_to_clean = [] self.inputs = list(self.onnx_node.input) self.outputs = list(self.onnx_node.output) self.inplaces = [] self.inputs_indices = [global_index(name) for name in self.inputs] self.outputs_indices = [global_index(name) for name in self.outputs]
[docs] def set_order(self, order): """ Defines the order of execution. :githublink:`%|py|53` """ self.order = order
[docs] def add_variable_to_clean(self, name): """ Adds a variable which can be cleaned after the node execution. :githublink:`%|py|60` """ self.variable_to_clean.append(name)
[docs] def __str__(self): "usual" return "Onnx-{}({}) -> {}".format( self.op_type, ", ".join(self.inputs), ", ".join(self.outputs))
[docs] def __repr__(self): "usual" return self.__str__()
[docs] def setup_runtime(self, runtime=None, variables=None, rt_class=None, target_opset=None, dtype=None, domain=None, ir_version=None, runtime_options=None): """ Loads runtime. :param runtime: runtime options :param variables: registered variables created by previous operators :param rt_class: runtime class used to compute prediction of subgraphs :param target_opset: use a specific target opset :param dtype: float computational type :param domain: node domain :param ir_version: if not None, changes the default value given by :epkg:`ONNX` :param runtime_options: runtime options :githublink:`%|py|89` """ if self.desc is None: raise AttributeError( "desc should not be None.") # pragma: no cover self.preprocess_parameters( runtime, rt_class, ir_version=ir_version, target_opset=target_opset) options = {'provider': runtime} if runtime else {} if domain is not None: options['domain'] = domain if target_opset is not None: options['target_opset'] = target_opset if ir_version is not None: options['ir_version'] = ir_version if runtime_options is not None: options.update(runtime_options) if runtime == 'onnxruntime2': self.ops_ = load_op(self.onnx_node, desc=self.desc, options=options if options else None, variables=variables, dtype=dtype) elif runtime in ('python_compiled', 'python_compiled_debug'): options['provider'] = 'python' self.ops_ = load_op(self.onnx_node, desc=self.desc, options=options if options else None, variables=variables) else: self.ops_ = load_op(self.onnx_node, desc=self.desc, options=options if options else None, variables=variables)
[docs] def preprocess_parameters(self, runtime, rt_class, ir_version=None, target_opset=None): """ Preprocesses the parameters, loads *GraphProto* (equivalent to :epkg:`ONNX` graph with less metadata). :param runtime: runtime options :param rt_class: runtime class used to compute prediction of subgraphs :param ir_version: if not None, overwrites the default value :param target_opset: use a specific target opset :githublink:`%|py|131` """ if 'atts' not in self.desc: return # pragma: no cover for _, v in self.desc['atts'].items(): if 'value' not in v: continue # pragma: no cover value = v['value'] if isinstance(value, onnx_proto.GraphProto): sess = rt_class(v['value'], runtime=runtime, ir_version=ir_version, target_opset=target_opset) v['value_rt'] = sess
[docs] def run(self, values): """ Runs the node. the function updates values with outputs. :param values: list of existing values :githublink:`%|py|150` """ # This code takes times if the graph contains many nodes. # Maybe a C++ container would help in that case (to skip GIL). if self.inputs_indices is None: args = list(values[k] for k in self.inputs) else: args = list(values[k] for k in self.inputs_indices) try: res = self.ops_.run(*args) except TypeError as e: raise RuntimeError( "Unable to run operator %r." % type(self.ops_)) from e if not isinstance(res, tuple): raise RuntimeError( # pragma: no cover "Results of operator %r should be a tuple." % type(self.ops_)) if len(self.outputs) != len(res): raise RuntimeError( # pragma: no cover "Mismatch number of outputs got {} for names {}.\n{}".format( len(res), list(sorted(self.outputs)), pprint.pformat(self.desc))) # This code takes times if the graph contains many nodes. # Maybe a C++ container would help in that case (to skip GIL). if self.outputs_indices is None: for name, value in zip(self.outputs, res): values[name] = value else: for i, r in enumerate(res): values[self.outputs_indices[i]] = r
[docs] def switch_initializers_dtype(self, dtype_in=numpy.float32, dtype_out=numpy.float64): """ Switches all initializers to ``numpy.float64``. This only works if the runtime is ``'python'``. :param dtype_in: previous type :param dtype_out: next type :return: done operations :githublink:`%|py|191` """ done = [] for k, v in self.desc['atts'].items(): if 'value_rt' not in v: continue if isinstance(v['value_rt'], numpy.ndarray): if v['value_rt'].dtype == dtype_in: v['value_rt'] = v['value_rt'].astype(dtype_out) done.append(("+", "desc", k, v['value_rt'])) else: done.append(("-", "desc", k, v['value_rt'])) if hasattr(self, 'ops_') and self.ops_ is not None: res = self.ops_.switch_initializers_dtype(dtype_in, dtype_out) for r in res: done.append(("ops_", ) + r) return done
[docs] def _set_shape_inference_runtime(self, values): """ Updates *values* which shapes of the outputs. :param values: container for shapes :githublink:`%|py|213` """ args = [values[k] for k in self.inputs] try: res = self.ops_.infer_shapes(*args) except (TypeError, ValueError) as e: raise TypeError( "Unable to call infer_shapes with {} arguments for class" " '{}' ({})".format(len(args), self.ops_.__class__.__name__, self.ops_.infer_shapes)) from e if not isinstance(res, tuple): raise RuntimeError( # pragma: no cover "Results of an operator should be a tuple for operator '{}'" ".".format(type(self.ops_))) if len(self.outputs) != len(res): raise RuntimeError( # pragma: no cover "Mismatch number of outputs got {} != {} for names {} (node='{}')." "\n{}".format( len(res), len(self.outputs), list(self.outputs), self.ops_.__class__.__name__, pprint.pformat(self.desc, depth=2))) for name, value in zip(self.outputs, res): values[name] = value return values
[docs] def enable_inplace_compute(self, name): """ Let the node know that one input can be overwritten. :param name: input name :githublink:`%|py|242` """ self.inplaces.append(name) self.ops_.enable_inplace_compute(self.inputs.index(name))
@property def inputs_args(self): """ Returns the list of arguments as well as the list of parameters with the default values (close to the signature). :githublink:`%|py|252` """ if not hasattr(self, 'ops_'): raise AttributeError( "Attribute 'ops_' is missing.") # pragma: no cover sigs = [] mand = self.ops_.args_mandatory if mand is None: mand = self.python_inputs sigs.extend(mand) if len(self.ops_.args_optional) > 0: sigs.extend(self.ops_.args_optional) if sys.version_info[:2] >= (3, 8): sigs.append('/') sigs.extend(self.ops_.args_default) return sigs @property def python_inputs(self): """ Returns the python arguments. :githublink:`%|py|272` """ if not hasattr(self, 'ops_'): raise AttributeError( "Attribute 'ops_' is missing.") # pragma: no cover if hasattr(self.ops_, 'python_inputs'): return self.ops_.python_inputs return self.inputs @property def modified_args(self): """ Returns the list of modified parameters. :githublink:`%|py|284` """ if not hasattr(self, 'ops_'): raise AttributeError( "Attribute 'ops_' is missing.") # pragma: no cover return self.ops_.args_default_modified
[docs] def to_python(self, inputs): """ Returns a python code for this operator. :param inputs: inputs name :return: imports, python code, both as strings :githublink:`%|py|296` """ if not hasattr(self, 'ops_'): raise AttributeError( "Attribute 'ops_' is missing.") # pragma: no cover return self.ops_.to_python(inputs)