Source code for mlprodict.onnxrt.optim.onnx_optimisation_identity
"""
Optimisation of :epkg:`ONNX` graphs.
:githublink:`%|py|5`
"""
from onnx.helper import make_graph
from ._onnx_optimisation_common import ( # pylint: disable=E0611
_rename_node_input,
_rename_node_output,
_apply_optimisation_on_graph,
_apply_remove_node_fct_node
)
[docs]def onnx_remove_node_identity(onnx_model, recursive=True, debug_info=None, **options):
"""
Removes as many *Identity* nodes as possible.
The function looks into every node and subgraphs if
*recursive* is True for identity node. Unless such a
node directy connects one input to one output, it will
be removed and every other node gets its inputs or
outputs accordingly renamed.
:param onnx_model: onnx model
:param recursive: looks into subgraphs
:param debug_info: debug information (private)
:param options: additional options (unused)
:return: new onnx _model
:githublink:`%|py|28`
"""
if debug_info is None:
debug_info = [str(type(onnx_model)).split('.')[-1].strip("'>")]
else:
debug_info = debug_info + \
[str(type(onnx_model)).split('.')[-1].strip("'>")]
if hasattr(onnx_model, 'graph'):
return _apply_optimisation_on_graph(
onnx_remove_node_identity, onnx_model,
recursive=recursive, debug_info=debug_info, **options)
graph = onnx_model
inputs = set(i.name for i in graph.input)
outputs = set(o.name for o in graph.output)
def retrieve_idnodes(graph, existing_nodes):
idnodes = []
for i, exnode in enumerate(existing_nodes):
if exnode is None:
continue
if exnode.op_type == 'Identity':
input = exnode.input[0]
output = exnode.output[0]
idnodes.append((i, exnode, input, output))
return idnodes
nodes = list(graph.node)
rem = 1
while rem > 0:
rem = 0
idnodes = retrieve_idnodes(graph, nodes)
restart = False
for i, _, inp, out in idnodes:
if restart:
break # pragma: no cover
if nodes[i] is None:
# Already removed.
continue # pragma: no cover
if inp in inputs and out in outputs:
# Cannot be removed.
continue
if not restart and out not in outputs:
# We cannot change an output name.
for j in range(len(nodes)): # pylint: disable=C0200
if nodes[j] is None:
continue
if out in nodes[j].input:
nodes[j] = _rename_node_input(nodes[j], out, inp)
rem += 1
if nodes[j].op_type == 'Identity':
restart = True # pragma: no cover
nodes[i] = None
rem += 1
continue
if not restart and inp not in inputs and inp not in outputs:
# We cannot change an input name or an output name.
for j in range(len(nodes)): # pylint: disable=C0200
if nodes[j] is None:
continue
if inp in nodes[j].output:
nodes[j] = _rename_node_output(nodes[j], inp, out)
rem += 1
if nodes[j].op_type == 'Identity':
restart = True # pragma: no cover
if inp in nodes[j].input:
nodes[j] = _rename_node_input(nodes[j], inp, out)
rem += 1
if nodes[j].op_type == 'Identity':
restart = True
nodes[i] = None
rem += 1
if recursive:
# Handles subgraphs.
for i in range(len(nodes)): # pylint: disable=C0200
node = nodes[i]
if node is None or not (node.attribute): # pylint: disable=C0325
continue
nodes[i] = _apply_remove_node_fct_node(
onnx_remove_node_identity,
node, recursive=True, debug_info=debug_info + [node.name])
# Finally create the new graph.
nodes = list(filter(lambda n: n is not None, nodes))
graph = make_graph(nodes, onnx_model.name,
onnx_model.input, onnx_model.output,
onnx_model.initializer)
graph.value_info.extend(onnx_model.value_info) # pylint: disable=E1101
return graph