Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Common functions to reduce the number of 

4nodes of an :epkg:`ONNX` graphs. 

5""" 

6from onnx.helper import make_graph, ValueInfoProto, make_model 

7from onnx import AttributeProto, NodeProto 

8from onnx.helper import make_attribute 

9 

10 

11def _apply_optimisation_on_graph(fct, onnx_model, recursive=True, debug_info=None, 

12 **kwargs): 

13 """ 

14 Applies an optimisation function *fct* on a graph 

15 and not on the model. 

16 

17 @param fct function to optimise like 

18 @see fn onnx_remove_node_identity 

19 @param onnx_model onnx model 

20 @param recursive looks into subgraphs 

21 @param debug_info debug information (private) 

22 @param kwargs additional parameters 

23 @return new onnx _model 

24 """ 

25 if hasattr(onnx_model, 'graph'): 

26 if debug_info is None: 

27 debug_info = [] 

28 graph = fct( 

29 onnx_model.graph, debug_info=debug_info + ['GRAPH'], 

30 **kwargs) 

31 new_model = make_model(graph) 

32 new_model.ir_version = onnx_model.ir_version 

33 new_model.producer_name = onnx_model.producer_name 

34 new_model.producer_version = onnx_model.producer_version 

35 new_model.domain = onnx_model.domain 

36 new_model.model_version = onnx_model.model_version 

37 new_model.doc_string = onnx_model.doc_string 

38 if hasattr(onnx_model, 'value_info'): 

39 graph.value_info.extend(onnx_model.value_info) # pragma: no cover 

40 while len(new_model.opset_import) > 0: # pylint: disable=E1101 

41 new_model.opset_import.pop() # pylint: disable=E1101 

42 for oimp in onnx_model.opset_import: 

43 op_set = new_model.opset_import.add() # pylint: disable=E1101 

44 op_set.domain = oimp.domain 

45 op_set.version = oimp.version 

46 return new_model 

47 raise TypeError( # pragma: no cover 

48 "This function only works on 'ModelProto' anod not not on" 

49 " {}.".format(type(onnx_model))) 

50 

51 

52def _apply_remove_node_fct_node(fct, node, recursive, debug_info): 

53 """ 

54 Applies an optimizing function on a subgraphs. 

55 

56 @param node onnx node 

57 @param recursive does it in subgraphs as well 

58 @return new node 

59 """ 

60 if not hasattr(node, 'attribute'): 

61 return node # pragma: no cover 

62 modified = 0 

63 new_atts = [] 

64 for att in node.attribute: 

65 if att.name == 'body': 

66 new_body = fct( 

67 att.g, recursive=recursive, 

68 debug_info=debug_info + [att.name]) 

69 new_atts.append(_make_att_graph(att.name, new_body)) 

70 modified += 1 

71 else: 

72 new_atts.append(att) 

73 if modified > 0: 

74 new_node = _make_node(node.op_type, node.input, 

75 node.output, name=node.name, 

76 attributes=new_atts) 

77 return new_node 

78 return node 

79 

80 

81def _make_node(op_type, inputs, outputs, name=None, doc_string=None, 

82 domain=None, attributes=None): 

83 """ 

84 Constructs a NodeProto. 

85 

86 :param op_type: (string): The name of the operator to construct 

87 :param inputs: list of input names 

88 :param outputs: list of output names 

89 :param name: optional unique identifier for NodeProto 

90 :param doc_string: optional documentation 

91 string for NodeProto 

92 :param domain: optional domain for NodeProto. 

93 If it's None, we will just use default domain (which is empty) 

94 :param attributes: the attributes of the node. The acceptable values 

95 are documented in :func:`make_attribute`. 

96 :return: node 

97 """ 

98 node = NodeProto() 

99 node.op_type = op_type 

100 node.input.extend(inputs) # pylint: disable=E1101 

101 node.output.extend(outputs) # pylint: disable=E1101 

102 if name: 

103 node.name = name 

104 if doc_string: 

105 node.doc_string = doc_string # pragma: no cover 

106 if domain is not None: 

107 node.domain = domain 

108 if isinstance(attributes, dict): 

109 if len(attributes) > 0: # pragma: no cover 

110 node.attribute.extend( # pylint: disable=E1101 

111 make_attribute(key, value) 

112 for key, value in sorted(attributes.items())) 

113 elif attributes: 

114 for att in attributes: 

115 node.attribute.extend([att]) # pylint: disable=E1101 

116 return node 

117 

118 

119def _replace(name, old_name, new_name): 

120 if isinstance(old_name, dict) and new_name is None: 

121 return old_name.get(name, name) 

122 if name == old_name: 

123 return new_name 

124 return name 

125 

126 

127def _rename_node_input(onnx_node, old_name, new_name=None): 

128 """ 

