Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Optimisation of :epkg:`ONNX` graphs.
4"""
5from onnx.helper import make_graph
6from ._onnx_optimisation_common import ( # pylint: disable=E0611
7 _rename_node_input,
8 _rename_node_output,
9 _apply_optimisation_on_graph,
10 _apply_remove_node_fct_node
11)
14def onnx_remove_node_identity(onnx_model, recursive=True, debug_info=None, **options):
15 """
16 Removes as many *Identity* nodes as possible.
17 The function looks into every node and subgraphs if
18 *recursive* is True for identity node. Unless such a
19 node directy connects one input to one output, it will
20 be removed and every other node gets its inputs or
21 outputs accordingly renamed.
23 @param onnx_model onnx model
24 @param recursive looks into subgraphs
25 @param debug_info debug information (private)
26 @param options additional options (unused)
27 @return new onnx _model
28 """
29 if debug_info is None:
30 debug_info = [str(type(onnx_model)).split('.')[-1].strip("'>")]
31 else:
32 debug_info = debug_info + \
33 [str(type(onnx_model)).split('.')[-1].strip("'>")]
35 if hasattr(onnx_model, 'graph'):
36 return _apply_optimisation_on_graph(
37 onnx_remove_node_identity, onnx_model,
38 recursive=recursive, debug_info=debug_info, **options)
40 graph = onnx_model
42 inputs = set(i.name for i in graph.input)
43 outputs = set(o.name for o in graph.output)
45 def retrieve_idnodes(graph, existing_nodes):
46 idnodes = []
47 for i, exnode in enumerate(existing_nodes):
48 if exnode is None:
49 continue
50 if exnode.op_type == 'Identity':
51 input = exnode.input[0]
52 output = exnode.output[0]
53 idnodes.append((i, exnode, input, output))
54 return idnodes
56 nodes = list(graph.node)
57 rem = 1
58 while rem > 0:
59 rem = 0
60 idnodes = retrieve_idnodes(graph, nodes)
61 restart = False
62 for i, _, inp, out in idnodes:
63 if restart:
64 break # pragma: no cover
65 if nodes[i] is None:
66 # Already removed.
67 continue # pragma: no cover
68 if inp in inputs and out in outputs:
69 # Cannot be removed.
70 continue
71 if not restart and out not in outputs:
72 # We cannot change an output name.
73 for j in range(len(nodes)): # pylint: disable=C0200
74 if nodes[j] is None:
75 continue
76 if out in nodes[j].input:
77 nodes[j] = _rename_node_input(nodes[j], out, inp)
78 rem += 1
79 if nodes[j].op_type == 'Identity':
80 restart = True # pragma: no cover
81 nodes[i] = None
82 rem += 1
83 continue
84 if not restart and inp not in inputs and inp not in outputs:
85 # We cannot change an input name or an output name.
86 for j in range(len(nodes)): # pylint: disable=C0200
87 if nodes[j] is None:
88 continue
89 if inp in nodes[j].output:
90 nodes[j] = _rename_node_output(nodes[j], inp, out)
91 rem += 1
92 if nodes[j].op_type == 'Identity':
93 restart = True # pragma: no cover
94 if inp in nodes[j].input:
95 nodes[j] = _rename_node_input(nodes[j], inp, out)
96 rem += 1
97 if nodes[j].op_type == 'Identity':
98 restart = True
99 nodes[i] = None
100 rem += 1
102 if recursive:
103 # Handles subgraphs.
104 for i in range(len(nodes)): # pylint: disable=C0200
105 node = nodes[i]
106 if node is None or not (node.attribute): # pylint: disable=C0325
107 continue
108 nodes[i] = _apply_remove_node_fct_node(
109 onnx_remove_node_identity,
110 node, recursive=True, debug_info=debug_info + [node.name])
112 # Finally create the new graph.
113 nodes = list(filter(lambda n: n is not None, nodes))
114 graph = make_graph(nodes, onnx_model.name,
115 onnx_model.input, onnx_model.output,
116 onnx_model.initializer)
118 graph.value_info.extend(onnx_model.value_info) # pylint: disable=E1101
119 return graph