Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief OnnxInferenceNode definition. 

4""" 

5import sys 

6import pprint 

7import numpy 

8from onnx import onnx_pb as onnx_proto 

9from .ops import load_op 

10 

11 

12class OnnxInferenceNode: 

13 """ 

14 A node to execute. 

15 """ 

16 

17 def __init__(self, onnx_node, desc, global_index): 

18 """ 

19 @param onnx_node onnx_node 

20 @param desc internal description 

21 @param global_index it is a function which returns a unique index 

22 for the output this operator generates 

23 """ 

24 if desc is None: 

25 raise ValueError("desc should not be None.") # pragma: no cover 

26 self.desc = desc 

27 self.onnx_node = onnx_node 

28 self._init(global_index) 

29 

30 @property 

31 def name(self): 

32 "Returns the ONNX name." 

33 return "_".join( 

34 [self.desc['domain'], self.onnx_node.op_type]).replace( 

35 ".", "_").replace('__', '_').strip('_') 

36 

37 def _init(self, global_index): 

38 """ 

39 Prepares the node. 

40 """ 

41 self.op_type = self.onnx_node.op_type 

42 self.order = -1 

43 self.variable_to_clean = [] 

44 self.inputs = list(self.onnx_node.input) 

45 self.outputs = list(self.onnx_node.output) 

46 self.inplaces = [] 

47 self.inputs_indices = [global_index(name) for name in self.inputs] 

48 self.outputs_indices = [global_index(name) for name in self.outputs] 

49 self._global_index = global_index 

50 

51 def set_order(self, order): 

52 """ 

53 Defines the order of execution. 

54 """ 

55 self.order = order 

56 

57 def add_variable_to_clean(self, name): 

58 """ 

59 Adds a variable which can be cleaned after the node 

60 execution. 

61 """ 

62 self.variable_to_clean.append(name) 

63 

64 def __str__(self): 

65 "usual" 

66 return "Onnx-{}({}) -> {}{}".format( 

67 self.op_type, ", ".join(self.inputs), ", ".join(self.outputs), 

68 " (name=%r)" % self.onnx_node.name 

69 if self.onnx_node.name else "") 

70 

71 def __repr__(self): 

72 "usual" 

73 return self.__str__() 

74 

75 def setup_runtime(self, runtime=None, variables=None, rt_class=None, 

76 target_opset=None, dtype=None, domain=None, 

77 ir_version=None, runtime_options=None): 

78 """ 

79 Loads runtime. 

80 

81 @param runtime runtime options 

82 @param variables registered variables created by previous operators 

83 @param rt_class runtime class used to compute 

84 prediction of subgraphs 

85 @param target_opset use a specific target opset 

86 @param dtype float computational type 

87 @param domain node domain 

88 @param ir_version if not None, changes the default value 

89 given by :epkg:`ONNX` 

90 @param runtime_options runtime options 

91 """ 

92 if self.desc is None: 

93 raise AttributeError( 

94 "desc should not be None.") # pragma: no cover 

95 self.preprocess_parameters( 

96 runtime, rt_class, ir_version=ir_version, target_opset=target_opset) 

97 options = {'provider': runtime} if runtime else {} 

98 if domain is not None: 

99 options['domain'] = domain 

100 if target_opset is not None: 

101 options['target_opset'] = target_opset 

102 if ir_version is not None: 

103 options['ir_version'] = ir_version 

104 if runtime_options is not None: 

105 options.update(runtime_options) 

106 if runtime == 'onnxruntime2': 

107 self.ops_ = load_op(self.onnx_node, desc=self.desc, 

108 options=options if options else None, 

109 variables=variables, dtype=dtype) 

110 elif runtime in ('python_compiled', 'python_compiled_debug'): 

111 options['provider'] = 'python' 

112 self.ops_ = load_op(self.onnx_node, desc=self.desc, 

113 options=options if options else None, 

114 variables=variables) 

115 else: 

116 self.ops_ = load_op(self.onnx_node, desc=self.desc, 

117 options=options if options else None, 

118 variables=variables) 

119 

120 @staticmethod 

121 def _find_static_inputs(body): 

122 """ 

123 Determines the loop inputs. It is any defined inputs 

124 by the subgraphs + any results used as a constant 

125 in the subgraphs. 

126 """ 

127 inputs_set = set(i.name for i in body.input) 

128 for init in body.initializer: 

129 inputs_set.add(init.name) 

130 for node in body.node: 

131 for i in node.output: 

132 inputs_set.add(i) 

133 add_inputs = [] 

134 for node in body.node: 

135 for i in node.input: 

136 if i not in inputs_set: 

137 # no graph input or output node matches 

138 # it must be a constant from the below graph 

139 add_inputs.append(i) 

140 inputs_set.add(i) 

141 return add_inputs 

142 

143 def preprocess_parameters(self, runtime, rt_class, ir_version=None, 

144 target_opset=None): 

145 """ 

146 Preprocesses the parameters, 

147 loads *GraphProto* 

148 (equivalent to :epkg:`ONNX` graph with 

149 less metadata). 

150 

151 @param runtime runtime options 

152 @param rt_class runtime class used to compute 

