Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Inspired from skl2onnx, handles two backends.
4"""
5from pyquickhelper.pycode import is_travis_or_appveyor
6from .utils_backend_common_compare import compare_runtime_session
7from ...tools.ort_wrapper import (
8 InferenceSession, GraphOptimizationLevel, SessionOptions)
11def _capture_output(fct, kind):
12 if is_travis_or_appveyor():
13 return fct(), None, None # pragma: no cover
14 try:
15 from cpyquickhelper.io import capture_output
16 except ImportError:
17 # cpyquickhelper not available
18 return fct(), None, None # pragma: no cover
19 return capture_output(fct, kind) # pragma: no cover
22class InferenceSession2:
23 """
24 Overwrites class *InferenceSession* to capture
25 the standard output and error.
26 """
28 def __init__(self, *args, **kwargs):
29 "Overwrites the constructor."
30 runtime_options = kwargs.pop('runtime_options', {})
31 disable_optimisation = runtime_options.pop(
32 'disable_optimisation', False)
33 if disable_optimisation:
34 if 'sess_options' in kwargs:
35 raise RuntimeError( # pragma: no cover
36 "Incompatible options, 'disable_options' and 'sess_options' cannot "
37 "be sepcified at the same time.")
38 kwargs['sess_options'] = SessionOptions()
39 kwargs['sess_options'].graph_optimization_level = (
40 GraphOptimizationLevel.ORT_DISABLE_ALL)
41 self.sess, self.outi, self.erri = _capture_output(
42 lambda: InferenceSession(*args, **kwargs), 'c')
44 def run(self, *args, **kwargs):
45 "Overwrites method *run*."
46 res, self.outr, self.errr = _capture_output(
47 lambda: self.sess.run(*args, **kwargs), 'c')
48 return res
50 def get_inputs(self, *args, **kwargs):
51 "Overwrites method *get_inputs*."
52 return self.sess.get_inputs(*args, **kwargs)
54 def get_outputs(self, *args, **kwargs):
55 "Overwrites method *get_outputs*."
56 return self.sess.get_outputs(*args, **kwargs)
59def compare_runtime(test, decimal=5, options=None,
60 verbose=False, context=None, comparable_outputs=None,
61 intermediate_steps=False, classes=None,
62 disable_optimisation=False):
63 """
64 The function compares the expected output (computed with
65 the model before being converted to ONNX) and the ONNX output
66 produced with module :epkg:`onnxruntime` or :epkg:`mlprodict`.
68 :param test: dictionary with the following keys:
69 - *onnx*: onnx model (filename or object)
70 - *expected*: expected output (filename pkl or object)
71 - *data*: input data (filename pkl or object)
72 :param decimal: precision of the comparison
73 :param options: comparison options
74 :param context: specifies custom operators
75 :param verbose: in case of error, the function may print
76 more information on the standard output
77 :param comparable_outputs: compare only these outputs
78 :param intermediate_steps: displays intermediate steps
79 in case of an error
80 :param classes: classes names (if option 'nocl' is used)
81 :param disable_optimisation: disable optimisation onnxruntime
82 could do
83 :return: tuple (outut, lambda function to run the predictions)
85 The function does not return anything but raises an error
86 if the comparison failed.
87 """
88 return compare_runtime_session(
89 InferenceSession2, test, decimal=decimal, options=options,
90 verbose=verbose, context=context,
91 comparable_outputs=comparable_outputs,
92 intermediate_steps=intermediate_steps,
93 classes=classes, disable_optimisation=disable_optimisation)