Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# pylint: disable=C0302
2"""
3@file
4@brief Implements a class able to compute the predictions
5from on an :epkg:`ONNX` model.
6"""
7from collections import OrderedDict
8from io import BytesIO
9from time import perf_counter
10import warnings
11import textwrap
12import pprint
13import numpy
14from scipy.sparse import coo_matrix
15from onnx import load, load_model, checker, shape_inference
16from onnx import onnx_pb as onnx_proto
17from onnx.helper import make_model
18from ..tools.code_helper import make_callable, print_code
19from ..onnx_tools.onnx2py_helper import (
20 _var_as_dict, numpy_min, numpy_max, guess_numpy_type_from_string)
21from ..onnx_tools.onnx_manipulations import (
22 select_model_inputs_outputs, enumerate_model_node_outputs,
23 overwrite_opset, insert_results_into_onnx)
24from ..onnx_tools.optim import onnx_remove_node_unused
25from .onnx_inference_node import OnnxInferenceNode
26from .onnx_inference_exports import OnnxInferenceExport
27from .shape_object import ShapeObject
28from .type_object import SequenceType
31class OnnxInference:
32 """
33 Loads an :epkg:`ONNX` file or object or stream.
34 Computes the output of the :epkg:`ONNX` graph.
35 Several runtimes are available.
37 * ``'python'``: the runtime implements every onnx operator
38 needed to run a :epkg:`scikit-learn` model by using :epkg:`numpy`
39 or C++ code.
40 * ``'python_compiled'``: it is the same runtime than the previous
41 one except every operator is called from a compiled function
42 (@see me _build_compile_run) instead for a method going through
43 the list of operator
44 * ``'onnxruntime1'``: uses :epkg:`onnxruntime`
45 * ``'onnxruntime2'``: this mode is mostly used to debug as
46 python handles calling every operator but :epkg:`onnxruntime`
47 is called for every of them, this process may fail due to
48 wrong inference type specially of the graph includes
49 custom nodes, in that case, it is better to compute the output
50 of intermediates nodes. It is much slower as fo every output, every
51 node is computed but more robust.
53 :param onnx_or_bytes_or_stream: :epkg:`onnx` object,
54 bytes, or filename or stream
55 :param runtime: runtime options
56 :param skip_run: do not build the runtime
57 :param inplace: use inplace computation as much as possible
58 :param input_inplace: the computation is allowed
59 to overwrite the input, see :meth:`_guess_inplace
60 <mlprodict.onnxrt.onnx_inference.OnnxInference._guess_inplace>`
61 :param ir_version: if not None, overwrite the default version
62 :param target_opset: used to overwrite *target_opset*
63 :param runtime_options: specific options for the runtime
64 :param inside_loop: tells the runtime the graph is meant to
65 be repeated multiple times (in that case, inputs and
66 outputs may share the same name)
67 :param static_inputs: Loop can use static variables,
68 variables from the graph which runs the loop
69 :param new_outputs: if the loading fails, it might worth
70 cutting the graph, if not None, the graph will
71 be cut to have these new_outputs as the final outputs
72 :param new_opset: overwrite the main opset and replaces
73 by this new one
75 Among the possible runtime_options, there are:
76 * *enable_profiling*: enables profiling for :epkg:`onnxruntime`
77 * *session_options*: an instance of *SessionOptions* from
78 :epkg:`onnxruntime`
79 * *ir_version*: change ir_version
81 .. versionchanged:: 0.7
82 Parameters *new_outputs*, *new_opset* were added.
83 """
85 def __init__(self, onnx_or_bytes_or_stream, runtime=None,
86 skip_run=False, inplace=True,
87 input_inplace=False, ir_version=None,
88 target_opset=None, runtime_options=None,
89 session_options=None, inside_loop=False,
90 static_inputs=None, new_outputs=None, new_opset=None):
91 if isinstance(onnx_or_bytes_or_stream, bytes):
92 self.obj = load_model(BytesIO(onnx_or_bytes_or_stream))
93 elif isinstance(onnx_or_bytes_or_stream, BytesIO):
94 self.obj = load_model(onnx_or_bytes_or_stream)
95 elif isinstance(onnx_or_bytes_or_stream, str):
96 self.obj = load(onnx_or_bytes_or_stream)
97 elif hasattr(onnx_or_bytes_or_stream, 'graph'):
98 self.obj = onnx_or_bytes_or_stream
99 elif isinstance(onnx_or_bytes_or_stream, onnx_proto.GraphProto):
100 self.obj = make_model(onnx_or_bytes_or_stream,
101 producer_name='mlprodict')
102 else:
103 raise TypeError("Unable to handle type {}.".format( # pragma: no cover
104 type(onnx_or_bytes_or_stream)))
105 if ir_version is not None:
106 self.obj.ir_version = ir_version
107 if new_outputs is not None:
108 self.obj = select_model_inputs_outputs(
109 self.obj, outputs=new_outputs, infer_shapes=True)
110 if new_opset is not None:
111 self.obj = overwrite_opset(self.obj, new_opset)
113 self.runtime = runtime
114 self.skip_run = skip_run
115 self.input_inplace = input_inplace
116 self.inplace = inplace
117 self.force_target_opset = target_opset
118 self.runtime_options = runtime_options
119 self.inside_loop = inside_loop
120 self.static_inputs = static_inputs
121 self._init()
123 def __getstate__(self):
124 """
125 To pickle the object.
126 """
127 return {'onnx': self.obj.SerializeToString(),
128 'runtime': self.runtime,
129 'runtime_options': self.runtime_options,
130 'skip_run': self.skip_run,
131 'input_inplace': self.input_inplace,
132 'inplace': self.inplace,
133 'force_target_opset': self.force_target_opset,
134 'static_inputs': self.static_inputs,
135 'inside_loop': self.inside_loop}
137 def __setstate__(self, state):
138 """
139 To unpickle the object.
140 """
141 onx = state['onnx']
142 self.obj = load_model(BytesIO(onx))
143 self.runtime = state['runtime']
144 self.runtime_options = state['runtime_options']
145 self.skip_run = state['skip_run']
146 self.input_inplace = state['input_inplace']
147 self.inplace = state['inplace']
148 self.force_target_opset = state['force_target_opset']
149 self.static_inputs = state['static_inputs']
150 self.inside_loop = state['inside_loop']
151 self._init()
153 def _init(self):
154 """
155 Prepares the instance to deliver predictions.
156 """
157 self.graph_ = self.to_sequence()
158 if len(self.graph_['sequence']) == 0:
159 raise RuntimeError( # pragma: no cover
160 "No runnable nodes was found in the ONNX graph.")
161 self.outputs_ = self.graph_['outputs']
162 self.inputs_ = self.graph_['inputs']
164 for ino in [self.obj.graph.input, self.obj.graph.output]:
165 for xy in ino:
166 shape = xy.type.tensor_type.shape
167 for d in shape.dim:
168 if d.dim_value == 0 and "0" in str(d) and 'dim_param' not in str(d):
169 # d.dim_value returns 0 whether is is 0 or empty.
170 # it may be a parameter as well
171 raise RuntimeError( # pragma: no cover
172 "Wrong ONNX file, one input or output has an empty shape: "
173 "{}.".format(xy))
175 self.target_opset_ = self.graph_['targets']
176 if self.force_target_opset is not None:
177 if isinstance(self.force_target_opset, dict):
178 self.target_opset_ = self.force_target_opset # pragma: no cover
179 else:
180 self.target_opset_ = {'': self.force_target_opset}
181 self.ir_version_ = self.graph_['ir_version']
183 if not self.skip_run:
184 if self.runtime == 'onnxruntime1':
185 # Loads the onnx with onnxruntime as a single file.
186 del self.graph_
187 from .ops_whole.session import OnnxWholeSession
188 self._whole = OnnxWholeSession(
189 self.obj, self.runtime, self.runtime_options)
190 self._run = self._run_whole_runtime
191 else:
192 self.sequence_ = self.graph_['sequence']
193 self.inits_ = self.graph_['inits']
194 self.statics_ = self.graph_['statics']
195 dtype = self._guess_input_dtype()
196 variables = self.inits_.copy()
197 for node in self.sequence_:
198 domain = node.onnx_node.domain
199 target_opset = self.target_opset_.get(domain, None)
200 if self.runtime in ('onnxruntime2', 'empty'):
201 node.setup_runtime(self.runtime, variables, self.__class__,
202 target_opset=target_opset, dtype=dtype,
203 domain=domain, ir_version=self.ir_version_,
204 runtime_options=self.runtime_options)
205 else:
206 node.setup_runtime(self.runtime, variables, self.__class__,
207 target_opset=target_opset, domain=domain,
208 ir_version=self.ir_version_,
209 runtime_options=self.runtime_options)
210 if hasattr(node, 'ops_') and hasattr(node.ops_, 'typed_outputs_'):
211 for k, v in node.ops_.typed_outputs_:
212 variables[k] = v
213 self._run = self._run_sequence_runtime
215 if not self.skip_run and self.runtime in ('python', None):
216 self.shapes_ = self._set_shape_inference_runtime()
217 if self.inplace:
218 self.inplaces_ = self._guess_inplace(self.input_inplace)
219 self.exporters_ = OnnxInferenceExport(self)
220 self.to_json = self.exporters_.to_json
221 self.to_dot = self.exporters_.to_dot
222 self.to_python = self.exporters_.to_python
223 self.to_text = self.exporters_.to_text
224 self.to_onnx_code = self.exporters_.to_onnx_code
226 if self.runtime in ('python_compiled', 'python_compiled_debug'):
227 # switch the inference method to the compiled one
228 _, fct, code = self._build_compile_run('debug' in self.runtime)
229 setattr(self, '_run_compiled', fct)
230 setattr(self, '_run_compiled_code', code)
231 self._run = self._run_sequence_runtime_compiled
233 def _run_sequence_runtime_compiled(
234 self, inputs, clean_right_away=False, intermediate=False,
235 verbose=0, node_time=False, fLOG=None):
236 """
237 Executes a compiled version of @see me _run_sequence_runtime,
238 compiled with method @see me _build_compile_run.
239 Every parameter with a default value is ignored.
240 Switch to ``runtime='python'`` to enable those.
241 """
242 try:
243 return self._run_compiled(inputs) # pylint: disable=E1101
244 except NameError as e:
245 raise RuntimeError( # pragma: no cover
246 "Unable to compute prediction due to %r. Code:\n%s"
247 "" % (e, print_code(
248 self._run_compiled_code))) from e # pylint: disable=E1101
250 def _guess_input_dtype(self):
251 for _, v in self.graph_['inputs'].items():
252 if 'type' not in v:
253 continue # pragma: no cover
254 t = v['type']
255 if 'elem' not in t:
256 continue
257 if t['elem'] == 'double':
258 return numpy.float64
259 return numpy.float32
261 def __str__(self):
262 """
263 usual
264 """
265 rows = ['OnnxInference(...)']
266 if hasattr(self, '_run_compiled_code'):
267 rows.append(
268 textwrap.indent(
269 self._run_compiled_code, ' ')) # pylint: disable=E1101
270 else:
271 rows.append(textwrap.indent(str(self.obj), ' '))
272 return "\n".join(rows)
274 def __repr__(self):
275 """
276 usual
277 """
278 return "OnnxInference(...)" # pragma: no cover
280 def check_model(self):
281 """
282 Checks the model follow :epkg:`ONNX` conventions.
283 """
284 checker.check_model(self.obj)
286 def shape_inference(self):
287 """
288 Infers the shape of the outputs
289 with :epkg:`onnx` package.
291 @return A new :epkg:`ONNX` graph which defined outputs.
292 """
293 return shape_inference.infer_shapes(self.obj)
295 @property
296 def input_names(self):
297 """
298 Returns the names of all inputs.
299 It does not include the optional inputs.
301 .. versionchanged:: 0.6
302 The list does not include optional inputs anymore.
303 """
304 inits = set(_.name for _ in self.obj.graph.initializer)
305 return [_.name for _ in self.obj.graph.input if _.name not in inits]
307 @property
308 def input_names_shapes(self):
309 """
310 Returns the names and shapes of all inputs.
311 This method assumes all inputs are tensors.
312 It does not include the optional inputs.
314 .. versionchanged:: 0.6
315 The list does not include optional inputs anymore.
316 """
317 names = set(self.input_names)
318 return [(_.name, _var_as_dict(_)['type']['shape'])
319 for _ in self.obj.graph.input if _.name in names]
321 @staticmethod
322 def _get_type_property(info, prop):
323 if prop in info:
324 return info[prop]
325 if 'kind' in info and info['kind'] == 'sequence':
326 if prop == 'shape':
327 return ('?', )
328 raise NotImplementedError(
329 "Unable to retrieve property %r from %r."
330 "" % (prop, info))
332 @property
333 def input_names_shapes_types(self):
334 """
335 Returns the names, shapes, types of all inputs.
336 This method assumes all inputs are tensors.
337 It does not include the optional inputs.
339 .. versionchanged:: 0.6
340 The list does not include optional inputs anymore.
341 """
342 f = OnnxInference._get_type_property
343 names = set(self.input_names)
344 return [(_.name, f(_var_as_dict(_)['type'], 'shape'),
345 'tensor(%s)' % f(_var_as_dict(_)['type'], 'elem'))
346 for _ in self.obj.graph.input if _.name in names]
348 @property
349 def output_names(self):
350 """
351 Returns the names of all outputs.
352 """
353 return [_.name for _ in self.obj.graph.output]
355 @property
356 def output_names_shapes(self):
357 """
358 Returns the names and shapes of all outputs.
359 This method assumes all inputs are tensors.
360 """
361 f = OnnxInference._get_type_property
362 return [(_.name, f(_var_as_dict(_)['type'], 'shape'))
363 for _ in self.obj.graph.output]
365 @property
366 def output_names_shapes_types(self):
367 """
368 Returns the names, shapes, types of all outputs.
369 This method assumes all inputs are tensors.
370 It does not include the optional outputs.
372 .. versionadd:: 0.7
373 """
374 names = set(self.output_names)
375 f = OnnxInference._get_type_property
376 return [(_.name, f(_var_as_dict(_)['type'], 'shape'),
377 'tensor(%s)' % f(_var_as_dict(_)['type'], 'elem'))
378 for _ in self.obj.graph.output if _.name in names]
380 def global_index(self, name):
381 """
382 Maps every name to one integer to avoid using dictionaries
383 when running the predictions.
385 @param name outputs name
386 @return integer
387 """
388 if not hasattr(self, '_global_index'):
389 self._global_index = {}
390 if name in self._global_index:
391 return self._global_index[name]
392 self._global_index[name] = len(self._global_index)
393 return self._global_index[name]
395 def to_sequence(self):
396 """
397 Produces a graph to facilitate the execution.
399 One example:
401 .. exref::
402 :title: Convert ONNX into graph
404 An example on how to convert an :epkg:`ONNX`
405 graph into a graph.
407 .. runpython::
408 :showcode:
409 :warningout: DeprecationWarning
411 import pprint
412 import numpy
413 from skl2onnx.algebra.onnx_ops import OnnxLinearRegressor
414 from skl2onnx.common.data_types import FloatTensorType
415 from mlprodict.onnxrt import OnnxInference
417 pars = dict(coefficients=numpy.array([1., 2.]),
418 intercepts=numpy.array([1.]),
419 post_transform='NONE')
420 onx = OnnxLinearRegressor('X', output_names=['Y'], **pars)
421 model_def = onx.to_onnx({'X': pars['coefficients'].astype(numpy.float32)},
422 outputs=[('Y', FloatTensorType([1]))],
423 target_opset=12)
424 oinf = OnnxInference(model_def)
425 pprint.pprint(oinf.to_sequence())
427 See an example of representation in notebook
428 :ref:`onnxvisualizationrst`.
429 """
430 inits = {}
431 variables = {}
432 outputs = {}
433 nodes = {}
434 statics = {}
435 targets = {}
436 for o in self.obj.opset_import:
437 targets[o.domain] = o.version
439 # static variables
440 if self.static_inputs is not None:
441 for n in self.static_inputs:
442 statics[n] = {'name': n}
443 self.global_index(n)
445 # inputs
446 for obj in self.obj.graph.input:
447 variables[obj.name] = _var_as_dict(obj)
448 self.global_index(obj.name)
450 # outputs
451 for obj in self.obj.graph.output:
452 if hasattr(obj, 'type') and str(obj.type) != '':
453 outputs[obj.name] = _var_as_dict(obj)
454 else:
455 outputs[obj.name] = {'name': obj.name}
456 self.global_index(obj.name)
458 # initializer
459 for obj in self.obj.graph.initializer:
460 init_obj = _var_as_dict(obj)
461 if init_obj is None:
462 raise RuntimeError( # pragma: no cover
463 "Unable to convert an initializer\n{}".format(obj))
464 inits[obj.name] = init_obj
465 self.global_index(obj.name)
466 if 'value' not in inits[obj.name]:
467 raise RuntimeError( # pragma: no cover
468 "One initializer has no value: '{}'\n{}\n{}".format(
469 obj.name, inits[obj.name], obj))
471 # nodes
472 for node in self.obj.graph.node:
473 dobj = _var_as_dict(node)
474 if dobj is None:
475 raise RuntimeError( # pragma: no cover
476 "Unable to convert a node\n{}".format(node))
477 if 'atts' in dobj:
478 atts = dobj['atts']
479 for k, v in atts.items():
480 if not isinstance(v, dict) or 'value' not in v:
481 raise RuntimeError( # pragma: no cover
482 "A parameter has no (sparse) value '{}' "
483 "for node '{}'\nv={}\ndobj=[{}]".format(
484 k, node.name, v, node))
485 if node.name in nodes: # pragma: no cover
486 i = 2
487 while True:
488 new_name = "%s_n%i" % (node.name, i)
489 if new_name not in nodes:
490 break
491 i += 1
492 else:
493 new_name = node.name
494 nodes[new_name] = OnnxInferenceNode(node, dobj, self.global_index)
496 # names
497 names = {}
498 for k, v in statics.items():
499 if (k, 0) in names:
500 raise RuntimeError( # pragma: no cover
501 "Static variables '{}' already exists (tag='{}').".format(
502 k, names[k, 0][0]))
503 names[k, 0] = ('S', v)
504 for k, v in inits.items():
505 if (k, 0) in names:
506 raise RuntimeError( # pragma: no cover
507 "Initializer '{}' already exists (tag='{}').".format(
508 k, names[k, 0][0]))
509 names[k, 0] = ('C', v)
510 for k, v in variables.items():
511 if (k, 0) in names:
512 if k in inits:
513 # Kind of default value for an input
514 continue
515 raise RuntimeError( # pragma: no cover
516 "Variable '{}' already exists (tag='{}').".format(
517 k, names[k, 0][0]))
518 names[k, 0] = ('I', v)
519 for k, v in outputs.items():
520 if (k, 0) in names and self.runtime != 'empty':
521 if not self.inside_loop or names[k, 0][0] != 'I':
522 raise RuntimeError( # pragma: no cover
523 "Output '{}' already exists (tag='{}').".format(
524 k, names[k, 0][0]))
525 else:
526 # For input, output sharing the same name, we marked the name
527 # as an input.
528 continue
529 names[k, 0] = ('O', v)
530 for k, v in nodes.items():
531 if (k, 1) in names:
532 raise RuntimeError( # pragma: no cover
533 "Node '{}' already exists (tag='{}'). "
534 "Use inside_loop=True to bypass this exception.".format(
535 k, names[k, 0][0]))
536 names[k, 1] = ('N', v)
538 # ordering
539 order = {}
540 modif = 1
541 intermediate = {}
542 while modif > 0:
543 modif = 0
544 for (k, _), v in names.items():
545 if (k, 1) in order:
546 # The operator node is already processed.
547 continue
548 if v[0] in {'I', 'C', 'S'}:
549 if (k, 0) not in order:
550 order[k, 0] = len(order) # A data node.
551 modif += 1
552 continue
553 if v[0] == 'O':
554 continue
555 if all((inp, 0) in order for inp in v[1].inputs):
556 # If all inputs are available,
557 # We tell the operator node is processed.
558 order[k, 1] = len(order)
559 modif += 1
560 for o in v[1].outputs:
561 if (o, 0) in order:
562 raise RuntimeError( # pragma: no cover
563 "Two nodes share the same output '{}' "
564 "or an operator and an output "
565 "share the same name. "
566 "(node: {}).".format(o, v[1]))
567 # We add a data node.
568 order[o, 0] = len(order)
569 intermediate[o] = None
570 modif += 1
572 # compute
573 rev = [(v, k[0], k[1]) for k, v in order.items()]
574 rev.sort()
575 sequence = []
576 for _, name, node_kind in rev:
577 if name not in nodes:
578 continue
579 if node_kind == 0:
580 # It is an output which shares the same name
581 # as a node.
582 continue
583 node = nodes[name]
584 node.set_order(len(sequence))
585 sequence.append(node)
587 if len(sequence) == 0:
588 raise RuntimeError( # pragma: no cover
589 "No runnable nodes was found in the ONNX graph"
590 "\n--rev--\n{}"
591 "\n--order--\n{}"
592 "\n--nodes--\n{}"
593 "\n---".format(
594 "\n".join([str(_) for _ in names.items()]),
595 "\n".join([str(_) for _ in order.items()]),
596 "\n".join([str(_) for _ in nodes.items()])))
598 # defines where an intermediare output is not needed
599 last_used = {}
600 for node in sequence:
601 for inp in node.inputs:
602 last_used[inp] = node.order
603 for k, ord in last_used.items():
604 sequence[ord].add_variable_to_clean(k)
606 results = dict(inits=inits, inputs=variables, outputs=outputs,
607 nodes=nodes, sequence=sequence,
608 intermediate=intermediate,
609 targets=targets, ir_version=self.obj.ir_version,
610 statics=statics)
611 if len(sequence) < len(nodes):
612 # Not all node will be executed.
613 raise RuntimeError(
614 "Unable to run all nodes.\n--Nodes--\n%s\n--Sequence--\n%s"
615 "\n--Inputs--\n%s\n--Inits--\n%s\n--Statics\n%s"
616 "" % (pprint.pformat(nodes), pprint.pformat(sequence),
617 pprint.pformat(list(variables)),
618 pprint.pformat(list(inits)),
619 pprint.pformat(list(statics))))
620 return results
622 def run(self, inputs, clean_right_away=False,
623 intermediate=False, verbose=0, node_time=False,
624 overwrite_types=None, fLOG=None):
625 """
626 Computes the predictions for this :epkg:`onnx` graph.
628 :param inputs: inputs as dictionary or a dataframe
629 :param clean_right_away: clean the intermediate outputs
630 as soon as they are not needed
631 :param intermediate: returns a dictionary of intermediate
632 variables instead of the results only
633 :param verbose: display information while predicting
634 :param node_time: measure time of each node
635 :param overwrite_types: shape inference does not work all the time,
636 this allows to force types when building intermediate
637 results, see @see fn select_model_inputs_outputs
638 :param fLOG: logging function if *verbose > 0*
639 :return: outputs as dictionary
640 and a second dictionary of the time spent
641 in each node if *node_time* is True
643 .. exref::
644 :title: Computes predictions with any runtime
646 The following example compares predictions
647 between :epkg:`scikit-learn` and this runtime
648 for the python runtime.
650 .. runpython::
651 :showcode:
652 :warningout: DeprecationWarning
654 import numpy
655 from sklearn.linear_model import LinearRegression
656 from sklearn.datasets import load_iris
657 from sklearn.model_selection import train_test_split
658 from mlprodict.onnxrt import OnnxInference
659 from mlprodict.onnx_conv import to_onnx
661 iris = load_iris()
662 X, y = iris.data, iris.target
663 X_train, X_test, y_train, _ = train_test_split(X, y)
664 clr = LinearRegression()
665 clr.fit(X_train, y_train)
667 exp = clr.predict(X_test[:5])
668 print(exp)
670 model_def = to_onnx(clr, X_train.astype(numpy.float32),
671 target_opset=12)
672 oinf = OnnxInference(model_def)
673 y = oinf.run({'X': X_test[:5]})
674 print(y)
676 The function returns all intermediate outputs
677 if *intermediate* is True. In case of runtime
678 *onnxruntime1*, if intermediate is True,
679 the first class builds all :epkg:`ONNX` cut out
680 to keep the one output and converted into
681 *OnnxInference*.
682 """
683 def retype(col_array):
684 if (hasattr(col_array, 'categories') and
685 hasattr(col_array, 'from_codes')):
686 # isinstance(col_array, pandas.Categorical):
687 return col_array.astype(numpy.int64)
688 return col_array
690 if hasattr(inputs, 'columns') and hasattr(inputs, 'iloc'):
691 # == isinstance(inputs, pandas.DataFrame)
692 inputs = OrderedDict((
693 name, retype(numpy.expand_dims(inputs[name].values, axis=1)))
694 for name in inputs.columns)
695 if intermediate:
696 if self.inplace:
697 raise RuntimeError( # pragma: no cover
698 "inplace must be False if intermediate is True, a container "
699 "might be used by several nodes.")
700 return self._run(inputs, clean_right_away=False,
701 intermediate=intermediate,
702 verbose=verbose, node_time=node_time,
703 overwrite_types=overwrite_types,
704 fLOG=fLOG)
705 if overwrite_types is not None:
706 raise RuntimeError( # pragma: no cover
707 "overwrite_types is not used if intermediate is False.")
708 return self._run(inputs, clean_right_away=False,
709 intermediate=intermediate,
710 verbose=verbose, node_time=node_time,
711 fLOG=fLOG)
713 def run2onnx(self, inputs, verbose=0, fLOG=None,
714 as_parameter=True, suffix='_DBG',
715 param_name=None, node_type='DEBUG',
716 domain='DEBUG', domain_opset=1):
717 """
718 Executes the graphs with the given inputs, then adds the intermediate
719 results into ONNX nodes in the original graph. Once saved, it can be
720 looked with a tool such as :epkg:`netron`.
722 :param inputs: inputs as dictionary or a dataframe
723 :param verbose: display information while predicting
724 :param fLOG: logging function if *verbose > 0*
725 :param as_parameter: add new nodes with results as one parameter
726 (True) or as initializer (False)
727 :param suffix: suffix to add to new results
728 :param param_name: name of the parameter to add
729 (by default the result name), it can be a function
730 `param_name(reult_name) -> parameter_name`
731 :param node_type: type of the new node
732 :param domain: domain the new node
733 :param domain_opset: opset for *domain*
734 :return: outputs as dictionary
735 and the onnx graph with new nodes
737 The following example shows how to use it.
739 .. gdot::
740 :script: DOT-SECTION
742 from sklearn.linear_model import LinearRegression
743 from sklearn.datasets import load_iris
744 from mlprodict.onnxrt import OnnxInference
745 import numpy
747 iris = load_iris()
748 X = iris.data[:, :2]
749 y = iris.target
750 lr = LinearRegression()
751 lr.fit(X, y)
753 from mlprodict.onnx_conv import to_onnx
754 model_onnx = to_onnx(lr, X.astype(numpy.float32))
755 oinf = OnnxInference(model_onnx, inplace=False)
757 model_onnx_debug = oinf.run2onnx({'X': X[:3].astype(numpy.float32)})
758 oinf_debug = OnnxInference(model_onnx_debug[1])
760 print("DOT-SECTION", oinf_debug.to_dot())
762 .. versionadded:: 0.7
763 """
764 intermediate = self.run(inputs, verbose=verbose, fLOG=fLOG,
765 intermediate=True)
766 for name in self.input_names:
767 del intermediate[name]
768 new_onx = insert_results_into_onnx(
769 self.obj, intermediate, as_parameter=as_parameter,
770 suffix=suffix, param_name=param_name, node_type=node_type,
771 domain=domain, domain_opset=domain_opset)
772 return intermediate, new_onx
774 def display_sequence(self, verbose=1):
775 """
776 Shows the sequence of nodes to run if ``runtime=='python'``.
777 """
778 rows = []
779 rows.append("#node: {}".format(len(self.sequence_)))
780 for i, node in enumerate(self.sequence_):
781 if verbose >= 1:
782 rows.append("{}: {}".format(i, str(node)))
783 return "\n".join(rows)
785 def _run_sequence_runtime(self, inputs, clean_right_away=False,
786 intermediate=False, verbose=0, node_time=False,
787 overwrite_types=None, fLOG=None):
788 if overwrite_types is not None:
789 raise NotImplementedError( # pragma: no cover
790 "overwrite_types != None not implemented.")
791 if clean_right_away:
792 raise NotImplementedError( # pragma: no cover
793 "clean_right_away=true not implemented.")
795 if node_time:
796 mtime = []
797 if verbose >= 1 and fLOG is not None:
798 printed = set()
800 if hasattr(self, "_values_init"):
801 values = self._values_init.copy() # pylint: disable=E0203
802 else:
803 values = [None] * len(self._global_index)
804 if verbose >= 1 and fLOG is not None:
805 for k, v in self.inits_.items():
806 values[self._global_index[k]] = v['value']
807 if verbose < 3:
808 fLOG("+ki='{}': {} (dtype={} min={} max={})".format(
809 k, v['value'].shape, v['value'].dtype,
810 numpy_min(v['value']), numpy_max(v['value'])))
811 else:
812 fLOG("+ki='{}': {} (dtype={} min={} max={}\n{}".format(
813 k, v['value'].shape, v['value'].dtype,
814 numpy_min(v['value']), numpy_max(v['value']),
815 v['value']))
816 printed.add(k)
817 else:
818 for k, v in self.inits_.items():
819 values[self._global_index[k]] = v['value']
820 # stores the array to skip initialing a second time
821 if verbose == 0 or fLOG is None:
822 self._values_init = values.copy()
824 for name, value in inputs.items():
825 values[self._global_index[name]] = value
827 if verbose == 0 or fLOG is None:
828 if node_time:
829 for i, node in enumerate(self.sequence_):
830 t = perf_counter()
831 node.run(values)
832 t2 = perf_counter()
833 mtime.append(dict(i=i, name=node.onnx_node.name,
834 op_type=node.onnx_node.op_type,
835 time=t2 - t))
836 else:
837 for node in self.sequence_:
838 node.run(values)
839 else:
840 def dispsimple(arr):
841 if hasattr(arr, 'shape'):
842 if len(arr.shape) <= 1:
843 threshold = 8
844 else:
845 threshold = min(
846 50, min(50 // max(arr.shape[1], 1), 8) * arr.shape[1])
847 if hasattr(arr, 'todense'):
848 fLOG( # pragma: no cover
849 numpy.array2string(arr.todense(), max_line_width=120,
850 suppress_small=True, threshold=threshold))
851 else:
852 fLOG(numpy.array2string(arr, max_line_width=120,
853 suppress_small=True,
854 threshold=threshold))
855 else: # pragma: no cover
856 s = str(arr)
857 if len(s) > 50:
858 s = s[:50] + "..."
859 fLOG(s)
861 if verbose >= 2:
862 for k in sorted(self._global_index):
863 if values[self._global_index[k]] is None:
864 continue
865 obj = values[self._global_index[k]]
866 if k not in printed:
867 printed.add(k)
868 if hasattr(obj, 'shape'):
869 fLOG("-kv='{}' shape={} dtype={} min={} max={}{}".format(
870 k, obj.shape, obj.dtype, numpy_min(obj),
871 numpy_max(obj),
872 ' (sparse)' if isinstance(obj, coo_matrix) else ''))
873 elif (isinstance(obj, list) and len(obj) > 0 and
874 not isinstance(obj[0], dict)): # pragma: no cover
875 fLOG("-kv='{}' list len={}".format(k, len(obj)))
876 if verbose >= 3 and len(obj) > 0:
877 fLOG("first={} last={}".format(
878 obj[0], obj[-1]))
879 else: # pragma: no cover
880 fLOG("-kv='{}' type={}".format(k, type(obj)))
882 keys = set(k for k in range(len(values)) if values[k] is not None)
883 if verbose >= 1:
884 fLOG("-- OnnxInference: run {} nodes".format(len(self.sequence_)))
885 for i, node in enumerate(self.sequence_):
886 if verbose >= 1:
887 fLOG(node)
888 if node_time:
889 t = perf_counter()
890 node.run(values)
891 t2 = perf_counter()
892 mtime.append(dict(i=i, name=node.onnx_node.name,
893 op_type=node.onnx_node.op_type,
894 time=t2 - t))
895 else:
896 node.run(values)
897 added = 0
898 for k in range(len(values)): # pylint: disable=C0200
899 if values[k] is None:
900 continue
901 if k not in keys and k not in printed:
902 added += 1
903 printed.add(k)
904 name = list(
905 name for name in self._global_index # pylint: disable=C0206
906 if self._global_index[name] == k)
907 if isinstance(values[k], (numpy.ndarray, coo_matrix)):
908 name = name[0]
909 mini = numpy_min(values[k])
910 maxi = numpy_max(values[k])
911 fLOG("+kr{}'{}': {} (dtype={} min={} max={}{})".format(
912 "=" if len(values[k].shape) == 0 or min(
913 values[k].shape) > 0 else "*",
914 name, values[k].shape, values[k].dtype,
915 mini, maxi,
916 ' sparse' if isinstance(values[k], coo_matrix) else ''))
917 if verbose >= 3:
918 dispsimple(values[k])
919 else:
920 fLOG("+kr='{}': {}".format(
921 name, type(values[k])))
922 if verbose >= 3: # pragma: no cover
923 dispsimple(values[k])
924 if added == 0:
925 fLOG("? no new result")
927 if intermediate:
928 values = [(v, k, values[v]) for k, v in self._global_index.items()]
929 values.sort()
930 values = OrderedDict((k, v) for _, k, v in values)
931 return (values, mtime) if node_time else values
933 try:
934 res = {k: values[self._global_index[k]] for k in self.outputs_}
935 except KeyError as e: # pragma: no cover
936 raise RuntimeError("Unable to find one output [{}]\n in [{}]"
937 ".".format(", ".join(sorted(self.outputs_)),
938 ", ".join(sorted(values)))) from e
939 return (res, mtime) if node_time else res
941 def build_intermediate(self, outputs=None, verbose=0, overwrite_types=None,
942 fLOG=None):
943 """
944 Builds every possible :epkg:`ONNX` file
945 which computes one specific intermediate output
946 from the inputs.
948 :param outputs: subsets of outputs to get,
949 None to get all outputs,
950 :param overwrite_types: shape inference does not work all the time,
951 this allows to force types when building intermediate
952 results, see @see fn select_model_inputs_outputs
953 :param verbose: displays intermediate information
954 :param fLOG: logging function
955 :return: :epkg:`*py:collections:OrderedDict`
957 .. versionchanged: 0.6
958 """
959 if verbose > 0:
960 fLOG('[build_intermediate] BEGIN.')
961 if outputs is not None:
962 if isinstance(outputs, str):
963 outputs = [outputs]
964 if not isinstance(outputs, set):
965 outputs = set(outputs)
966 ord = OrderedDict()
967 for output in enumerate_model_node_outputs(self.obj, order=True):
968 if outputs is not None and output not in outputs:
969 continue
970 subonx = select_model_inputs_outputs(
971 self.obj, outputs=output, infer_shapes=True,
972 overwrite=overwrite_types)
973 subonx = onnx_remove_node_unused(subonx)
974 if verbose > 0:
975 fLOG('[build_intermediate] + {}'.format(output))
976 ord[output] = OnnxInference(subonx, runtime=self.runtime,
977 skip_run=self.skip_run,
978 runtime_options=self.runtime_options,
979 inplace=self.inplace,
980 input_inplace=self.input_inplace)
981 if verbose > 0:
982 fLOG('[build_intermediate] END.')
983 return ord
985 def _run_whole_runtime(self, inputs, clean_right_away=False,
986 intermediate=False, verbose=0, node_time=False,
987 overwrite_types=None, fLOG=None):
988 # node_time is unused
989 if clean_right_away:
990 raise RuntimeError( # pragma: no cover
991 "clean_right_away=true does not work with this runtime.")
992 if intermediate:
993 if hasattr(self, "intermediate_onnx_inference_"):
994 inter_run = self.intermediate_onnx_inference_ # pylint: disable=E0203
995 else:
996 if verbose > 0:
997 fLOG("-- OnnxInference: build intermediate")
998 inter_run = self.build_intermediate(
999 verbose=verbose, fLOG=fLOG, overwrite_types=overwrite_types)
1000 self.intermediate_onnx_inference_ = inter_run
1001 graph = self.to_sequence()
1002 self.inits_ = graph['inits']
1004 if verbose >= 1:
1005 fLOG("-- OnnxInference: run {} nodes".format(
1006 len(self.intermediate_onnx_inference_)))
1007 values = OrderedDict(inputs)
1008 for k, v in self.inits_.items():
1009 values[k] = v['value']
1010 if verbose >= 2: # pragma: no cover
1011 for k in sorted(values):
1012 fLOG("-k='{}' shape={} dtype={}".format(
1013 k, values[k].shape, values[k].dtype))
1014 for node, oinf in self.intermediate_onnx_inference_.items():
1015 if verbose >= 4:
1016 fLOG('[intermediate] %r' % node)
1017 if verbose >= 5: # pragma: no cover
1018 fLOG(oinf.obj)
1019 output = oinf.run(inputs)[node]
1020 values[node] = output
1021 if verbose >= 1:
1022 if verbose >= 4:
1023 for k, v in inputs.items():
1024 if isinstance(output, numpy.ndarray):
1025 fLOG("-i='{}': {} (dtype={}) {}".format(
1026 k, v.shape, v.dtype, v.ravel().tolist()))
1027 else:
1028 fLOG("-i='{}': {} (dtype={}) - ?".format(
1029 k, v.shape, v.dtype))
1030 if isinstance(output, numpy.ndarray):
1031 fLOG("+k='{}': {} (dtype={})".format(
1032 node, output.shape, output.dtype))
1033 if verbose >= 2:
1034 fLOG(output)
1035 else:
1036 fLOG("+k='{}': {}".format( # pragma: no cover
1037 node, type(output)))
1038 if verbose >= 2:
1039 fLOG(output)
1040 return values
1042 if verbose != 0:
1043 warnings.warn(
1044 "verbose option not implemented if runtime is 'onnxruntime1'")
1045 res = self._whole.run(inputs)
1046 return {k: v for k, v in zip(self.outputs_, res)}
1048 def __getitem__(self, item):
1049 """
1050 Returns the ONNX verions of a node.
1051 """
1052 if isinstance(item, tuple):
1053 node_name, att_name = item
1054 else:
1055 node_name = item
1056 att_name = None
1058 node_ = None
1059 for node in self.obj.graph.node:
1060 if node.name == node_name:
1061 node_ = node
1062 break
1064 if node_ is None:
1065 raise IndexError( # pragma: no cover
1066 "Unable to get node name '{}'.\n{}".format(
1067 node_name, "\n".join(node.name for node in self.obj.graph.node)))
1069 if att_name is None:
1070 return node_
1072 for att in node_.attribute:
1073 if att.name == att_name:
1074 return att
1076 raise IndexError( # pragma: no cover
1077 "Unable to find attribute '{}' from node "
1078 "'{}'.".format(att_name, node_name))
1080 def switch_initializers_dtype(self, model=None,
1081 dtype_in=numpy.float32,
1082 dtype_out=numpy.float64):
1083 """
1084 Switches all initializers to ``numpy.float64``. If *model*
1085 is None, a simple cast is done. Otherwise, the function assumes
1086 the model is a :epkg:`scikit-learn` pipeline.
1087 This only works if the runtime is ``'python'``.
1089 @param model :epkg:`scikit-learn` model or None
1090 @param dtype_in previous type
1091 @param dtype_out next type
1092 @return done operations
1093 """
1094 from ..onnx_tools.optim.sklearn_helper import enumerate_fitted_arrays, pairwise_array_distances
1096 if self.runtime != 'python': # pragma: no cover
1097 raise RuntimeError("Initializers can be casted only if the "
1098 "runtime is 'python' not '{}'.".format(self.runtime))
1100 if hasattr(self, '_values_init'):
1101 del self._values_init
1103 # first pass: simple cast
1104 done = []
1105 initializer = self.inits_
1106 for k, v in initializer.items():
1107 if isinstance(v['value'], numpy.ndarray):
1108 if v['value'].dtype == dtype_in:
1109 v['value'] = v['value'].astype(dtype_out)
1110 done.append(("pass1", "+", "init", k, v['value']))
1111 else:
1112 done.append(("pass1", "-", "init", k,
1113 v['value'])) # pragma: no cover
1114 for k, v in self.graph_['nodes'].items():
1115 res = v.switch_initializers_dtype(dtype_in=dtype_in,
1116 dtype_out=dtype_out)
1117 for r in res:
1118 done.append(("pass1", "node", k) + r)
1119 for k, v in self.graph_['intermediate'].items():
1120 if v is None:
1121 continue
1122 res = v.switch_initializers_dtype(dtype_in=dtype_in,
1123 dtype_out=dtype_out)
1124 for r in res:
1125 done.append(("pass1", "sub", k) + r)
1127 if model is not None:
1128 # Second pass, we compare all arrays from the model
1129 # to the arrays in the converted models.
1130 def dist(a):
1131 cast = a.astype(dtype_in).astype(dtype_out)
1132 d = pairwise_array_distances([cast], [a])[0, 0]
1133 return d
1135 done_ = [(c, c[-1]) for c in done]
1136 moda_ = [(a, a[-2][-1]) for a in enumerate_fitted_arrays(model)
1137 if dist(a[-2][-1]) > 0]
1138 aconv = [_[-1] for _ in done_]
1139 amoda = [_[-1] for _ in moda_]
1140 distances = pairwise_array_distances(aconv, amoda)
1142 for i in range(distances.shape[0]):
1143 j = numpy.argmin(distances[i])
1144 d = distances[i, j]
1145 if d < 0.1:
1146 numpy.copyto(aconv[i], amoda[j])
1147 done.append(("pass2", d) + done_[i][0])
1149 return done
1151 def _set_shape_inference_runtime(self):
1152 """
1153 Set shapes based on shape inference
1154 relying on the runtime.
1155 The values are stored in every node.
1156 """
1157 if not hasattr(self, 'sequence_') or not hasattr(self, 'inputs_'):
1158 raise RuntimeError( # pragma: no cover
1159 "This method only works if the runtime is 'python' not "
1160 "'{}'.".format(self.runtime))
1161 values = OrderedDict()
1162 for k, v in self.inputs_.items():
1163 # The function assumes the first dimension is unknown
1164 # and is the batch size.
1165 try:
1166 values[k] = ShapeObject(v, use_n1=True, name=k)
1167 except TypeError as e:
1168 raise TypeError(
1169 "Unable to guess shape for %r (shape=%r)." % (k, v)) from e
1171 impossible = False
1172 for k, v in self.statics_.items():
1173 # static inputs should be known.
1174 try:
1175 values[k] = ShapeObject(v)
1176 except TypeError:
1177 # default value is wrong
1178 impossible = True
1179 values[k] = None
1181 for k, v in self.inits_.items():
1182 values[k] = ShapeObject(v['value'], name=k)
1183 last = None
1184 for i, node in enumerate(self.sequence_):
1185 try:
1186 s = node._set_shape_inference_runtime(values)
1187 last = s
1188 except (IndexError, TypeError, KeyError,
1189 AttributeError) as e: # pragma: no cover
1190 rows = []
1191 if last is not None:
1192 for k, v in last.items():
1193 rows.append("{}: {}".format(k, v))
1194 for k in range(i + 1):
1195 rows.append("{} --> {}".format(k, self.sequence_[k]))
1196 if not impossible:
1197 raise RuntimeError("Unable to infer shape of node {}\n{}".format(
1198 i, '\n'.join(rows))) from e
1199 return values
1201 def infer_shapes(self):
1202 """
1203 Computes expected shapes.
1205 :return: dictionary of shapes
1206 """
1207 return self._set_shape_inference_runtime()
1209 def _set_type_inference_runtime(self):
1210 """
1211 Set types based on type inference
1212 relying on the runtime.
1213 The values are stored in every node.
1214 """
1215 if not hasattr(self, 'sequence_') or not hasattr(self, 'inputs_'):
1216 raise RuntimeError( # pragma: no cover
1217 "This method only works if the runtime is 'python' not "
1218 "'{}'.".format(self.runtime))
1219 values = OrderedDict()
1220 for k, v in self.statics_.items():
1221 values[k] = None
1222 for k, v in self.inputs_.items():
1223 # The function assumes the first dimension is unknown
1224 # and is the batch size.
1225 if isinstance(v['type']['elem'], dict):
1226 # sequence
1227 values[k] = SequenceType()
1228 else:
1229 values[k] = guess_numpy_type_from_string(v['type']['elem'])
1230 for k, v in self.inits_.items():
1231 values[k] = v['value'].dtype
1232 last = None
1233 for i, node in enumerate(self.sequence_):
1234 try:
1235 s = node._set_type_inference_runtime(values)
1236 last = s
1237 except IndexError as e: # pragma: no cover
1238 rows = []
1239 if last is not None:
1240 for k, v in last.items():
1241 rows.append("{}: {}".format(k, v))
1242 for k in range(i + 1):
1243 rows.append("{} --> {}".format(k, self.sequence_[k]))
1244 raise RuntimeError("Unable to infer type of node {}\n{}".format(
1245 i, '\n'.join(rows))) from e
1246 return values
1248 def infer_types(self):
1249 """
1250 Computes expected shapes.
1252 :return: dictionary of types
1253 """
1254 return self._set_type_inference_runtime()
1256 def _set_size_inference_runtime(self, inputs, context=None):
1257 """
1258 Set sizes allocated during inference
1259 relying on the runtime.
1260 The values are stored in every node.
1261 """
1262 if not hasattr(self, 'sequence_') or not hasattr(self, 'inputs_'):
1263 raise RuntimeError( # pragma: no cover
1264 "This method only works if the runtime is 'python' not "
1265 "'{}'.".format(self.runtime))
1266 values = OrderedDict()
1267 for k, v in self.statics_.items():
1268 if context is None:
1269 raise RuntimeError( # pragma: no cover
1270 "static variable but context is None.")
1271 values[k] = context[k]
1272 for k, v in self.inits_.items():
1273 values[k] = v['value']
1274 for k, v in self.inputs_.items():
1275 if k in inputs:
1276 values[k] = inputs[k]
1278 last = None
1279 for i, node in enumerate(self.sequence_):
1280 try:
1281 s = node._set_size_inference_runtime(values)
1282 last = s
1283 except IndexError as e: # pragma: no cover
1284 rows = []
1285 if last is not None:
1286 for k, v in last.items():
1287 rows.append("{}: {}".format(k, v))
1288 for k in range(i + 1):
1289 rows.append("{} --> {}".format(k, self.sequence_[k]))
1290 raise RuntimeError("Unable to infer size of node {}\n{}".format(
1291 i, '\n'.join(rows))) from e
1292 return values
1294 def infer_sizes(self, inputs, context=None):
1295 """
1296 Computes expected sizes.
1298 :param inputs: inputs as a dictionary
1299 :return: dictionary of dictionary of sizes
1300 """
1301 res = self._set_size_inference_runtime(inputs, context=context)
1302 return {k: v for k, v in res.items() if k.startswith('#')}
1304 def _guess_inplace(self, input_inplace=False):
1305 """
1306 Looks into every node of the graph to see
1307 if there is a way to do the computation
1308 inplace. By default (*input_inplace=False*),
1309 the function assumes inputs cannot be modified
1310 so the first node cannot do inplace computation.
1311 This function only works with the python runtime.
1313 @param input_inplace the computation is allowed
1314 to overwrite the input
1316 This function checks that one node is used only
1317 once and then can be modified by the next node.
1318 Nodes `A`, `C` can be overwritten by the computation.
1319 Node `B` cannot as it is used by two nodes.
1321 .. blockdiag::
1323 diagram {
1324 A -> B -> C -> E;
1325 B -> D;
1326 }
1328 It does not handle specific case such node `B` being
1329 overwritten by node `C` but without changing its shape
1330 and node `D` only needs the shape of `B`. Then `B` could
1331 be overwritten as well.
1332 """
1333 forbid = {}
1334 values = OrderedDict()
1335 for k in self.statics_:
1336 values[k] = dict(inplace=False, to=[], fr=[])
1337 for k in self.inputs_:
1338 values[k] = dict(inplace=input_inplace, to=[], fr=[])
1339 for k in self.inits_:
1340 values[k] = dict(inplace=False, to=[], fr=[])
1341 for node in self.sequence_:
1342 for n in node.inputs:
1343 values[n]['to'].append(node)
1344 for n in node.outputs:
1345 if node.op_type == 'Constant':
1346 # We cannot modify constant.
1347 forbid[n] = node
1348 if n not in values:
1349 values[n] = dict(inplace=None, to=[], fr=[])
1350 values[n]['fr'].append(node)
1352 # checks the number of outputs
1353 outputs = set(self.output_names)
1354 modif = 1
1355 while modif > 0:
1356 modif = 0
1357 for n, v in values.items():
1358 if v['inplace'] is not None:
1359 continue
1360 if n in forbid:
1361 continue
1362 if len(v['to']) == 1:
1363 v['inplace'] = True
1364 modif += 1
1366 # convey the information to every node
1367 inplaces = {}
1368 for n, v in values.items():
1369 if v['inplace']:
1370 inplaces[n] = v
1371 for node in v['to']:
1372 if n in outputs:
1373 continue
1374 node.enable_inplace_compute(n)
1376 return inplaces
1378 def _build_compile_run(self, debug=False):
1379 """
1380 Rewrite the run function in python,
1381 compiles it, and adds it as a method.
1383 @param debug insert debugging code
1384 @return method name, callable object
1386 .. exref::
1387 :title: Run a model with runtime 'python_compiled'
1389 The following code trains a model and compute
1390 the predictions with runtime ``'python_compiled'``.
1391 It converts the onnx graph into a python function
1392 which calls every operator. Its code is printed
1393 below.
1395 .. runpython::
1396 :showcode:
1397 :warningout: DeprecationWarning
1399 from sklearn.datasets import load_iris
1400 from sklearn.model_selection import train_test_split
1401 from sklearn.ensemble import AdaBoostClassifier
1402 from sklearn.tree import DecisionTreeClassifier
1403 from skl2onnx import to_onnx
1404 from mlprodict.onnxrt import OnnxInference
1406 iris = load_iris()
1407 X, y = iris.data, iris.target
1408 X_train, X_test, y_train, __ = train_test_split(X, y, random_state=11)
1409 y_train = y_train.astype(numpy.float32)
1410 clr = AdaBoostClassifier(
1411 base_estimator=DecisionTreeClassifier(max_depth=3),
1412 n_estimators=3)
1413 clr.fit(X_train, y_train)
1415 model_def = to_onnx(clr, X_train.astype(numpy.float32),
1416 target_opset=12)
1418 oinf2 = OnnxInference(model_def, runtime='python_compiled')
1419 print(oinf2.run({'X': X_test[:5]}))
1421 # prints out the python function equivalent
1422 # to the onnx graph
1423 print(oinf2)
1424 """
1425 def clean_name(name):
1426 return name.replace(":", "_").replace('.', '_').replace('/', '_')
1428 # inits
1429 inputs = self.input_names
1430 code = ['def compiled_run(dict_inputs):']
1431 if debug:
1432 code.append(" printed = {}")
1434 context = {}
1436 # static variables
1437 for k in sorted(self.statics_):
1438 code.append(" # static: {0}".format(k))
1439 code.append(" {0} = dict_inputs['{1}']".format(
1440 clean_name(k), k))
1441 if debug:
1442 code.append(
1443 " debug_print('i.{0}', {1}, printed)".format(
1444 clean_name(k), k))
1446 # initializers
1447 for k, v in sorted(self.inits_.items()):
1448 if k.startswith("_OPT_"):
1449 raise RuntimeError( # pragma: no cover
1450 "The runtime cannot handle any constant name "
1451 "starting with '_OPT_': '{}'.".format(k))
1452 if k in inputs:
1453 context["_OPT_" + clean_name(k)] = v['value']
1454 code.append(" # init: _OPT_{0} ({1})".format(
1455 clean_name(k), k))
1456 if debug:
1457 code.append(
1458 " debug_print('c.[_OPT_{0}]', _OPT_{1}, printed)".format(
1459 clean_name(k), k))
1460 else:
1461 context[clean_name(k)] = v['value']
1462 code.append(" # init: {0} ({1})".format(
1463 clean_name(k), k))
1464 if debug:
1465 code.append(
1466 " debug_print('c.[{0}]', {1}, printed)".format(
1467 clean_name(k), k))
1469 # method signature
1470 code.append(" # inputs")
1471 for inp in inputs:
1472 if '_OPT_' + inp in context:
1473 # optional inputs
1474 code.append(
1475 " {0} = dict_inputs.get('{1}', _OPT_{0})".format(
1476 clean_name(inp), inp))
1477 else:
1478 code.append(" {0} = dict_inputs['{1}']".format(
1479 clean_name(inp), inp))
1480 if debug:
1481 code.append(
1482 " debug_print('i.{0}', {1}, printed)".format(
1483 clean_name(inp), inp))
1485 # code
1486 for i, node in enumerate(self.sequence_):
1487 name = "n{}_{}".format(i, node.ops_.__class__.__name__.lower())
1488 context[name] = node.ops_._run
1489 if (node.ops_.__class__.__name__ == 'Loop' and
1490 node.ops_.need_context()):
1491 # Adding context.
1492 ctx = "{%s}" % ", ".join(
1493 "'%s': %s" % (n, n) for n in node.ops_.additional_inputs)
1494 code.append(' ({1}, ) = {2}({0}, context={3})'.format(
1495 ', '.join(map(clean_name, node.inputs)),
1496 ', '.join(map(clean_name, node.outputs)),
1497 name, ctx))
1498 else:
1499 code.append(' ({1}, ) = {2}({0})'.format(
1500 ', '.join(map(clean_name, node.inputs)),
1501 ', '.join(map(clean_name, node.outputs)),
1502 name))
1503 if debug:
1504 code.append(" print('''# {}''')".format(code[-1][4:]))
1505 for o in node.outputs:
1506 code.append(
1507 " debug_print('o.{0}', {1}, printed)".format(
1508 clean_name(o), o))
1510 # return
1511 code.append(' return {')
1512 for out in self.output_names:
1513 code.append(" '{1}': {0},".format(
1514 clean_name(out), out))
1515 code.append(' }')
1516 final_code = '\n'.join(code)
1518 # compile the outcome
1519 context['self'] = self
1520 try:
1521 obj = compile(final_code, "<string>", 'exec')
1522 except SyntaxError as e: # pragma: no cover
1523 raise SyntaxError(
1524 "Unable to compile\n#####\n{}".format(final_code)) from e
1525 fcts_obj = [_ for _ in obj.co_consts
1526 if _ is not None and not isinstance(_, (bool, str, int))]
1527 fct = make_callable(
1528 "compiled_run", fcts_obj[0], final_code, context, debug)
1530 # end
1531 return "compiled_run", fct, final_code
1533 def reduce_size(self, pickable=False):
1534 """
1535 Reduces the memory footprint as much as possible.
1537 @param pickable keeps a pickle object?
1538 """
1539 import gc
1540 del self.graph_
1541 if not pickable:
1542 del self.obj
1543 if self.runtime in ('python_compiled', 'python_compiled_debug'):
1544 del self.sequence_
1545 gc.collect()
1547 def get_profiling(self, as_df=False):
1548 """
1549 Returns the profiling after a couple of execution.
1551 :param as_df: return the results as a dataframe (True)
1552 :return: dataframe or list of dictionaries
1554 .. versionadded:: 0.6
1555 """
1556 if (self.runtime_options is None or
1557 not self.runtime_options.get('enable_profiling', False)):
1558 raise RuntimeError(
1559 "Profiling is available if options 'enable_profiling' "
1560 "is set to true in 'runtime_options' but is %r." % self.runtime_options)
1561 prof = None
1562 if hasattr(self, '_whole'):
1563 prof = self._whole.get_profiling()
1564 if prof is None:
1565 raise NotImplementedError( # pragma: no cover
1566 "profiling is only implemented for runtime 'onnxruntime1'.")
1567 if as_df:
1568 import pandas
1569 return pandas.DataFrame(prof)
1570 return prof
1572 def get_execution_order(self):
1573 """
1574 This function returns a dictionary `{(kind, name): (order, op)}`,
1575 *name* can be a node name or a result name. In that case,
1576 it gets the execution order than the node which created it.
1577 The function returns None if the order is not available
1578 (the selected runtime does not return it). *kind* is either
1579 `'node'` or `'node'`. If two nodes have the same name,
1580 returned order is the last one. Initializers gets an execution
1581 order equal to -1, inputs to 0, all others results are >= 1.
1583 .. versionadded:: 0.7
1584 """
1585 if not hasattr(self, "sequence_"):
1586 return None
1588 res = {}
1589 for k, v in self.inits_.items():
1590 res['res', k] = (-1, v)
1591 for name, shape in self.input_names_shapes:
1592 res['res', name] = (0, shape)
1594 for i, node in enumerate(self.sequence_):
1595 key = ('node', node.onnx_node.name)
1596 res[key] = (i + 1, node)
1597 for out in node.onnx_node.output:
1598 key = ('res', out)
1599 if key in res:
1600 raise RuntimeError(
1601 "Output %r of node name %r already registered."
1602 "" % (out, node.onnx_node.name))
1603 res[key] = (i + 1, None)
1605 return res