153 prediction of subgraphs 

154 @param ir_version if not None, overwrites the default value 

155 @param target_opset use a specific target opset 

156 """ 

157 if 'atts' not in self.desc: 

158 return # pragma: no cover 

159 inside_loop = self.onnx_node.op_type in {'Loop'} 

160 for _, v in self.desc['atts'].items(): 

161 if 'value' not in v: 

162 continue # pragma: no cover 

163 value = v['value'] 

164 if isinstance(value, onnx_proto.GraphProto): 

165 static_inputs = OnnxInferenceNode._find_static_inputs(value) 

166 try: 

167 sess = rt_class(v['value'], runtime=runtime, 

168 ir_version=ir_version, 

169 target_opset=target_opset, 

170 inside_loop=inside_loop, 

171 static_inputs=static_inputs) 

172 except RuntimeError as e: # pragma: no cover 

173 raise RuntimeError( 

174 "Unable to instantiate a node of type %r and name %r." 

175 "" % (self.onnx_node.op_type, self.onnx_node.name)) from e 

176 v['value_rt'] = sess 

177 

178 def run(self, values): 

179 """ 

180 Runs the node. 

181 the function updates values with outputs. 

182 

183 @param values list of existing values 

184 """ 

185 # This code takes times if the graph contains many nodes. 

186 # Maybe a C++ container would help in that case (to skip GIL). 

187 if self.inputs_indices is None: 

188 args = list(values[k] for k in self.inputs) 

189 else: 

190 args = list(values[k] for k in self.inputs_indices) 

191 try: 

192 if self.ops_.need_context(): 

193 context = {n: values[self._global_index(n)] 

194 for n in self.ops_.additional_inputs} 

195 res = self.ops_.run(*args, context=context) 

196 else: 

197 res = self.ops_.run(*args) 

198 except TypeError as e: 

199 raise RuntimeError( # pragma: no cover 

200 "Unable to run operator %r, inputs=%r." 

201 "" % (type(self.ops_), self.inputs)) from e 

202 except OverflowError as e: 

203 raise RuntimeError( # pragma: no cover 

204 "Unable to run operator %r, inputs=%r." 

205 "" % (type(self.ops_), self.inputs)) from e 

206 

207 if not isinstance(res, tuple): 

208 raise RuntimeError( # pragma: no cover 

209 "Results of operator %r should be a tuple." % type(self.ops_)) 

210 if len(self.outputs) != len(res): 

211 raise RuntimeError( # pragma: no cover 

212 "Mismatch number of outputs got {} for names {}.\n{}".format( 

213 len(res), list(sorted(self.outputs)), 

214 pprint.pformat(self.desc))) 

215 

216 # This code takes times if the graph contains many nodes. 

217 # Maybe a C++ container would help in that case (to skip GIL). 

218 if self.outputs_indices is None: 

219 for name, value in zip(self.outputs, res): 

220 values[name] = value 

221 else: 

222 for i, r in enumerate(res): 

223 values[self.outputs_indices[i]] = r 

224 

225 def switch_initializers_dtype(self, dtype_in=numpy.float32, 

226 dtype_out=numpy.float64): 

227 """ 

228 Switches all initializers to ``numpy.float64``. 

229 This only works if the runtime is ``'python'``. 

230 

231 @param dtype_in previous type 

232 @param dtype_out next type 

233 @return done operations 

234 """ 

235 done = [] 

236 for k, v in self.desc['atts'].items(): 

237 if 'value_rt' not in v: 

238 continue 

239 if isinstance(v['value_rt'], numpy.ndarray): 

240 if v['value_rt'].dtype == dtype_in: 

241 v['value_rt'] = v['value_rt'].astype(dtype_out) 

242 done.append(("+", "desc", k, v['value_rt'])) 

243 else: 

244 done.append(("-", "desc", k, v['value_rt'])) 

245 if hasattr(self, 'ops_') and self.ops_ is not None: 

246 res = self.ops_.switch_initializers_dtype(dtype_in, dtype_out) 

247 for r in res: 

248 done.append(("ops_", ) + r) 

249 return done 

250 

251 def _set_shape_inference_runtime(self, values): 

252 """ 

253 Updates *values* which shapes of the outputs. 

254 

255 :param values: container for shapes 

256 """ 

257 args = [values[k] for k in self.inputs] 

258 try: 

259 res = self.ops_.infer_shapes(*args) 

260 except (TypeError, ValueError) as e: # pragma: no cover 

261 raise TypeError( 

262 "Unable to call infer_shapes with {} arguments for class" 

263 " '{}' ({})".format(len(args), self.ops_.__class__.__name__, 

264 self.ops_.infer_shapes)) from e 

265 if not isinstance(res, tuple): 

266 raise RuntimeError( # pragma: no cover 

267 "Results of an operator should be a tuple for operator '{}'" 

268 ".".format(type(self.ops_))) 

269 if len(self.outputs) != len(res): 

270 raise RuntimeError( # pragma: no cover 

271 "Mismatch number of outputs got {} != {} for names {} (node='{}')." 

272 "\n{}".format( 

273 len(res), len(self.outputs), list(self.outputs), 

274 self.ops_.__class__.__name__, 

275 pprint.pformat(self.desc, depth=2))) 

276 for name, value in zip(self.outputs, res): 

277 values[name] = value 

278 return values 

279 

280 def _set_type_inference_runtime(self, values): 

281 """ 

