Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Validates runtime for many :scikit-learn: operators.
4The submodule relies on :epkg:`onnxconverter_common`,
5:epkg:`sklearn-onnx`.
6"""
7import pprint
8from inspect import signature
9import numpy
10from numpy.linalg import LinAlgError
11import sklearn
12from sklearn import __all__ as sklearn__all__, __version__ as sklearn_version
13from sklearn.exceptions import ConvergenceWarning
14from sklearn.utils._testing import ignore_warnings
15from ... import __version__ as ort_version
16from ...onnx_conv import to_onnx, register_converters, register_rewritten_operators
17from ...tools.ort_wrapper import onnxrt_version
18from ...tools.model_info import analyze_model, set_random_state
19from ...tools.asv_options_helper import (
20 get_opset_number_from_onnx, get_ir_version_from_onnx)
21from ..onnx_inference import OnnxInference
22from ...onnx_tools.optim.sklearn_helper import inspect_sklearn_model, set_n_jobs
23from ...onnx_tools.optim.onnx_helper import onnx_statistics
24from ...onnx_tools.optim import onnx_optimisations
25from .validate_problems import find_suitable_problem
26from .validate_scenarios import _extra_parameters
27from .validate_difference import measure_relative_difference
28from .validate_helper import (
29 _dispsimple, sklearn_operators,
30 _measure_time, _shape_exc, dump_into_folder,
31 default_time_kwargs, RuntimeBadResultsError,
32 _dictionary2str, _merge_options, _multiply_time_kwargs,
33 _get_problem_data)
34from .validate_benchmark import benchmark_fct
37@ignore_warnings(category=(UserWarning, ConvergenceWarning))
38def _dofit_model(dofit, obs, inst, X_train, y_train, X_test, y_test,
39 Xort_test, init_types, store_models,
40 debug, verbose, fLOG):
41 if dofit:
42 if verbose >= 2 and fLOG is not None:
43 fLOG("[enumerate_compatible_opset] fit, type: '{}' dtype: {}".format(
44 type(X_train), getattr(X_train, 'dtype', '-')))
45 try:
46 set_random_state(inst)
47 if y_train is None:
48 t4 = _measure_time(lambda: inst.fit(X_train))[1]
49 else:
50 t4 = _measure_time(
51 lambda: inst.fit(X_train, y_train))[1]
52 except (AttributeError, TypeError, ValueError,
53 IndexError, NotImplementedError, MemoryError,
54 LinAlgError, StopIteration) as e:
55 if debug:
56 raise # pragma: no cover
57 obs["_1training_time_exc"] = str(e)
58 return False
60 obs["training_time"] = t4
61 try:
62 skl_st = inspect_sklearn_model(inst)
63 except NotImplementedError:
64 skl_st = {}
65 obs.update({'skl_' + k: v for k, v in skl_st.items()})
67 if store_models:
68 obs['MODEL'] = inst
69 obs['X_test'] = X_test
70 obs['Xort_test'] = Xort_test
71 obs['init_types'] = init_types
72 else:
73 obs["training_time"] = 0.
74 if store_models:
75 obs['MODEL'] = inst
76 obs['init_types'] = init_types
78 return True
81def _run_skl_prediction(obs, check_runtime, assume_finite, inst,
82 method_name, predict_kwargs, X_test,
83 benchmark, debug, verbose, time_kwargs,
84 skip_long_test, time_kwargs_fact, fLOG):
85 if not check_runtime:
86 return None # pragma: no cover
87 if verbose >= 2 and fLOG is not None:
88 fLOG("[enumerate_compatible_opset] check_runtime SKL {}-{}-{}-{}-{}".format(
89 id(inst), method_name, predict_kwargs, time_kwargs,
90 time_kwargs_fact))
91 with sklearn.config_context(assume_finite=assume_finite):
92 # compute sklearn prediction
93 obs['ort_version'] = ort_version
94 try:
95 meth = getattr(inst, method_name)
96 except AttributeError as e:
97 if debug:
98 raise # pragma: no cover
99 obs['_2skl_meth_exc'] = str(e)
100 return e
101 try:
102 ypred, t4, ___ = _measure_time(
103 lambda: meth(X_test, **predict_kwargs))
104 obs['lambda-skl'] = (lambda xo: meth(xo, **predict_kwargs), X_test)
105 except (ValueError, AttributeError, TypeError, MemoryError, IndexError) as e:
106 if debug:
107 raise # pragma: no cover
108 obs['_3prediction_exc'] = str(e)
109 return e
110 obs['prediction_time'] = t4
111 obs['assume_finite'] = assume_finite
112 if benchmark and 'lambda-skl' in obs:
113 obs['bench-skl'] = benchmark_fct(
114 *obs['lambda-skl'], obs=obs,
115 time_kwargs=_multiply_time_kwargs(
116 time_kwargs, time_kwargs_fact, inst),
117 skip_long_test=skip_long_test)
118 if verbose >= 3 and fLOG is not None:
119 fLOG("[enumerate_compatible_opset] scikit-learn prediction")
120 _dispsimple(ypred, fLOG)
121 if verbose >= 2 and fLOG is not None:
122 fLOG("[enumerate_compatible_opset] predictions stored")
123 return ypred
126def _retrieve_problems_extra(model, verbose, fLOG, extended_list):
127 """
128 Use by @see fn enumerate_compatible_opset.
129 """
130 extras = None
131 if extended_list:
132 from ...onnx_conv.validate_scenarios import find_suitable_problem as fsp_extended
133 problems = fsp_extended(model)
134 if problems is not None:
135 from ...onnx_conv.validate_scenarios import build_custom_scenarios as fsp_scenarios
136 extra_parameters = fsp_scenarios()
138 if verbose >= 2 and fLOG is not None:
139 fLOG(
140 "[enumerate_compatible_opset] found custom for model={}".format(model))
141 extras = extra_parameters.get(model, None)
142 if extras is not None:
143 fLOG(
144 "[enumerate_compatible_opset] found custom scenarios={}".format(extras))
145 else:
146 problems = None
148 if problems is None:
149 # scikit-learn
150 extra_parameters = _extra_parameters
151 try:
152 problems = find_suitable_problem(model)
153 except RuntimeError as e:
154 return {'name': model.__name__, 'skl_version': sklearn_version,
155 '_0problem_exc': e}, extras
156 extras = extra_parameters.get(model, [('default', {})])
158 # checks existence of random_state
159 sig = signature(model.__init__)
160 if 'random_state' in sig.parameters:
161 new_extras = []
162 for extra in extras:
163 if 'random_state' not in extra[1]:
164 ps = extra[1].copy()
165 ps['random_state'] = 42
166 if len(extra) == 2:
167 extra = (extra[0], ps)
168 else:
169 extra = (extra[0], ps) + extra[2:]
170 new_extras.append(extra)
171 extras = new_extras
173 return problems, extras
176def enumerate_compatible_opset(model, opset_min=-1, opset_max=-1, # pylint: disable=R0914
177 check_runtime=True, debug=False,
178 runtime='python', dump_folder=None,
179 store_models=False, benchmark=False,
180 assume_finite=True, node_time=False,
181 fLOG=print, filter_exp=None,
182 verbose=0, time_kwargs=None,
183 extended_list=False, dump_all=False,
184 n_features=None, skip_long_test=True,
185 filter_scenario=None, time_kwargs_fact=None,
186 time_limit=4, n_jobs=None):
187 """
188 Lists all compatible opsets for a specific model.
190 @param model operator class
191 @param opset_min starts with this opset
192 @param opset_max ends with this opset (None to use
193 current onnx opset)
194 @param check_runtime checks that runtime can consume the
195 model and compute predictions
196 @param debug catch exception (True) or not (False)
197 @param runtime test a specific runtime, by default ``'python'``
198 @param dump_folder dump information to replicate in case of mismatch
199 @param dump_all dump all models not only the one which fail
200 @param store_models if True, the function
201 also stores the fitted model and its conversion
202 into :epkg:`ONNX`
203 @param benchmark if True, measures the time taken by each function
204 to predict for different number of rows
205 @param fLOG logging function
206 @param filter_exp function which tells if the experiment must be run,
207 None to run all, takes *model, problem* as an input
208 @param filter_scenario second function which tells if the experiment must be run,
209 None to run all, takes *model, problem, scenario, extra, options*
210 as an input
211 @param node_time collect time for each node in the :epkg:`ONNX` graph
212 @param assume_finite See `config_context
213 <https://scikit-learn.org/stable/modules/generated/
214 sklearn.config_context.html>`_, If True, validation for finiteness
215 will be skipped, saving time, but leading to potential crashes.
216 If False, validation for finiteness will be performed, avoiding error.
217 @param verbose verbosity
218 @param extended_list extends the list to custom converters
219 and problems
220 @param time_kwargs to define a more precise way to measure a model
221 @param n_features modifies the shorts datasets used to train the models
222 to use exactly this number of features, it can also
223 be a list to test multiple datasets
224 @param skip_long_test skips tests for high values of N if they seem too long
225 @param time_kwargs_fact see :func:`_multiply_time_kwargs <mlprodict.onnxrt.validate.validate_helper._multiply_time_kwargs>`
226 @param time_limit to stop benchmarking after this amount of time was spent
227 @param n_jobs *n_jobs* is set to the number of CPU by default unless this
228 value is changed
229 @return dictionaries, each row has the following
230 keys: opset, exception if any, conversion time,
231 problem chosen to test the conversion...
233 The function requires :epkg:`sklearn-onnx`.
234 The outcome can be seen at pages references
235 by :ref:`l-onnx-availability`.
236 The parameter *time_kwargs* is a dictionary which defines the
237 number of times to repeat the same predictions in order
238 to give more precise figures. The default value (if None) is returned
239 by the following code:
241 .. runpython::
242 :showcode:
243 :warningout: DeprecationWarning
245 from mlprodict.onnxrt.validate.validate_helper import default_time_kwargs
246 import pprint
247 pprint.pprint(default_time_kwargs())
249 Parameter *time_kwargs_fact* multiples these values for some
250 specific models. ``'lin'`` multiplies by 10 when the model
251 is linear.
252 """
253 if opset_min == -1:
254 opset_min = get_opset_number_from_onnx() # pragma: no cover
255 if opset_max == -1:
256 opset_max = get_opset_number_from_onnx() # pragma: no cover
257 if verbose > 0 and fLOG is not None:
258 fLOG("[enumerate_compatible_opset] opset in [{}, {}].".format(
259 opset_min, opset_max))
260 if verbose > 1 and fLOG:
261 fLOG("[enumerate_compatible_opset] validate class '{}'.".format(
262 model.__name__))
263 if verbose > 2:
264 fLOG(model)
266 if time_kwargs is None:
267 time_kwargs = default_time_kwargs()
268 problems, extras = _retrieve_problems_extra(
269 model, verbose, fLOG, extended_list)
270 if isinstance(problems, dict):
271 yield problems # pragma: no cover
272 problems = [] # pragma: no cover
274 if opset_max is None:
275 opset_max = get_opset_number_from_onnx() # pragma: no cover
276 opsets = list(range(opset_min, opset_max + 1)) # pragma: no cover
277 opsets.append(None) # pragma: no cover
278 else:
279 opsets = list(range(opset_min, opset_max + 1))
281 if extras is None:
282 problems = []
283 yield {'name': model.__name__, 'skl_version': sklearn_version,
284 '_0problem_exc': 'SKIPPED'}
286 if not isinstance(n_features, list):
287 n_features = [n_features]
289 for prob in problems:
290 if filter_exp is not None and not filter_exp(model, prob):
291 continue
292 for n_feature in n_features:
293 if verbose >= 2 and fLOG is not None:
294 fLOG("[enumerate_compatible_opset] problem={} n_feature={}".format(
295 prob, n_feature))
297 (X_train, X_test, y_train,
298 y_test, Xort_test,
299 init_types, conv_options, method_name,
300 output_index, dofit, predict_kwargs) = _get_problem_data(prob, n_feature)
302 for scenario_extra in extras:
303 subset_problems = None
304 optimisations = None
305 new_conv_options = None
306 if len(scenario_extra) > 2:
307 options = scenario_extra[2]
308 if isinstance(options, dict):
309 subset_problems = options.get('subset_problems', None)
310 optimisations = options.get('optim', None)
311 new_conv_options = options.get('conv_options', None)
312 else:
313 subset_problems = options
315 if subset_problems and isinstance(subset_problems, (list, set)):
316 if prob not in subset_problems:
317 # Skips unrelated problem for a specific configuration.
318 continue
319 elif subset_problems is not None:
320 raise RuntimeError( # pragma: no cover
321 "subset_problems must be a set or a list not {}.".format(
322 subset_problems))
324 try:
325 scenario, extra = scenario_extra[:2]
326 except TypeError as e: # pragma: no cover
327 raise TypeError(
328 "Unable to interpret 'scenario_extra'\n{}".format(
329 scenario_extra)) from e
330 if optimisations is None:
331 optimisations = [None]
332 if new_conv_options is None:
333 new_conv_options = [{}]
335 if (filter_scenario is not None and
336 not filter_scenario(model, prob, scenario,
337 extra, new_conv_options)):
338 continue
340 if verbose >= 2 and fLOG is not None:
341 fLOG("[enumerate_compatible_opset] ##############################")
342 fLOG("[enumerate_compatible_opset] scenario={} optim={} extra={} dofit={} (problem={})".format(
343 scenario, optimisations, extra, dofit, prob))
345 # training
346 obs = {'scenario': scenario, 'name': model.__name__,
347 'skl_version': sklearn_version, 'problem': prob,
348 'method_name': method_name, 'output_index': output_index,
349 'fit': dofit, 'conv_options': conv_options,
350 'idtype': Xort_test.dtype, 'predict_kwargs': predict_kwargs,
351 'init_types': init_types, 'inst': extra if extra else None,
352 'n_features': X_train.shape[1] if len(X_train.shape) == 2 else 1}
353 inst = None
354 extra = set_n_jobs(model, extra, n_jobs=n_jobs)
355 try:
356 inst = model(**extra)
357 except TypeError as e: # pragma: no cover
358 if debug: # pragma: no cover
359 raise
360 if "__init__() missing" not in str(e):
361 raise RuntimeError(
362 "Unable to instantiate model '{}'.\nextra=\n{}".format(
363 model.__name__, pprint.pformat(extra))) from e
364 yield obs.copy()
365 continue
367 if not _dofit_model(dofit, obs, inst, X_train, y_train, X_test, y_test,
368 Xort_test, init_types, store_models,
369 debug, verbose, fLOG):
370 yield obs.copy()
371 continue
373 # statistics about the trained model
374 skl_infos = analyze_model(inst)
375 for k, v in skl_infos.items():
376 obs['fit_' + k] = v
378 # runtime
379 ypred = _run_skl_prediction(
380 obs, check_runtime, assume_finite, inst,
381 method_name, predict_kwargs, X_test,
382 benchmark, debug, verbose, time_kwargs,
383 skip_long_test, time_kwargs_fact, fLOG)
384 if isinstance(ypred, Exception):
385 yield obs.copy()
386 continue
388 for run_obs in _call_conv_runtime_opset(
389 obs=obs.copy(), opsets=opsets, debug=debug,
390 new_conv_options=new_conv_options,
391 model=model, prob=prob, scenario=scenario,
392 extra=extra, extras=extras, conv_options=conv_options,
393 init_types=init_types, inst=inst,
394 optimisations=optimisations, verbose=verbose,
395 benchmark=benchmark, runtime=runtime,
396 filter_scenario=filter_scenario,
397 X_test=X_test, y_test=y_test, ypred=ypred,
398 Xort_test=Xort_test, method_name=method_name,
399 check_runtime=check_runtime,
400 output_index=output_index,
401 kwargs=dict(
402 dump_all=dump_all,
403 dump_folder=dump_folder,
404 node_time=node_time,
405 skip_long_test=skip_long_test,
406 store_models=store_models,
407 time_kwargs=_multiply_time_kwargs(
408 time_kwargs, time_kwargs_fact, inst)
409 ),
410 time_limit=time_limit,
411 fLOG=fLOG):
412 yield run_obs
415def _check_run_benchmark(benchmark, stat_onnx, bench_memo, runtime):
416 unique = set(stat_onnx.items())
417 unique.add(runtime)
418 run_benchmark = benchmark and all(
419 map(lambda u: unique != u, bench_memo))
420 if run_benchmark:
421 bench_memo.append(unique)
422 return run_benchmark
425def _call_conv_runtime_opset(
426 obs, opsets, debug, new_conv_options,
427 model, prob, scenario, extra, extras, conv_options,
428 init_types, inst, optimisations, verbose,
429 benchmark, runtime, filter_scenario,
430 check_runtime, X_test, y_test, ypred, Xort_test,
431 method_name, output_index,
432 kwargs, time_limit, fLOG):
433 # Calls the conversion and runtime for different opets
434 if None in opsets:
435 set_opsets = [None] + list(sorted((_ for _ in opsets if _ is not None),
436 reverse=True))
437 else:
438 set_opsets = list(sorted(opsets, reverse=True))
439 bench_memo = []
441 for opset in set_opsets:
442 if verbose >= 2 and fLOG is not None:
443 fLOG("[enumerate_compatible_opset] opset={} init_types={}".format(
444 opset, init_types))
445 obs_op = obs.copy()
446 if opset is not None:
447 obs_op['opset'] = opset
449 if len(init_types) != 1:
450 raise NotImplementedError( # pragma: no cover
451 "Multiple types are is not implemented: "
452 "{}.".format(init_types))
454 if not isinstance(runtime, list):
455 runtime = [runtime]
457 obs_op_0c = obs_op.copy()
458 for aoptions in new_conv_options:
459 obs_op = obs_op_0c.copy()
460 all_conv_options = {} if conv_options is None else conv_options.copy()
461 all_conv_options = _merge_options(
462 all_conv_options, aoptions)
463 obs_op['conv_options'] = all_conv_options
465 if (filter_scenario is not None and
466 not filter_scenario(model, prob, scenario,
467 extra, all_conv_options)):
468 continue
470 for rt in runtime:
471 def fct_conv(itt=inst, it=init_types[0][1], ops=opset,
472 options=all_conv_options):
473 return to_onnx(itt, it, target_opset=ops, options=options,
474 rewrite_ops=rt in ('', None, 'python',
475 'python_compiled'))
477 if verbose >= 2 and fLOG is not None:
478 fLOG(
479 "[enumerate_compatible_opset] conversion to onnx: {}".format(all_conv_options))
480 try:
481 conv, t4 = _measure_time(fct_conv)[:2]
482 obs_op["convert_time"] = t4
483 except (RuntimeError, IndexError, AttributeError, TypeError,
484 ValueError, NameError, NotImplementedError) as e:
485 if debug:
486 fLOG(pprint.pformat(obs_op)) # pragma: no cover
487 raise # pragma: no cover
488 obs_op["_4convert_exc"] = e
489 yield obs_op.copy()
490 continue
492 if verbose >= 6 and fLOG is not None:
493 fLOG( # pragma: no cover
494 "[enumerate_compatible_opset] ONNX:\n{}".format(conv))
496 if all_conv_options.get('optim', '') == 'cdist': # pragma: no cover
497 check_cdist = [_ for _ in str(conv).split('\n')
498 if 'CDist' in _]
499 check_scan = [_ for _ in str(conv).split('\n')
500 if 'Scan' in _]
501 if len(check_cdist) == 0 and len(check_scan) > 0:
502 raise RuntimeError(
503 "Operator CDist was not used in\n{}"
504 "".format(conv))
506 obs_op0 = obs_op.copy()
507 for optimisation in optimisations:
508 obs_op = obs_op0.copy()
509 if optimisation is not None:
510 if optimisation == 'onnx':
511 obs_op['optim'] = optimisation
512 if len(aoptions) != 0:
513 obs_op['optim'] += '/' + \
514 _dictionary2str(aoptions)
515 conv = onnx_optimisations(conv)
516 else:
517 raise ValueError( # pragma: no cover
518 "Unknown optimisation option '{}' (extra={})"
519 "".format(optimisation, extras))
520 else:
521 obs_op['optim'] = _dictionary2str(aoptions)
523 if verbose >= 3 and fLOG is not None:
524 fLOG("[enumerate_compatible_opset] optim='{}' optimisation={} all_conv_options={}".format(
525 obs_op['optim'], optimisation, all_conv_options))
526 if kwargs['store_models']:
527 obs_op['ONNX'] = conv
528 if verbose >= 2 and fLOG is not None:
529 fLOG( # pragma: no cover
530 "[enumerate_compatible_opset] onnx nodes: {}".format(
531 len(conv.graph.node)))
532 stat_onnx = onnx_statistics(conv)
533 obs_op.update(
534 {'onx_' + k: v for k, v in stat_onnx.items()})
536 # opset_domain
537 for op_imp in list(conv.opset_import):
538 obs_op['domain_opset_%s' %
539 op_imp.domain] = op_imp.version
541 run_benchmark = _check_run_benchmark(
542 benchmark, stat_onnx, bench_memo, rt)
544 # prediction
545 if check_runtime:
546 yield _call_runtime(obs_op=obs_op.copy(), conv=conv,
547 opset=opset, debug=debug,
548 runtime=rt, inst=inst,
549 X_test=X_test, y_test=y_test,
550 init_types=init_types,
551 method_name=method_name,
552 output_index=output_index,
553 ypred=ypred, Xort_test=Xort_test,
554 model=model,
555 dump_folder=kwargs['dump_folder'],
556 benchmark=run_benchmark,
557 node_time=kwargs['node_time'],
558 time_kwargs=kwargs['time_kwargs'],
559 fLOG=fLOG, verbose=verbose,
560 store_models=kwargs['store_models'],
561 dump_all=kwargs['dump_all'],
562 skip_long_test=kwargs['skip_long_test'],
563 time_limit=time_limit)
564 else:
565 yield obs_op.copy() # pragma: no cover
568def _call_runtime(obs_op, conv, opset, debug, inst, runtime,
569 X_test, y_test, init_types, method_name, output_index,
570 ypred, Xort_test, model, dump_folder,
571 benchmark, node_time, fLOG,
572 verbose, store_models, time_kwargs,
573 dump_all, skip_long_test, time_limit):
574 """
575 Private.
576 """
577 if 'onnxruntime' in runtime:
578 old = conv.ir_version
579 conv.ir_version = get_ir_version_from_onnx()
580 else:
581 old = None
583 ser, t5, ___ = _measure_time(lambda: conv.SerializeToString())
584 obs_op['tostring_time'] = t5
585 obs_op['runtime'] = runtime
587 if old is not None:
588 conv.ir_version = old
590 # load
591 if verbose >= 2 and fLOG is not None:
592 fLOG("[enumerate_compatible_opset-R] load onnx")
593 try:
594 sess, t5, ___ = _measure_time(
595 lambda: OnnxInference(ser, runtime=runtime))
596 obs_op['tostring_time'] = t5
597 except (RuntimeError, ValueError, KeyError, IndexError, TypeError) as e:
598 if debug:
599 raise # pragma: no cover
600 obs_op['_5ort_load_exc'] = e
601 return obs_op
603 # compute batch
604 if store_models:
605 obs_op['OINF'] = sess
606 if verbose >= 2 and fLOG is not None:
607 fLOG("[enumerate_compatible_opset-R] compute batch with runtime "
608 "'{}'".format(runtime))
610 def fct_batch(se=sess, xo=Xort_test, it=init_types): # pylint: disable=W0102
611 return se.run({it[0][0]: xo},
612 verbose=max(verbose - 1, 1) if debug else 0, fLOG=fLOG)
614 try:
615 opred, t5, ___ = _measure_time(fct_batch)
616 obs_op['ort_run_time_batch'] = t5
617 obs_op['lambda-batch'] = (lambda xo: sess.run(
618 {init_types[0][0]: xo}, node_time=node_time), Xort_test)
619 except (RuntimeError, TypeError, ValueError, KeyError, IndexError) as e:
620 if debug:
621 raise RuntimeError("Issue with {}.".format(
622 obs_op)) from e # pragma: no cover
623 obs_op['_6ort_run_batch_exc'] = e
624 if (benchmark or node_time) and 'lambda-batch' in obs_op:
625 try:
626 benres = benchmark_fct(*obs_op['lambda-batch'], obs=obs_op,
627 node_time=node_time, time_kwargs=time_kwargs,
628 skip_long_test=skip_long_test,
629 time_limit=time_limit)
630 obs_op['bench-batch'] = benres
631 except (RuntimeError, TypeError, ValueError) as e: # pragma: no cover
632 if debug:
633 raise e # pragma: no cover
634 obs_op['_6ort_run_batch_exc'] = e
635 obs_op['_6ort_run_batch_bench_exc'] = e
637 # difference
638 debug_exc = []
639 if verbose >= 2 and fLOG is not None:
640 fLOG("[enumerate_compatible_opset-R] differences")
641 if '_6ort_run_batch_exc' not in obs_op:
642 if isinstance(opred, dict):
643 ch = [(k, v) for k, v in opred.items()]
644 opred = [_[1] for _ in ch]
646 if output_index != 'all':
647 try:
648 opred = opred[output_index]
649 except IndexError as e: # pragma: no cover
650 if debug:
651 raise IndexError(
652 "Issue with output_index={}/{}".format(
653 output_index, len(opred))) from e
654 obs_op['_8max_rel_diff_batch_exc'] = (
655 "Unable to fetch output {}/{} for model '{}'"
656 "".format(output_index, len(opred),
657 model.__name__))
658 opred = None
660 if opred is not None:
661 if store_models:
662 obs_op['skl_outputs'] = ypred
663 obs_op['ort_outputs'] = opred
664 if verbose >= 3 and fLOG is not None:
665 fLOG("[_call_runtime] runtime prediction")
666 _dispsimple(opred, fLOG)
668 if (method_name == "decision_function" and hasattr(opred, 'shape') and
669 hasattr(ypred, 'shape') and len(opred.shape) == 2 and
670 opred.shape[1] == 2 and len(ypred.shape) == 1):
671 # decision_function, for binary classification,
672 # raw score is a distance
673 max_rel_diff = measure_relative_difference(
674 ypred, opred[:, 1])
675 else:
676 max_rel_diff = measure_relative_difference(
677 ypred, opred)
679 if max_rel_diff >= 1e9 and debug: # pragma: no cover
680 _shape = lambda o: o.shape if hasattr(
681 o, 'shape') else 'no shape'
682 raise RuntimeError(
683 "Big difference (opset={}, runtime='{}' p='{}' s='{}')"
684 ":\n-------\n{}-{}\n{}\n--------\n{}-{}\n{}".format(
685 opset, runtime, obs_op['problem'], obs_op['scenario'],
686 type(ypred), _shape(ypred), ypred,
687 type(opred), _shape(opred), opred))
689 if numpy.isnan(max_rel_diff):
690 obs_op['_8max_rel_diff_batch_exc'] = ( # pragma: no cover
691 "Unable to compute differences between"
692 " {}-{}\n{}\n--------\n{}".format(
693 _shape_exc(
694 ypred), _shape_exc(opred),
695 ypred, opred))
696 if debug: # pragma: no cover
697 debug_exc.append(RuntimeError(
698 obs_op['_8max_rel_diff_batch_exc']))
699 else:
700 obs_op['max_rel_diff_batch'] = max_rel_diff
701 if dump_folder and max_rel_diff > 1e-5:
702 dump_into_folder(dump_folder, kind='batch', obs_op=obs_op,
703 X_test=X_test, y_test=y_test, Xort_test=Xort_test)
704 if debug and max_rel_diff >= 0.1: # pragma: no cover
705 raise RuntimeError("Two big differences {}\n{}\n{}\n{}".format(
706 max_rel_diff, inst, conv, pprint.pformat(obs_op)))
708 if debug and len(debug_exc) == 2:
709 raise debug_exc[0] # pragma: no cover
710 if debug and verbose >= 2: # pragma: no cover
711 if verbose >= 3:
712 fLOG(pprint.pformat(obs_op))
713 else:
714 obs_op_log = {k: v for k,
715 v in obs_op.items() if 'lambda-' not in k}
716 fLOG(pprint.pformat(obs_op_log))
717 if verbose >= 2 and fLOG is not None:
718 fLOG("[enumerate_compatible_opset-R] next...")
719 if dump_all:
720 dump = dump_into_folder(dump_folder, kind='batch', obs_op=obs_op,
721 X_test=X_test, y_test=y_test, Xort_test=Xort_test,
722 is_error=len(debug_exc) > 1,
723 onnx_bytes=conv.SerializeToString(),
724 skl_model=inst, ypred=ypred)
725 obs_op['dumped'] = dump
726 return obs_op
729def _enumerate_validated_operator_opsets_ops(extended_list, models, skip_models):
730 ops = [_ for _ in sklearn_operators(extended=extended_list)]
732 if models is not None:
733 if not all(map(lambda m: isinstance(m, str), models)):
734 raise ValueError( # pragma: no cover
735 "models must be a set of strings.")
736 ops_ = [_ for _ in ops if _['name'] in models]
737 if len(ops) == 0:
738 raise ValueError( # pragma: no cover
739 "Parameter models is wrong: {}\n{}".format(
740 models, ops[0]))
741 ops = ops_
742 if skip_models is not None:
743 ops = [m for m in ops if m['name'] not in skip_models]
744 return ops
747def _enumerate_validated_operator_opsets_version(runtime):
748 from numpy import __version__ as numpy_version
749 from onnx import __version__ as onnx_version
750 from scipy import __version__ as scipy_version
751 from skl2onnx import __version__ as skl2onnx_version
752 add_versions = {'v_numpy': numpy_version, 'v_onnx': onnx_version,
753 'v_scipy': scipy_version, 'v_skl2onnx': skl2onnx_version,
754 'v_sklearn': sklearn_version, 'v_onnxruntime': ort_version}
755 if "onnxruntime" in runtime:
756 add_versions['v_onnxruntime'] = onnxrt_version
757 return add_versions
760def enumerate_validated_operator_opsets(verbose=0, opset_min=-1, opset_max=-1,
761 check_runtime=True, debug=False, runtime='python',
762 models=None, dump_folder=None, store_models=False,
763 benchmark=False, skip_models=None,
764 assume_finite=True, node_time=False,
765 fLOG=print, filter_exp=None,
766 versions=False, extended_list=False,
767 time_kwargs=None, dump_all=False,
768 n_features=None, skip_long_test=True,
769 fail_bad_results=False,
770 filter_scenario=None,
771 time_kwargs_fact=None,
772 time_limit=4, n_jobs=None):
773 """
774 Tests all possible configurations for all possible
775 operators and returns the results.
777 :param verbose: integer 0, 1, 2
778 :param opset_min: checks conversion starting from the opset, -1
779 to get the last one
780 :param opset_max: checks conversion up to this opset,
781 None means :func:`get_opset_number_from_onnx
782 <mlprodict.tools.asv_options_helper.get_opset_number_from_onnx>`
783 :param check_runtime: checks the python runtime
784 :param models: only process a small list of operators,
785 set of model names
786 :param debug: stops whenever an exception
787 is raised
788 :param runtime: test a specific runtime, by default ``'python'``
789 :param dump_folder: dump information to replicate in case of mismatch
790 :param dump_all: dump all models not only the one which fail
791 :param store_models: if True, the function
792 also stores the fitted model and its conversion
793 into :epkg:`ONNX`
794 :param benchmark: if True, measures the time taken by each function
795 to predict for different number of rows
796 :param filter_exp: function which tells if the experiment must be run,
797 None to run all, takes *model, problem* as an input
798 :param filter_scenario: second function which tells if the experiment must be run,
799 None to run all, takes *model, problem, scenario, extra, options*
800 as an input
801 :param skip_models: models to skip
802 :param assume_finite: See `config_context
803 <https://scikit-learn.org/stable/modules/generated/
804 sklearn.config_context.html>`_, If True, validation for finiteness
805 will be skipped, saving time, but leading to potential crashes.
806 If False, validation for finiteness will be performed, avoiding error.
807 :param node_time: measure time execution for every node in the graph
808 :param versions: add columns with versions of used packages,
809 :epkg:`numpy`, :epkg:`scikit-learn`, :epkg:`onnx`,
810 :epkg:`onnxruntime`, :epkg:`sklearn-onnx`
811 :param extended_list: also check models this module implements a converter for
812 :param time_kwargs: to define a more precise way to measure a model
813 :param n_features: modifies the shorts datasets used to train the models
814 to use exactly this number of features, it can also
815 be a list to test multiple datasets
816 :param skip_long_test: skips tests for high values of N if they seem too long
817 :param fail_bad_results: fails if the results are aligned with :epkg:`scikit-learn`
818 :param time_kwargs_fact: see :func:`_multiply_time_kwargs
819 <mlprodict.onnxrt.validate.validate_helper._multiply_time_kwargs>`
820 :param time_limit: to skip the rest of the test after this limit (in second)
821 :param n_jobs: *n_jobs* is set to the number of CPU by default unless this
822 value is changed
823 :param fLOG: logging function
824 :return: list of dictionaries
826 The function is available through command line
827 :ref:`validate_runtime <l-cmd-validate_runtime>`.
828 The default for *time_kwargs* is the following:
830 .. runpython::
831 :showcode:
832 :warningout: DeprecationWarning
834 from mlprodict.onnxrt.validate.validate_helper import default_time_kwargs
835 import pprint
836 pprint.pprint(default_time_kwargs())
837 """
838 register_converters()
839 register_rewritten_operators()
840 ops = _enumerate_validated_operator_opsets_ops(
841 extended_list, models, skip_models)
843 if verbose > 0:
845 def iterate():
846 for i, row in enumerate(ops): # pragma: no cover
847 fLOG("{}/{} - {}".format(i + 1, len(ops), row))
848 yield row
850 if verbose >= 11:
851 verbose -= 10 # pragma: no cover
852 loop = iterate() # pragma: no cover
853 else:
854 try:
855 from tqdm import trange
857 def iterate_tqdm():
858 with trange(len(ops)) as t:
859 for i in t:
860 row = ops[i]
861 disp = row['name'] + " " * (28 - len(row['name']))
862 t.set_description("%s" % disp)
863 yield row
865 loop = iterate_tqdm()
867 except ImportError: # pragma: no cover
868 loop = iterate()
869 else:
870 loop = ops
872 if versions:
873 add_versions = _enumerate_validated_operator_opsets_version(runtime)
874 else:
875 add_versions = {}
877 current_opset = get_opset_number_from_onnx()
878 if opset_min == -1:
879 opset_min = get_opset_number_from_onnx()
880 if opset_max == -1:
881 opset_max = get_opset_number_from_onnx()
882 if verbose > 0 and fLOG is not None:
883 fLOG("[enumerate_validated_operator_opsets] opset in [{}, {}].".format(
884 opset_min, opset_max))
885 for row in loop:
887 model = row['cl']
888 if verbose > 1:
889 fLOG("[enumerate_validated_operator_opsets] - model='{}'".format(model))
891 for obs in enumerate_compatible_opset(
892 model, opset_min=opset_min, opset_max=opset_max,
893 check_runtime=check_runtime, runtime=runtime,
894 debug=debug, dump_folder=dump_folder,
895 store_models=store_models, benchmark=benchmark,
896 fLOG=fLOG, filter_exp=filter_exp,
897 assume_finite=assume_finite, node_time=node_time,
898 verbose=verbose, extended_list=extended_list,
899 time_kwargs=time_kwargs, dump_all=dump_all,
900 n_features=n_features, skip_long_test=skip_long_test,
901 filter_scenario=filter_scenario,
902 time_kwargs_fact=time_kwargs_fact,
903 time_limit=time_limit, n_jobs=n_jobs):
905 for mandkey in ('inst', 'method_name', 'problem',
906 'scenario'):
907 if '_0problem_exc' in obs:
908 continue
909 if mandkey not in obs:
910 raise ValueError("Missing key '{}' in\n{}".format(
911 mandkey, pprint.pformat(obs))) # pragma: no cover
912 if verbose > 1:
913 fLOG('[enumerate_validated_operator_opsets] - OBS')
914 if verbose > 2:
915 fLOG(" ", obs)
916 else:
917 obs_log = {k: v for k,
918 v in obs.items() if 'lambda-' not in k}
919 fLOG(" ", obs_log)
920 elif verbose > 0 and "_0problem_exc" in obs:
921 fLOG(" ???", obs) # pragma: no cover
923 diff = obs.get('max_rel_diff_batch', None)
924 batch = 'max_rel_diff_batch' in obs and diff is not None
925 op1 = obs.get('domain_opset_', '')
926 op2 = obs.get('domain_opset_ai.onnx.ml', '')
927 op = '{}/{}'.format(op1, op2)
929 obs['available'] = "?"
930 if diff is not None:
931 if diff < 1e-5:
932 obs['available'] = 'OK'
933 elif diff < 0.0001:
934 obs['available'] = 'e<0.0001'
935 elif diff < 0.001:
936 obs['available'] = 'e<0.001'
937 elif diff < 0.01:
938 obs['available'] = 'e<0.01' # pragma: no cover
939 elif diff < 0.1:
940 obs['available'] = 'e<0.1'
941 else:
942 obs['available'] = "ERROR->=%1.1f" % diff
943 obs['available'] += '-' + op
944 if not batch:
945 obs['available'] += "-NOBATCH" # pragma: no cover
946 if fail_bad_results and 'e<' in obs['available']:
947 raise RuntimeBadResultsError(
948 "Wrong results '{}'.".format(obs['available']), obs) # pragma: no cover
950 excs = []
951 for k, v in sorted(obs.items()):
952 if k.endswith('_exc'):
953 excs.append((k, v))
954 break
955 if 'opset' not in obs:
956 # It fails before the conversion happens.
957 obs['opset'] = current_opset
958 if obs['opset'] == current_opset and len(excs) > 0:
959 k, v = excs[0]
960 obs['available'] = 'ERROR-%s' % k
961 obs['available-ERROR'] = v
963 if 'bench-skl' in obs:
964 b1 = obs['bench-skl']
965 if 'bench-batch' in obs:
966 b2 = obs['bench-batch']
967 else:
968 b2 = None
969 if b1 is not None and b2 is not None:
970 for k in b1:
971 if k in b2 and b2[k] is not None and b1[k] is not None:
972 key = 'time-ratio-N=%d' % k
973 obs[key] = b2[k]['average'] / b1[k]['average']
974 key = 'time-ratio-N=%d-min' % k
975 obs[key] = b2[k]['min_exec'] / b1[k]['max_exec']
976 key = 'time-ratio-N=%d-max' % k
977 obs[key] = b2[k]['max_exec'] / b1[k]['min_exec']
979 obs.update(row)
980 obs.update(add_versions)
981 yield obs.copy()