129 Renames an input from a node. 

130 

131 @param onnx_node onnx_node 

132 @param old_name old name 

133 @param new_name new name or None if *old_name* is a dictionary 

134 @return new node 

135 """ 

136 inputs = [_replace(name, old_name, new_name) for name in onnx_node.input] 

137 outputs = list(onnx_node.output) 

138 if hasattr(onnx_node, 'attribute'): 

139 new_atts = [] 

140 for att in onnx_node.attribute: 

141 if att.name == 'body': 

142 new_body = _rename_graph_input(att.g, old_name, new_name) 

143 attr = AttributeProto() 

144 attr.name = att.name 

145 attr.g.CopyFrom(new_body) # pylint: disable=E1101 

146 attr.type = AttributeProto.GRAPH # pylint: disable=E1101 

147 new_atts.append(attr) 

148 else: 

149 new_atts.append(att) 

150 atts = new_atts 

151 else: 

152 atts = None # pragma: no cover 

153 node = _make_node( 

154 onnx_node.op_type, inputs, outputs, name=onnx_node.name, 

155 domain=onnx_node.domain, attributes=atts) 

156 return node 

157 

158 

159def _copy_value_info_proto(new_name, obj): 

160 value_info = ValueInfoProto() 

161 value_info.name = new_name 

162 value_info.type.CopyFrom(obj.type) # pylint: disable=E1101 

163 if obj.type.doc_string: 

164 value_info.doc_string = obj.type.doc_string 

165 return value_info 

166 

167 

168def _rename_graph_output(graph, old_name, new_name): 

169 """ 

170 Renames an output and adds an *Identity* node 

171 to connect the dots. 

172 

173 @param graph ONNX graph 

174 @return modified graph 

175 """ 

176 outputs = [] 

177 for o in graph.output: 

178 if old_name != o.name: 

179 outputs.append(o) 

180 else: 

181 outputs.append(_copy_value_info_proto(new_name, o)) 

182 nodes = list(graph.node) 

183 nodes.append(_make_node('Identity', [old_name], [new_name])) 

184 new_graph = make_graph(nodes, graph.name, graph.input, outputs, 

185 graph.initializer) 

186 new_graph.value_info.extend(graph.value_info) # pylint: disable=E1101 

187 return new_graph 

188 

189 

190def _rename_graph_input(graph, old_name, new_name): 

191 """ 

192 Renames an input and adds an *Identity* node 

193 to connect the dots. 

194 

195 @param graph ONNX graph 

196 @return modified graph 

197 """ 

198 inputs = [] 

199 for i in graph.input: 

200 if old_name != i.name: 

201 inputs.append(i) 

202 else: 

203 inputs.append(_copy_value_info_proto(new_name, i)) 

204 nodes = list(graph.node) 

205 nodes.append(_make_node('Identity', [new_name], [old_name])) 

206 new_graph = make_graph(nodes, graph.name, inputs, graph.output, 

207 graph.initializer) 

208 new_graph.value_info.extend(graph.value_info) # pylint: disable=E1101 

209 return new_graph 

210 

211 

212def _make_att_graph(name, new_body): 

213 attr = AttributeProto() 

214 attr.name = name 

215 attr.g.CopyFrom(new_body) # pylint: disable=E1101 

216 attr.type = AttributeProto.GRAPH # pylint: disable=E1101 

217 return attr 

218 

219 

220def _rename_node_output(onnx_node, old_name, new_name): 

221 """ 

222 Renames an output from a node. 

223 

224 @param onnx_node onnx_node 

225 @param old_name old name 

226 @param new_name new name 

227 @return new node 

228 """ 

229 inputs = list(onnx_node.input) 

230 outputs = [_replace(name, old_name, new_name) for name in onnx_node.output] 

231 if hasattr(onnx_node, 'attribute'): 

232 new_atts = [] 

233 for att in onnx_node.attribute: 

234 if att.name == 'body': 

235 new_body = _rename_graph_output(att.g, old_name, new_name) 

236 new_atts.append(_make_att_graph(att.name, new_body)) 

237 else: 

238 new_atts.append(att) 

239 atts = new_atts 

240 else: 

241 atts = None # pragma: no cover 

242 node = _make_node( 

243 onnx_node.op_type, inputs, outputs, name=onnx_node.name, 

244 domain=onnx_node.domain, attributes=atts) 

245 return node