Source code for mlprodict.onnxrt.onnx_inference_manipulations

"""
Implements a class able to compute the predictions
from on an :epkg:`ONNX` model.


:githublink:`%|py|6`
"""
from onnx import helper


[docs]def enumerate_model_node_outputs(model, add_node=False): """ Enumerates all the nodes of a model. :param model: :epkg:`ONNX` graph :param add_node: if False, the function enumerates all output names from every node, otherwise, it enumerates tuple (output name, node) :return: enumerator :githublink:`%|py|18` """ if not hasattr(model, "graph"): raise TypeError( # pragma: no cover "Parameter model is not an ONNX model but " "{}".format(type(model))) for node in model.graph.node: for out in node.output: yield (out, node) if add_node else out
[docs]def select_model_inputs_outputs(model, outputs=None, inputs=None): """ Takes a model and changes its outputs. :param model: :epkg:`ONNX` model :param inputs: new inputs, same ones if None :param outputs: new outputs, same ones if None :return: modified model The function removes unneeded files. :githublink:`%|py|38` """ if inputs is not None: raise NotImplementedError( # pragma: no cover "Parameter inputs cannot be empty.") if outputs is None: raise RuntimeError( # pragma: no cover "Parameter outputs cannot be None.") if not isinstance(outputs, list): outputs = [outputs] mark_var = {} for out in enumerate_model_node_outputs(model): mark_var[out] = 0 for inp in model.graph.input: mark_var[inp.name] = 0 for out in outputs: if out not in mark_var: raise ValueError( # pragma: no cover "Output '{}' not found in model.".format(out)) mark_var[out] = 1 nodes = model.graph.node[::-1] mark_op = {} for node in nodes: mark_op[node.name] = 0 # We mark all the nodes we need to keep. nb = 1 while nb > 0: nb = 0 for node in nodes: if mark_op[node.name] == 1: continue mod = False for out in node.output: if mark_var[out] == 1: mark_op[node.name] = 1 mod = True break if not mod: continue nb += 1 for inp in node.input: if mark_var.get(inp, 0) == 1: continue mark_var[inp] = 1 nb += 1 # All nodes verifies mark_op[node.name] == 1 keep_nodes = [node for node in nodes if mark_op[node.name] == 1] var_out = [] for out in outputs: value_info = helper.ValueInfoProto() value_info.name = out var_out.append(value_info) graph = helper.make_graph(keep_nodes, model.graph.name, model.graph.input, var_out, model.graph.initializer) onnx_model = helper.make_model(graph) onnx_model.ir_version = model.ir_version onnx_model.producer_name = model.producer_name onnx_model.producer_version = model.producer_version onnx_model.domain = model.domain onnx_model.model_version = model.model_version onnx_model.doc_string = model.doc_string if len(model.metadata_props) > 0: # pragma: no cover values = {p.key: p.value for p in model.metadata_props} helper.set_model_props(onnx_model, values) del onnx_model.opset_import[:] # pylint: disable=E1101 for oimp in model.opset_import: op_set = onnx_model.opset_import.add() # pylint: disable=E1101 op_set.domain = oimp.domain op_set.version = oimp.version if len(onnx_model.graph.input) != len(model.graph.input): # pylint: disable=E1101 raise RuntimeError( # pragma: no cover "Input mismatch {} != {}".format( len(onnx_model.input), len(model.input))) # pylint: disable=E1101 return onnx_model