Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Optimisation of :epkg:`ONNX` graphs.
4"""
5import copy
6import hashlib
7from onnx.helper import make_graph
8from ._onnx_optimisation_common import ( # pylint: disable=E0611
9 _rename_node_input,
10 _rename_node_output,
11 _apply_optimisation_on_graph,
12 _apply_remove_node_fct_node
13)
16def _hash_obj_content(obj, max_size=1000):
17 """
18 Hash the content of an object.
19 """
20 m = hashlib.sha256()
21 if hasattr(obj, 'op_type'):
22 # An operator.
23 m.update(obj.op_type.encode('ascii'))
24 m.update(len(obj.output).to_bytes(8, byteorder='big'))
25 for i in obj.input:
26 m.update(i.encode('ascii'))
27 if hasattr(obj, 'attribute'):
28 for att in obj.attribute:
29 m.update(att.name.encode('ascii'))
30 m.update(_hash_obj_content(att))
31 else:
32 # An initializer.
33 obj = copy.deepcopy(obj)
34 obj.name = ""
35 obj.doc_string = ""
36 m.update(obj.SerializeToString())
38 content = m.digest()
39 if len(content) > max_size:
40 content = content[:max_size]
41 return content
44def onnx_remove_node_redundant(onnx_model, recursive=True, debug_info=None,
45 max_hash_size=1000, **options):
46 """
47 Removes redundant part of the graph. A redundant part is
48 a set of nodes which takes the same inputs and produces
49 the same outputs. It first starts by looking into duplicated
50 initializers, then looks into nodes taking the same inputs
51 and sharing the same type and parameters.
53 @param onnx_model onnx model
54 @param recursive looks into subgraphs
55 @param debug_info debug information (private)
56 @param max_hash_size limit the size of a hash used to detect
57 identical subgraphs
58 @param options additional options (unused)
59 @return new onnx _model
60 """
61 if debug_info is None:
62 debug_info = [str(type(onnx_model)).split('.')[-1].strip("'>")]
63 else:
64 debug_info = debug_info + \
65 [str(type(onnx_model)).split('.')[-1].strip("'>")]
67 if hasattr(onnx_model, 'graph'):
68 return _apply_optimisation_on_graph(
69 onnx_remove_node_redundant, onnx_model,
70 recursive=recursive, debug_info=debug_info,
71 max_hash_size=max_hash_size, **options)
73 def _enumerate_rename_list_nodes_inputs(nodes, rename):
74 for i, node in enumerate(nodes):
75 if node is None:
76 yield False, i, None
77 continue
78 if any(set(node.input) & set(rename)):
79 yield True, i, _rename_node_input(node, rename)
80 continue
81 yield False, i, node
83 graph = onnx_model
85 # Detects duplicated initializers.
86 hashes = {}
87 names = []
88 rename = {}
89 for init in graph.initializer:
90 hs = _hash_obj_content(init, max_size=max_hash_size)
91 if hs in hashes:
92 # Already seen.
93 rename[init.name] = hashes[hs] # pragma: no cover
94 else:
95 # New.
96 hashes[hs] = init.name
97 names.append(init.name)
99 new_inits = [init for init in graph.initializer if init.name in set(names)]
101 # Renames node inputs.
102 new_nodes = []
103 new_nodes = list(graph.node)
104 new_nodes = list(
105 _[2] for _ in _enumerate_rename_list_nodes_inputs(new_nodes, rename))
107 # Detects duplicated operators.
108 graph_outputs = set(o.name for o in graph.output)
109 node_hashes = {}
110 changed = 1
111 replace = {}
112 while changed > 0:
113 changed = 0
114 nnodes = len(new_nodes)
115 for i in range(nnodes):
116 if i in replace:
117 # Already removed.
118 continue
119 node = new_nodes[i]
120 hash = _hash_obj_content(node, max_size=max_hash_size)
121 if hash in node_hashes:
122 ni = node_hashes[hash]
123 if ni == i:
124 continue
125 replace[i] = ni
126 changed += 1
128 # Specifies what to rename.
129 # One exception: the output is one of the graph output.
130 rep = new_nodes[ni]
131 for old, nn in zip(node.output, rep.output):
132 if old in graph_outputs:
133 rename[nn] = old
134 new_nodes[ni] = _rename_node_output(
135 new_nodes[ni], nn, old)
136 else:
137 rename[old] = nn
139 # Renames inputs.
140 new_new_nodes = []
141 renew_index = set()
142 for changed, ci, node in _enumerate_rename_list_nodes_inputs(new_nodes, rename):
143 if changed:
144 renew_index.add(ci)
145 new_new_nodes.append(node)
146 new_nodes = new_new_nodes
148 # Renews hashes.
149 renew_hash = set(
150 k for k, v in node_hashes.items() if v in renew_index)
151 for hs in renew_hash:
152 del node_hashes[hs]
153 new_nodes[i] = None
154 else:
155 node_hashes[hash] = i
157 if recursive:
158 # Handles subgraphs.
159 for i in range(len(new_nodes)): # pylint: disable=C0200
160 node = new_nodes[i]
161 if node is None or not (node.attribute): # pylint: disable=C0325
162 continue
163 new_nodes[i] = _apply_remove_node_fct_node(
164 onnx_remove_node_redundant,
165 node, recursive=True, debug_info=debug_info + [node.name])
167 # Finally create the new graph.
168 nodes = list(filter(lambda n: n is not None, new_nodes))
169 graph = make_graph(nodes, onnx_model.name,
170 onnx_model.input, onnx_model.output,
171 new_inits)
173 graph.value_info.extend(onnx_model.value_info) # pylint: disable=E1101
174 return graph