Source code for mlprodict.onnxrt.optim.onnx_optimisation_unused
"""
Optimisation of :epkg:`ONNX` graphs.
:githublink:`%|py|5`
"""
from onnx.helper import make_graph
from ._onnx_optimisation_common import ( # pylint: disable=E0611
_apply_optimisation_on_graph, _apply_remove_node_fct_node)
[docs]def onnx_remove_node_unused(onnx_model, recursive=True, debug_info=None, **options):
"""
Removes unused nodes of the graph. An unused node
is not involved in the output computation.
:param onnx_model: onnx model
:param recursive: looks into subgraphs
:param debug_info: debug information (private)
:param options: unused
:return: new onnx _model
:githublink:`%|py|20`
"""
if debug_info is None:
debug_info = [str(type(onnx_model)).split('.')[-1].strip("'>")]
else:
debug_info = debug_info + \
[str(type(onnx_model)).split('.')[-1].strip("'>")]
if hasattr(onnx_model, 'graph'):
return _apply_optimisation_on_graph(
onnx_remove_node_unused, onnx_model,
recursive=recursive, debug_info=debug_info,
**options)
graph = onnx_model
data = {}
valid = {}
edges = {}
for init in graph.initializer:
data[init.name, 0] = init
for node in graph.node:
data[node.name, 1] = node
for inp in node.input:
data[inp, 0] = node
edges[(inp, 0), (node.name, 1)] = node
for out in node.output:
data[out, 0] = node
edges[(node.name, 1), (out, 0)] = node
for out in graph.output:
valid[out.name, 0] = True
modif = 1
while modif > 0:
modif = 0
for e1, e2 in edges: # pylint: disable=E1141
if valid.get(e2, False) and not valid.get(e1, False):
valid[e1] = True
modif += 1
new_nodes = [n for n in graph.node if (n.name, 1) in valid]
new_inits = [n for n in graph.initializer if (n.name, 0) in valid]
if recursive:
# Handles subgraphs.
for i in range(len(new_nodes)): # pylint: disable=C0200
node = new_nodes[i]
if node is None or not (node.attribute): # pylint: disable=C0325
continue
new_nodes[i] = _apply_remove_node_fct_node(
onnx_remove_node_unused,
node, recursive=True, debug_info=debug_info + [node.name])
# Finally create the new graph.
nodes = list(filter(lambda n: n is not None, new_nodes))
graph = make_graph(nodes, onnx_model.name,
onnx_model.input, onnx_model.output,
new_inits)
graph.value_info.extend(onnx_model.value_info) # pylint: disable=E1101
return graph