Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Optimisation of :epkg:`ONNX` graphs. 

4""" 

5import copy 

6import hashlib 

7from onnx.helper import make_graph 

8from ._onnx_optimisation_common import ( # pylint: disable=E0611 

9 _rename_node_input, 

10 _rename_node_output, 

11 _apply_optimisation_on_graph, 

12 _apply_remove_node_fct_node 

13) 

14 

15 

16def _hash_obj_content(obj, max_size=1000): 

17 """ 

18 Hash the content of an object. 

19 """ 

20 m = hashlib.sha256() 

21 if hasattr(obj, 'op_type'): 

22 # An operator. 

23 m.update(obj.op_type.encode('ascii')) 

24 m.update(len(obj.output).to_bytes(8, byteorder='big')) 

25 for i in obj.input: 

26 m.update(i.encode('ascii')) 

27 if hasattr(obj, 'attribute'): 

28 for att in obj.attribute: 

29 m.update(att.name.encode('ascii')) 

30 m.update(_hash_obj_content(att)) 

31 else: 

32 # An initializer. 

33 obj = copy.deepcopy(obj) 

34 obj.name = "" 

35 obj.doc_string = "" 

36 m.update(obj.SerializeToString()) 

37 

38 content = m.digest() 

39 if len(content) > max_size: 

40 content = content[:max_size] 

41 return content 

42 

43 

44def onnx_remove_node_redundant(onnx_model, recursive=True, debug_info=None, 

45 max_hash_size=1000, **options): 

46 """ 

47 Removes redundant part of the graph. A redundant part is 

48 a set of nodes which takes the same inputs and produces 

49 the same outputs. It first starts by looking into duplicated 

50 initializers, then looks into nodes taking the same inputs 

51 and sharing the same type and parameters. 

52 

53 @param onnx_model onnx model 

54 @param recursive looks into subgraphs 

55 @param debug_info debug information (private) 

56 @param max_hash_size limit the size of a hash used to detect 

57 identical subgraphs 

58 @param options additional options (unused) 

59 @return new onnx _model 

60 """ 

61 if debug_info is None: 

62 debug_info = [str(type(onnx_model)).split('.')[-1].strip("'>")] 

63 else: 

64 debug_info = debug_info + \ 

65 [str(type(onnx_model)).split('.')[-1].strip("'>")] 

66 

67 if hasattr(onnx_model, 'graph'): 

68 return _apply_optimisation_on_graph( 

69 onnx_remove_node_redundant, onnx_model, 

70 recursive=recursive, debug_info=debug_info, 

71 max_hash_size=max_hash_size, **options) 

72 

73 def _enumerate_rename_list_nodes_inputs(nodes, rename): 

74 for i, node in enumerate(nodes): 

75 if node is None: 

76 yield False, i, None 

77 continue 

78 if any(set(node.input) & set(rename)): 

79 yield True, i, _rename_node_input(node, rename) 

80 continue 

81 yield False, i, node 

82 

83 graph = onnx_model 

84 

85 # Detects duplicated initializers. 

86 hashes = {} 

87 names = [] 

88 rename = {} 

89 for init in graph.initializer: 

90 hs = _hash_obj_content(init, max_size=max_hash_size) 

91 if hs in hashes: 

92 # Already seen. 

93 rename[init.name] = hashes[hs] # pragma: no cover 

94 else: 

95 # New. 

96 hashes[hs] = init.name 

97 names.append(init.name) 

98 

99 new_inits = [init for init in graph.initializer if init.name in set(names)] 

100 

101 # Renames node inputs. 

102 new_nodes = [] 

103 new_nodes = list(graph.node) 

104 new_nodes = list( 

105 _[2] for _ in _enumerate_rename_list_nodes_inputs(new_nodes, rename)) 

106 

107 # Detects duplicated operators. 

108 graph_outputs = set(o.name for o in graph.output) 

109 node_hashes = {} 

110 changed = 1 

111 replace = {} 

112 while changed > 0: 

113 changed = 0 

114 nnodes = len(new_nodes) 

115 for i in range(nnodes): 

116 if i in replace: 

117 # Already removed. 

118 continue 

119 node = new_nodes[i] 

120 hash = _hash_obj_content(node, max_size=max_hash_size) 

121 if hash in node_hashes: 

122 ni = node_hashes[hash] 

123 if ni == i: 

124 continue 

125 replace[i] = ni 

126 changed += 1 

127 

128 # Specifies what to rename. 

129 # One exception: the output is one of the graph output. 

130 rep = new_nodes[ni] 

131 for old, nn in zip(node.output, rep.output): 

132 if old in graph_outputs: 

133 rename[nn] = old 

134 new_nodes[ni] = _rename_node_output( 

135 new_nodes[ni], nn, old) 

136 else: 

137 rename[old] = nn 

138 

139 # Renames inputs. 

140 new_new_nodes = [] 

141 renew_index = set() 

142 for changed, ci, node in _enumerate_rename_list_nodes_inputs(new_nodes, rename): 

143 if changed: 

144 renew_index.add(ci) 

145 new_new_nodes.append(node) 

146 new_nodes = new_new_nodes 

147 

148 # Renews hashes. 

149 renew_hash = set( 

150 k for k, v in node_hashes.items() if v in renew_index) 

151 for hs in renew_hash: 

152 del node_hashes[hs] 

153 new_nodes[i] = None 

154 else: 

155 node_hashes[hash] = i 

156 

157 if recursive: 

158 # Handles subgraphs. 

159 for i in range(len(new_nodes)): # pylint: disable=C0200 

160 node = new_nodes[i] 

161 if node is None or not (node.attribute): # pylint: disable=C0325 

162 continue 

163 new_nodes[i] = _apply_remove_node_fct_node( 

164 onnx_remove_node_redundant, 

165 node, recursive=True, debug_info=debug_info + [node.name]) 

166 

167 # Finally create the new graph. 

168 nodes = list(filter(lambda n: n is not None, new_nodes)) 

169 graph = make_graph(nodes, onnx_model.name, 

170 onnx_model.input, onnx_model.output, 

171 new_inits) 

172 

173 graph.value_info.extend(onnx_model.value_info) # pylint: disable=E1101 

174 return graph