Source code for mlprodict.onnxrt.onnx_tools
"""
Functions to manipulate ONNX file.
:githublink:`%|py|5`
"""
from onnx import helper
[docs]def find_node_name(model, name):
"""
Finds a node by its name.
:param model: onnx graph
:param name: node name
:return: node pointer
:githublink:`%|py|14`
"""
if not hasattr(model, "graph"):
raise TypeError( # pragma: no cover
"Parameter model is not an ONNX model but "
"{}".format(type(model)))
for node in model.graph.node:
if node.name == name:
return node
return None
[docs]def insert_node(model, op_type, node, input_index=0, new_name=None, **attrs):
"""
Inserts a node before one node input.
:param model: onnx graph
:param op_type:
:param node: node or node name
:param input_index: input index or input name
:param attrs: node attributes
:return: updated graph
:githublink:`%|py|47`
"""
if isinstance(node, str):
inode = find_node_name(model, node)
else:
inode = node
if isinstance(input_index, str):
input_index_ = find_node_input_name(node, input_index)
if input_index_ == -1:
raise RuntimeError(
"Unable to find input_index %r in node %r." % (
input_index, node.name)) # pylint: disable=E1120
input_index = input_index_
# guess a new name
names = []
for n in model.graph.node:
names.extend(n.input)
names.extend(n.output)
names = set(names)
if new_name is None:
new_name = op_type.lower()
root_name = new_name
i = 0
while new_name in names:
new_name = "%s_%d" % (root_name, i)
i += 1
new_node = helper.make_node(
op_type, [inode.input[input_index]], [new_name], **attrs)
inode.input[input_index] = new_name
keep_nodes = list(model.graph.node)
keep_nodes.append(new_node)
graph = helper.make_graph(
keep_nodes, model.graph.name, model.graph.input,
model.graph.output, model.graph.initializer)
onnx_model = helper.make_model(graph)
onnx_model.ir_version = model.ir_version
onnx_model.producer_name = model.producer_name
onnx_model.producer_version = model.producer_version
onnx_model.domain = model.domain
onnx_model.model_version = model.model_version
onnx_model.doc_string = model.doc_string
if len(model.metadata_props) > 0:
values = {p.key: p.value for p in model.metadata_props}
helper.set_model_props(onnx_model, values)
del onnx_model.opset_import[:] # pylint: disable=E1101
for oimp in model.opset_import:
op_set = onnx_model.opset_import.add() # pylint: disable=E1101
op_set.domain = oimp.domain
op_set.version = oimp.version
if len(onnx_model.graph.input) != len(model.graph.input): # pylint: disable=E1101
raise RuntimeError( # pragma: no cover
"Input mismatch {} != {}".format(
len(onnx_model.input), len(model.input))) # pylint: disable=E1101
return onnx_model