282 Updates *values* which types of the outputs. 

283 

284 :param values: container for types 

285 """ 

286 args = [values[k] for k in self.inputs] 

287 try: 

288 res = self.ops_.infer_types(*args) 

289 except (TypeError, ValueError) as e: # pragma: no cover 

290 raise TypeError( 

291 "Unable to call infer_types with {} arguments for class" 

292 " '{}' ({})".format(len(args), self.ops_.__class__.__name__, 

293 self.ops_.infer_types)) from e 

294 if not isinstance(res, tuple): 

295 raise RuntimeError( # pragma: no cover 

296 "Results of an operator should be a tuple for operator '{}'" 

297 ".".format(type(self.ops_))) 

298 if len(self.outputs) != len(res): 

299 raise RuntimeError( # pragma: no cover 

300 "Mismatch number of outputs got {} != {} for names {} (node='{}')." 

301 "\n{}".format( 

302 len(res), len(self.outputs), list(self.outputs), 

303 self.ops_.__class__.__name__, 

304 pprint.pformat(self.desc, depth=2))) 

305 for name, value in zip(self.outputs, res): 

306 values[name] = value 

307 return values 

308 

309 def _set_size_inference_runtime(self, values): 

310 """ 

311 Updates *values* which types of the outputs. 

312 

313 :param values: container for sizes 

314 """ 

315 args = [values[k] for k in self.inputs] 

316 try: 

317 if self.ops_.need_context(): 

318 context = {n: values[n] 

319 for n in self.ops_.additional_inputs} 

320 res = self.ops_.infer_sizes(*args, context=context) 

321 else: 

322 res = self.ops_.infer_sizes(*args) 

323 except (TypeError, ValueError) as e: 

324 raise TypeError( 

325 "Unable to call infer_sizes with {} arguments for class" 

326 " '{}' ({})".format(len(args), self.ops_.__class__.__name__, 

327 self.ops_.infer_sizes)) from e 

328 if not isinstance(res, tuple): 

329 raise RuntimeError( # pragma: no cover 

330 "Results of an operator should be a tuple for operator '{}'" 

331 ".".format(type(self.ops_))) 

332 if len(self.outputs) + 1 != len(res): 

333 raise RuntimeError( # pragma: no cover 

334 "Mismatch number of outputs got {} != {} + 1 for names {} " 

335 "(node='{}').\n{}".format( 

336 len(res), len(self.outputs), list(self.outputs), 

337 self.ops_.__class__.__name__, 

338 pprint.pformat(self.desc, depth=2))) 

339 for name, value in zip(self.outputs, res[1:]): 

340 values[name] = value 

341 values['#' + self.onnx_node.name] = res[0] 

342 return values 

343 

344 def enable_inplace_compute(self, name): 

345 """ 

346 Let the node know that one input can be overwritten. 

347 

348 @param name input name 

349 """ 

350 self.inplaces.append(name) 

351 self.ops_.enable_inplace_compute(self.inputs.index(name)) 

352 

353 @property 

354 def inputs_args(self): 

355 """ 

356 Returns the list of arguments as well as 

357 the list of parameters with the default values 

358 (close to the signature). 

359 """ 

360 if not hasattr(self, 'ops_'): 

361 raise AttributeError( 

362 "Attribute 'ops_' is missing.") # pragma: no cover 

363 sigs = [] 

364 mand = self.ops_.args_mandatory 

365 if mand is None: 

366 mand = self.python_inputs 

367 sigs.extend(mand) 

368 if len(self.ops_.args_optional) > 0: 

369 sigs.extend(self.ops_.args_optional) 

370 if sys.version_info[:2] >= (3, 8): 

371 sigs.append('/') 

372 sigs.extend(self.ops_.args_default) 

373 return sigs 

374 

375 @property 

376 def python_inputs(self): 

377 """ 

378 Returns the python arguments. 

379 """ 

380 if not hasattr(self, 'ops_'): 

381 raise AttributeError( 

382 "Attribute 'ops_' is missing.") # pragma: no cover 

383 if hasattr(self.ops_, 'python_inputs'): 

384 return self.ops_.python_inputs 

385 return self.inputs 

386 

387 @property 

388 def modified_args(self): 

389 """ 

390 Returns the list of modified parameters. 

391 """ 

392 if not hasattr(self, 'ops_'): 

393 raise AttributeError( 

394 "Attribute 'ops_' is missing.") # pragma: no cover 

395 return self.ops_.args_default_modified 

396 

397 def to_python(self, inputs): 

398 """ 

399 Returns a python code for this operator. 

400 

401 @param inputs inputs name 

402 @return imports, python code, both as strings 

403 """ 

404 if not hasattr(self, 'ops_'): 

405 raise AttributeError( 

406 "Attribute 'ops_' is missing.") # pragma: no cover 

407 return self.ops_.to_python(inputs)