Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief OnnxInferenceNode definition.
4"""
5import sys
6import pprint
7import numpy
8from onnx import onnx_pb as onnx_proto
9from .ops import load_op
12class OnnxInferenceNode:
13 """
14 A node to execute.
15 """
17 def __init__(self, onnx_node, desc, global_index):
18 """
19 @param onnx_node onnx_node
20 @param desc internal description
21 @param global_index it is a function which returns a unique index
22 for the output this operator generates
23 """
24 if desc is None:
25 raise ValueError("desc should not be None.") # pragma: no cover
26 self.desc = desc
27 self.onnx_node = onnx_node
28 self._init(global_index)
30 @property
31 def name(self):
32 "Returns the ONNX name."
33 return "_".join(
34 [self.desc['domain'], self.onnx_node.op_type]).replace(
35 ".", "_").replace('__', '_').strip('_')
37 def _init(self, global_index):
38 """
39 Prepares the node.
40 """
41 self.op_type = self.onnx_node.op_type
42 self.order = -1
43 self.variable_to_clean = []
44 self.inputs = list(self.onnx_node.input)
45 self.outputs = list(self.onnx_node.output)
46 self.inplaces = []
47 self.inputs_indices = [global_index(name) for name in self.inputs]
48 self.outputs_indices = [global_index(name) for name in self.outputs]
49 self._global_index = global_index
51 def set_order(self, order):
52 """
53 Defines the order of execution.
54 """
55 self.order = order
57 def add_variable_to_clean(self, name):
58 """
59 Adds a variable which can be cleaned after the node
60 execution.
61 """
62 self.variable_to_clean.append(name)
64 def __str__(self):
65 "usual"
66 return "Onnx-{}({}) -> {}{}".format(
67 self.op_type, ", ".join(self.inputs), ", ".join(self.outputs),
68 " (name=%r)" % self.onnx_node.name
69 if self.onnx_node.name else "")
71 def __repr__(self):
72 "usual"
73 return self.__str__()
75 def setup_runtime(self, runtime=None, variables=None, rt_class=None,
76 target_opset=None, dtype=None, domain=None,
77 ir_version=None, runtime_options=None):
78 """
79 Loads runtime.
81 @param runtime runtime options
82 @param variables registered variables created by previous operators
83 @param rt_class runtime class used to compute
84 prediction of subgraphs
85 @param target_opset use a specific target opset
86 @param dtype float computational type
87 @param domain node domain
88 @param ir_version if not None, changes the default value
89 given by :epkg:`ONNX`
90 @param runtime_options runtime options
91 """
92 if self.desc is None:
93 raise AttributeError(
94 "desc should not be None.") # pragma: no cover
95 self.preprocess_parameters(
96 runtime, rt_class, ir_version=ir_version, target_opset=target_opset)
97 options = {'provider': runtime} if runtime else {}
98 if domain is not None:
99 options['domain'] = domain
100 if target_opset is not None:
101 options['target_opset'] = target_opset
102 if ir_version is not None:
103 options['ir_version'] = ir_version
104 if runtime_options is not None:
105 options.update(runtime_options)
106 if runtime == 'onnxruntime2':
107 self.ops_ = load_op(self.onnx_node, desc=self.desc,
108 options=options if options else None,
109 variables=variables, dtype=dtype)
110 elif runtime in ('python_compiled', 'python_compiled_debug'):
111 options['provider'] = 'python'
112 self.ops_ = load_op(self.onnx_node, desc=self.desc,
113 options=options if options else None,
114 variables=variables)
115 else:
116 self.ops_ = load_op(self.onnx_node, desc=self.desc,
117 options=options if options else None,
118 variables=variables)
120 @staticmethod
121 def _find_static_inputs(body):
122 """
123 Determines the loop inputs. It is any defined inputs
124 by the subgraphs + any results used as a constant
125 in the subgraphs.
126 """
127 inputs_set = set(i.name for i in body.input)
128 for init in body.initializer:
129 inputs_set.add(init.name)
130 for node in body.node:
131 for i in node.output:
132 inputs_set.add(i)
133 add_inputs = []
134 for node in body.node:
135 for i in node.input:
136 if i not in inputs_set:
137 # no graph input or output node matches
138 # it must be a constant from the below graph
139 add_inputs.append(i)
140 inputs_set.add(i)
141 return add_inputs
143 def preprocess_parameters(self, runtime, rt_class, ir_version=None,
144 target_opset=None):
145 """
146 Preprocesses the parameters,
147 loads *GraphProto*
148 (equivalent to :epkg:`ONNX` graph with
149 less metadata).
151 @param runtime runtime options
152 @param rt_class runtime class used to compute
153 prediction of subgraphs
154 @param ir_version if not None, overwrites the default value
155 @param target_opset use a specific target opset
156 """
157 if 'atts' not in self.desc:
158 return # pragma: no cover
159 inside_loop = self.onnx_node.op_type in {'Loop'}
160 for _, v in self.desc['atts'].items():
161 if 'value' not in v:
162 continue # pragma: no cover
163 value = v['value']
164 if isinstance(value, onnx_proto.GraphProto):
165 static_inputs = OnnxInferenceNode._find_static_inputs(value)
166 try:
167 sess = rt_class(v['value'], runtime=runtime,
168 ir_version=ir_version,
169 target_opset=target_opset,
170 inside_loop=inside_loop,
171 static_inputs=static_inputs)
172 except RuntimeError as e: # pragma: no cover
173 raise RuntimeError(
174 "Unable to instantiate a node of type %r and name %r."
175 "" % (self.onnx_node.op_type, self.onnx_node.name)) from e
176 v['value_rt'] = sess
178 def run(self, values):
179 """
180 Runs the node.
181 the function updates values with outputs.
183 @param values list of existing values
184 """
185 # This code takes times if the graph contains many nodes.
186 # Maybe a C++ container would help in that case (to skip GIL).
187 if self.inputs_indices is None:
188 args = list(values[k] for k in self.inputs)
189 else:
190 args = list(values[k] for k in self.inputs_indices)
191 try:
192 if self.ops_.need_context():
193 context = {n: values[self._global_index(n)]
194 for n in self.ops_.additional_inputs}
195 res = self.ops_.run(*args, context=context)
196 else:
197 res = self.ops_.run(*args)
198 except TypeError as e:
199 raise RuntimeError( # pragma: no cover
200 "Unable to run operator %r, inputs=%r."
201 "" % (type(self.ops_), self.inputs)) from e
202 except OverflowError as e:
203 raise RuntimeError( # pragma: no cover
204 "Unable to run operator %r, inputs=%r."
205 "" % (type(self.ops_), self.inputs)) from e
207 if not isinstance(res, tuple):
208 raise RuntimeError( # pragma: no cover
209 "Results of operator %r should be a tuple." % type(self.ops_))
210 if len(self.outputs) != len(res):
211 raise RuntimeError( # pragma: no cover
212 "Mismatch number of outputs got {} for names {}.\n{}".format(
213 len(res), list(sorted(self.outputs)),
214 pprint.pformat(self.desc)))
216 # This code takes times if the graph contains many nodes.
217 # Maybe a C++ container would help in that case (to skip GIL).
218 if self.outputs_indices is None:
219 for name, value in zip(self.outputs, res):
220 values[name] = value
221 else:
222 for i, r in enumerate(res):
223 values[self.outputs_indices[i]] = r
225 def switch_initializers_dtype(self, dtype_in=numpy.float32,
226 dtype_out=numpy.float64):
227 """
228 Switches all initializers to ``numpy.float64``.
229 This only works if the runtime is ``'python'``.
231 @param dtype_in previous type
232 @param dtype_out next type
233 @return done operations
234 """
235 done = []
236 for k, v in self.desc['atts'].items():
237 if 'value_rt' not in v:
238 continue
239 if isinstance(v['value_rt'], numpy.ndarray):
240 if v['value_rt'].dtype == dtype_in:
241 v['value_rt'] = v['value_rt'].astype(dtype_out)
242 done.append(("+", "desc", k, v['value_rt']))
243 else:
244 done.append(("-", "desc", k, v['value_rt']))
245 if hasattr(self, 'ops_') and self.ops_ is not None:
246 res = self.ops_.switch_initializers_dtype(dtype_in, dtype_out)
247 for r in res:
248 done.append(("ops_", ) + r)
249 return done
251 def _set_shape_inference_runtime(self, values):
252 """
253 Updates *values* which shapes of the outputs.
255 :param values: container for shapes
256 """
257 args = [values[k] for k in self.inputs]
258 try:
259 res = self.ops_.infer_shapes(*args)
260 except (TypeError, ValueError) as e: # pragma: no cover
261 raise TypeError(
262 "Unable to call infer_shapes with {} arguments for class"
263 " '{}' ({})".format(len(args), self.ops_.__class__.__name__,
264 self.ops_.infer_shapes)) from e
265 if not isinstance(res, tuple):
266 raise RuntimeError( # pragma: no cover
267 "Results of an operator should be a tuple for operator '{}'"
268 ".".format(type(self.ops_)))
269 if len(self.outputs) != len(res):
270 raise RuntimeError( # pragma: no cover
271 "Mismatch number of outputs got {} != {} for names {} (node='{}')."
272 "\n{}".format(
273 len(res), len(self.outputs), list(self.outputs),
274 self.ops_.__class__.__name__,
275 pprint.pformat(self.desc, depth=2)))
276 for name, value in zip(self.outputs, res):
277 values[name] = value
278 return values
280 def _set_type_inference_runtime(self, values):
281 """
282 Updates *values* which types of the outputs.
284 :param values: container for types
285 """
286 args = [values[k] for k in self.inputs]
287 try:
288 res = self.ops_.infer_types(*args)
289 except (TypeError, ValueError) as e: # pragma: no cover
290 raise TypeError(
291 "Unable to call infer_types with {} arguments for class"
292 " '{}' ({})".format(len(args), self.ops_.__class__.__name__,
293 self.ops_.infer_types)) from e
294 if not isinstance(res, tuple):
295 raise RuntimeError( # pragma: no cover
296 "Results of an operator should be a tuple for operator '{}'"
297 ".".format(type(self.ops_)))
298 if len(self.outputs) != len(res):
299 raise RuntimeError( # pragma: no cover
300 "Mismatch number of outputs got {} != {} for names {} (node='{}')."
301 "\n{}".format(
302 len(res), len(self.outputs), list(self.outputs),
303 self.ops_.__class__.__name__,
304 pprint.pformat(self.desc, depth=2)))
305 for name, value in zip(self.outputs, res):
306 values[name] = value
307 return values
309 def _set_size_inference_runtime(self, values):
310 """
311 Updates *values* which types of the outputs.
313 :param values: container for sizes
314 """
315 args = [values[k] for k in self.inputs]
316 try:
317 if self.ops_.need_context():
318 context = {n: values[n]
319 for n in self.ops_.additional_inputs}
320 res = self.ops_.infer_sizes(*args, context=context)
321 else:
322 res = self.ops_.infer_sizes(*args)
323 except (TypeError, ValueError) as e:
324 raise TypeError(
325 "Unable to call infer_sizes with {} arguments for class"
326 " '{}' ({})".format(len(args), self.ops_.__class__.__name__,
327 self.ops_.infer_sizes)) from e
328 if not isinstance(res, tuple):
329 raise RuntimeError( # pragma: no cover
330 "Results of an operator should be a tuple for operator '{}'"
331 ".".format(type(self.ops_)))
332 if len(self.outputs) + 1 != len(res):
333 raise RuntimeError( # pragma: no cover
334 "Mismatch number of outputs got {} != {} + 1 for names {} "
335 "(node='{}').\n{}".format(
336 len(res), len(self.outputs), list(self.outputs),
337 self.ops_.__class__.__name__,
338 pprint.pformat(self.desc, depth=2)))
339 for name, value in zip(self.outputs, res[1:]):
340 values[name] = value
341 values['#' + self.onnx_node.name] = res[0]
342 return values
344 def enable_inplace_compute(self, name):
345 """
346 Let the node know that one input can be overwritten.
348 @param name input name
349 """
350 self.inplaces.append(name)
351 self.ops_.enable_inplace_compute(self.inputs.index(name))
353 @property
354 def inputs_args(self):
355 """
356 Returns the list of arguments as well as
357 the list of parameters with the default values
358 (close to the signature).
359 """
360 if not hasattr(self, 'ops_'):
361 raise AttributeError(
362 "Attribute 'ops_' is missing.") # pragma: no cover
363 sigs = []
364 mand = self.ops_.args_mandatory
365 if mand is None:
366 mand = self.python_inputs
367 sigs.extend(mand)
368 if len(self.ops_.args_optional) > 0:
369 sigs.extend(self.ops_.args_optional)
370 if sys.version_info[:2] >= (3, 8):
371 sigs.append('/')
372 sigs.extend(self.ops_.args_default)
373 return sigs
375 @property
376 def python_inputs(self):
377 """
378 Returns the python arguments.
379 """
380 if not hasattr(self, 'ops_'):
381 raise AttributeError(
382 "Attribute 'ops_' is missing.") # pragma: no cover
383 if hasattr(self.ops_, 'python_inputs'):
384 return self.ops_.python_inputs
385 return self.inputs
387 @property
388 def modified_args(self):
389 """
390 Returns the list of modified parameters.
391 """
392 if not hasattr(self, 'ops_'):
393 raise AttributeError(
394 "Attribute 'ops_' is missing.") # pragma: no cover
395 return self.ops_.args_default_modified
397 def to_python(self, inputs):
398 """
399 Returns a python code for this operator.
401 @param inputs inputs name
402 @return imports, python code, both as strings
403 """
404 if not hasattr(self, 'ops_'):
405 raise AttributeError(
406 "Attribute 'ops_' is missing.") # pragma: no cover
407 return self.ops_.to_python(inputs)