Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Inspired from skl2onnx, handles two backends.
4"""
5import numpy
6import onnx
7import pandas
8from ...tools.ort_wrapper import OrtInvalidArgument
9from .utils_backend_common import (
10 load_data_and_model, extract_options,
11 ExpectedAssertionError, OnnxBackendAssertionError,
12 OnnxRuntimeMissingNewOnnxOperatorException,
13 _compare_expected, _create_column)
16def compare_runtime_session( # pylint: disable=R0912
17 cls_session, test, decimal=5, options=None,
18 verbose=False, context=None, comparable_outputs=None,
19 intermediate_steps=False, classes=None,
20 disable_optimisation=False):
21 """
22 The function compares the expected output (computed with
23 the model before being converted to ONNX) and the ONNX output
24 produced with module :epkg:`onnxruntime` or :epkg:`mlprodict`.
26 :param cls_session: inference session instance (like @see cl OnnxInference)
27 :param test: dictionary with the following keys:
28 - *onnx*: onnx model (filename or object)
29 - *expected*: expected output (filename pkl or object)
30 - *data*: input data (filename pkl or object)
31 :param decimal: precision of the comparison
32 :param options: comparison options
33 :param context: specifies custom operators
34 :param verbose: in case of error, the function may print
35 more information on the standard output
36 :param comparable_outputs: compare only these outputs
37 :param intermediate_steps: displays intermediate steps
38 in case of an error
39 :param classes: classes names (if option 'nocl' is used)
40 :param disable_optimisation: disable optimisation the runtime may do
41 :return: tuple (outut, lambda function to run the predictions)
43 The function does not return anything but raises an error
44 if the comparison failed.
45 """
46 lambda_onnx = None
47 if context is None:
48 context = {}
49 load = load_data_and_model(test, **context)
50 if verbose: # pragma no cover
51 print("[compare_runtime] test '{}' loaded".format(test['onnx']))
53 onx = test['onnx']
55 if options is None:
56 if isinstance(onx, str):
57 options = extract_options(onx)
58 else:
59 options = {}
60 elif options is None:
61 options = {}
62 elif not isinstance(options, dict):
63 raise TypeError( # pragma no cover
64 "options must be a dictionary.")
66 if verbose: # pragma no cover
67 print("[compare_runtime] InferenceSession('{}')".format(onx))
69 runtime_options = dict(disable_optimisation=disable_optimisation)
70 try:
71 sess = cls_session(onx, runtime_options=runtime_options)
72 except TypeError as et: # pragma: no cover
73 raise TypeError( # pylint: disable=W0707
74 "Wrong signature for '{}' ({}).".format(cls_session.__name__, et))
75 except ExpectedAssertionError as expe: # pragma no cover
76 raise expe
77 except Exception as e: # pylint: disable=W0703
78 if "CannotLoad" in options: # pragma no cover
79 raise ExpectedAssertionError( # pylint: disable=W0707
80 "Unable to load onnx '{0}' due to\n{1}".format(onx, e))
81 else: # pragma no cover
82 if verbose: # pragma no cover
83 model = onnx.load(onx)
84 smodel = "\nJSON ONNX\n" + str(model)
85 else:
86 smodel = ""
87 if ("NOT_IMPLEMENTED : Could not find an implementation "
88 "for the node" in str(e)):
89 # onnxruntime does not implement a specific node yet.
90 raise OnnxRuntimeMissingNewOnnxOperatorException( # pylint: disable=W0707
91 "{3} does not implement a new operator "
92 "'{0}'\n{1}\nONNX\n{2}".format(
93 onx, e, smodel, cls_session))
94 if "NOT_IMPLEMENTED : Failed to find kernel" in str(e):
95 # onnxruntime does not implement a specific node yet
96 # in the kernel included in onnxruntime.
97 raise OnnxBackendAssertionError( # pylint: disable=W0707
98 "{3} misses a kernel for operator "
99 "'{0}'\n{1}\nONNX\n{2}".format(
100 onx, e, smodel, cls_session))
101 raise OnnxBackendAssertionError( # pylint: disable=W0707
102 "Unable to load onnx '{0}'\nONNX\n{1}\n{2}".format(
103 onx, smodel, e))
105 input = load["data"]
106 DF = options.pop('DF', False)
107 if DF:
108 inputs = {c: input[c].values for c in input.columns}
109 for k in inputs:
110 if inputs[k].dtype == numpy.float64:
111 inputs[k] = inputs[k].astype(numpy.float32)
112 inputs[k] = inputs[k].reshape((inputs[k].shape[0], 1))
113 else:
114 if isinstance(input, dict):
115 inputs = input
116 elif isinstance(input, (list, numpy.ndarray, pandas.DataFrame)):
117 inp = sess.get_inputs()
118 outs = sess.get_outputs()
119 if len(outs) == 0:
120 raise OnnxBackendAssertionError( # pragma: no cover
121 "Wrong number of outputs, onnx='{2}'".format(onx))
122 if len(inp) == len(input):
123 inputs = {i.name: v for i, v in zip(inp, input)}
124 elif len(inp) == 1:
125 inputs = {inp[0].name: input}
126 elif isinstance(input, numpy.ndarray):
127 shape = sum(i.shape[1] if len(i.shape) == 2 else i.shape[0]
128 for i in inp)
129 if shape == input.shape[1]:
130 inputs = {n.name: input[:, i] for i, n in enumerate(inp)}
131 else:
132 raise OnnxBackendAssertionError( # pragma: no cover
133 "Wrong number of inputs onnx {0} != "
134 "original shape {1}, onnx='{2}'"
135 .format(len(inp), input.shape, onx))
136 elif isinstance(input, list):
137 try:
138 array_input = numpy.array(input)
139 except Exception: # pragma no cover
140 raise OnnxBackendAssertionError( # pylint: disable=W0707
141 "Wrong number of inputs onnx {0} != "
142 "original {1}, onnx='{2}'"
143 .format(len(inp), len(input), onx))
144 shape = sum(i.shape[1] for i in inp)
145 if shape == array_input.shape[1]:
146 inputs = {}
147 c = 0
148 for i, n in enumerate(inp):
149 d = c + n.shape[1]
150 inputs[n.name] = _create_column(
151 [row[c:d] for row in input], n.type)
152 c = d
153 else:
154 raise OnnxBackendAssertionError( # pragma no cover
155 "Wrong number of inputs onnx {0} != "
156 "original shape {1}, onnx='{2}'*"
157 .format(len(inp), array_input.shape, onx))
158 elif isinstance(input, pandas.DataFrame):
159 try:
160 array_input = numpy.array(input)
161 except Exception: # pragma no cover
162 raise OnnxBackendAssertionError( # pylint: disable=W0707
163 "Wrong number of inputs onnx {0} != "
164 "original {1}, onnx='{2}'"
165 .format(len(inp), len(input), onx))
166 shape = sum(i.shape[1] for i in inp)
167 if shape == array_input.shape[1]:
168 inputs = {}
169 c = 0
170 for i, n in enumerate(inp):
171 d = c + n.shape[1]
172 inputs[n.name] = _create_column(
173 input.iloc[:, c:d], n.type)
174 c = d
175 else:
176 raise OnnxBackendAssertionError( # pragma no cover
177 "Wrong number of inputs onnx {0}={1} columns != "
178 "original shape {2}, onnx='{3}'*"
179 .format(len(inp), shape, array_input.shape, onx))
180 else:
181 raise OnnxBackendAssertionError( # pragma no cover
182 "Wrong type of inputs onnx {0}, onnx='{1}'".format(
183 type(input), onx))
184 else:
185 raise OnnxBackendAssertionError( # pragma no cover
186 "Dict or list is expected, not {0}".format(type(input)))
188 for k in inputs:
189 if isinstance(inputs[k], list):
190 inputs[k] = numpy.array(inputs[k])
192 options.pop('SklCol', False) # unused here but in dump_data_and_model
194 if verbose: # pragma no cover
195 print("[compare_runtime] type(inputs)={} len={} names={}".format(
196 type(input), len(inputs), list(sorted(inputs))))
197 if verbose: # pragma no cover
198 if intermediate_steps:
199 run_options = {'verbose': 3, 'fLOG': print}
200 else:
201 run_options = {'verbose': 2, 'fLOG': print}
202 else:
203 run_options = {}
204 try:
205 try:
206 output = sess.run(None, inputs, **run_options)
207 except TypeError: # pragma no cover
208 output = sess.run(None, inputs)
209 lambda_onnx = lambda: sess.run(None, inputs) # noqa
210 if verbose: # pragma no cover
211 import pprint
212 pprint.pprint(output)
213 except ExpectedAssertionError as expe: # pragma no cover
214 raise expe
215 except (RuntimeError, OrtInvalidArgument) as e: # pragma no cover
216 if intermediate_steps:
217 sess.run(None, inputs, verbose=3, fLOG=print)
218 if "-Fail" in onx:
219 raise ExpectedAssertionError( # pylint: disable=W0707
220 "{1} cannot compute the prediction for '{0}'".
221 format(onx, cls_session))
222 else:
223 if verbose: # pragma no cover
224 model = onnx.load(onx)
225 smodel = "\nJSON ONNX\n" + str(model)
226 else:
227 smodel = ""
228 import pprint
229 raise OnnxBackendAssertionError( # pylint: disable=W0707
230 "{4} cannot compute the predictions"
231 " for '{0}' due to {1}{2}\n{3}"
232 .format(onx, e, smodel, pprint.pformat(inputs),
233 cls_session))
234 except Exception as e: # pragma no cover
235 raise OnnxBackendAssertionError( # pylint: disable=W0707
236 "Unable to run onnx '{0}' due to {1}".format(onx, e))
237 if verbose: # pragma no cover
238 print("[compare_runtime] done type={}".format(type(output)))
240 output0 = output.copy()
242 if comparable_outputs:
243 cmp_exp = [load["expected"][o] for o in comparable_outputs]
244 cmp_out = [output[o] for o in comparable_outputs]
245 else:
246 cmp_exp = load["expected"]
247 cmp_out = output
249 try:
250 _compare_expected(cmp_exp, cmp_out, sess, onx,
251 decimal=decimal, verbose=verbose,
252 classes=classes, **options)
253 except ExpectedAssertionError as expe: # pragma no cover
254 raise expe
255 except Exception as e: # pragma no cover
256 if verbose: # pragma no cover
257 model = onnx.load(onx)
258 smodel = "\nJSON ONNX\n" + str(model)
259 else:
260 smodel = ""
261 raise OnnxBackendAssertionError( # pylint: disable=W0707
262 "Model '{}' has discrepencies with cls='{}'.\n{}: {}{}".format(
263 onx, sess.__class__.__name__, type(e), e, smodel))
265 return output0, lambda_onnx