Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Command line about validation of prediction runtime.
4"""
5import os
6from logging import getLogger
7import warnings
8import json
9from multiprocessing import Pool
10from pandas import DataFrame
11from sklearn.exceptions import ConvergenceWarning
14def validate_runtime(verbose=1, opset_min=-1, opset_max="",
15 check_runtime=True, runtime='python', debug=False,
16 models=None, out_raw="model_onnx_raw.xlsx",
17 out_summary="model_onnx_summary.xlsx",
18 dump_folder=None, dump_all=False, benchmark=False,
19 catch_warnings=True, assume_finite=True,
20 versions=False, skip_models=None,
21 extended_list=True, separate_process=False,
22 time_kwargs=None, n_features=None, fLOG=print,
23 out_graph=None, force_return=False,
24 dtype=None, skip_long_test=False,
25 number=1, repeat=1, time_kwargs_fact='lin',
26 time_limit=4, n_jobs=0):
27 """
28 Walks through most of :epkg:`scikit-learn` operators
29 or model or predictor or transformer, tries to convert
30 them into :epkg:`ONNX` and computes the predictions
31 with a specific runtime.
33 :param verbose: integer from 0 (None) to 2 (full verbose)
34 :param opset_min: tries every conversion from this minimum opset,
35 -1 to get the current opset
36 :param opset_max: tries every conversion up to maximum opset,
37 -1 to get the current opset
38 :param check_runtime: to check the runtime
39 and not only the conversion
40 :param runtime: runtime to check, python,
41 onnxruntime1 to check :epkg:`onnxruntime`,
42 onnxruntime2 to check every *ONNX* node independently
43 with onnxruntime, many runtime can be checked at the same time
44 if the value is a comma separated list
45 :param models: comma separated list of models to test or empty
46 string to test them all
47 :param skip_models: models to skip
48 :param debug: stops whenever an exception is raised,
49 only if *separate_process* is False
50 :param out_raw: output raw results into this file (excel format)
51 :param out_summary: output an aggregated view into this file (excel format)
52 :param dump_folder: folder where to dump information (pickle)
53 in case of mismatch
54 :param dump_all: dumps all models, not only the failing ones
55 :param benchmark: run benchmark
56 :param catch_warnings: catch warnings
57 :param assume_finite: See `config_context
58 <https://scikit-learn.org/stable/modules/generated/sklearn.config_context.html>`_,
59 If True, validation for finiteness will be skipped, saving time, but leading
60 to potential crashes. If False, validation for finiteness will be performed,
61 avoiding error.
62 :param versions: add columns with versions of used packages,
63 :epkg:`numpy`, :epkg:`scikit-learn`, :epkg:`onnx`, :epkg:`onnxruntime`,
64 :epkg:`sklearn-onnx`
65 :param extended_list: extends the list of :epkg:`scikit-learn` converters
66 with converters implemented in this module
67 :param separate_process: run every model in a separate process,
68 this option must be used to run all model in one row
69 even if one of them is crashing
70 :param time_kwargs: a dictionary which defines the number of rows and
71 the parameter *number* and *repeat* when benchmarking a model,
72 the value must follow :epkg:`json` format
73 :param n_features: change the default number of features for
74 a specific problem, it can also be a comma separated list
75 :param force_return: forces the function to return the results,
76 used when the results are produces through a separate process
77 :param out_graph: image name, to output a graph which summarizes
78 a benchmark in case it was run
79 :param dtype: '32' or '64' or None for both,
80 limits the test to one specific number types
81 :param skip_long_test: skips tests for high values of N if
82 they seem too long
83 :param number: to multiply number values in *time_kwargs*
84 :param repeat: to multiply repeat values in *time_kwargs*
85 :param time_kwargs_fact: to multiply number and repeat in
86 *time_kwargs* depending on the model
87 (see :func:`_multiply_time_kwargs <mlprodict.onnxrt.validate.validate_helper._multiply_time_kwargs>`)
88 :param time_limit: to stop benchmarking after this limit of time
89 :param n_jobs: force the number of jobs to have this value,
90 by default, it is equal to the number of CPU
91 :param fLOG: logging function
93 .. cmdref::
94 :title: Validate a runtime against scikit-learn
95 :cmd: -m mlprodict validate_runtime --help
96 :lid: l-cmd-validate_runtime
98 The command walks through all scikit-learn operators,
99 tries to convert them, checks the predictions,
100 and produces a report.
102 Example::
104 python -m mlprodict validate_runtime --models LogisticRegression,LinearRegression
106 Following example benchmarks models
107 :epkg:`sklearn:ensemble:RandomForestRegressor`,
108 :epkg:`sklearn:tree:DecisionTreeRegressor`, it compares
109 :epkg:`onnxruntime` against :epkg:`scikit-learn` for opset 10.
111 ::
113 python -m mlprodict validate_runtime -v 1 -o 10 -op 10 -c 1 -r onnxruntime1
114 -m RandomForestRegressor,DecisionTreeRegressor -out bench_onnxruntime.xlsx -b 1
116 Parameter ``--time_kwargs`` may be used to reduce or increase
117 bencharmak precisions. The following value tells the function
118 to run a benchmarks with datasets of 1 or 10 number, to repeat
119 a given number of time *number* predictions in one row.
120 The total time is divided by :math:`number \\times repeat``.
121 Parameter ``--time_kwargs_fact`` may be used to increase these
122 number for some specific models. ``'lin'`` multiplies
123 by 10 number when the model is linear.
125 ::
127 -t "{\\"1\\":{\\"number\\":10,\\"repeat\\":10},\\"10\\":{\\"number\\":5,\\"repeat\\":5}}"
129 The following example dumps every model in the list:
131 ::
133 python -m mlprodict validate_runtime --out_raw raw.csv --out_summary sum.csv
134 --models LinearRegression,LogisticRegression,DecisionTreeRegressor,DecisionTreeClassifier
135 -r python,onnxruntime1 -o 10 -op 10 -v 1 -b 1 -dum 1
136 -du model_dump -n 20,100,500 --out_graph benchmark.png --dtype 32
138 The command line generates a graph produced by function
139 :func:`plot_validate_benchmark
140 <mlprodict.onnxrt.validate.validate_graph.plot_validate_benchmark>`.
141 """
142 if separate_process:
143 return _validate_runtime_separate_process(
144 verbose=verbose, opset_min=opset_min, opset_max=opset_max,
145 check_runtime=check_runtime, runtime=runtime, debug=debug,
146 models=models, out_raw=out_raw,
147 out_summary=out_summary, dump_all=dump_all,
148 dump_folder=dump_folder, benchmark=benchmark,
149 catch_warnings=catch_warnings, assume_finite=assume_finite,
150 versions=versions, skip_models=skip_models,
151 extended_list=extended_list, time_kwargs=time_kwargs,
152 n_features=n_features, fLOG=fLOG, force_return=True,
153 out_graph=None, dtype=dtype, skip_long_test=skip_long_test,
154 time_kwargs_fact=time_kwargs_fact, time_limit=time_limit,
155 n_jobs=n_jobs)
157 from ..onnxrt.validate import enumerate_validated_operator_opsets # pylint: disable=E0402
159 if not isinstance(models, list):
160 models = (None if models in (None, "")
161 else models.strip().split(','))
162 if not isinstance(skip_models, list):
163 skip_models = ({} if skip_models in (None, "")
164 else skip_models.strip().split(','))
165 if verbose <= 1:
166 logger = getLogger('skl2onnx')
167 logger.disabled = True
168 if not dump_folder:
169 dump_folder = None
170 if dump_folder and not os.path.exists(dump_folder):
171 os.mkdir(dump_folder) # pragma: no cover
172 if dump_folder and not os.path.exists(dump_folder):
173 raise FileNotFoundError( # pragma: no cover
174 "Cannot find dump_folder '{0}'.".format(
175 dump_folder))
177 # handling parameters
178 if opset_max == "":
179 opset_max = None # pragma: no cover
180 if isinstance(opset_min, str):
181 opset_min = int(opset_min) # pragma: no cover
182 if isinstance(opset_max, str):
183 opset_max = int(opset_max)
184 if isinstance(verbose, str):
185 verbose = int(verbose) # pragma: no cover
186 if isinstance(extended_list, str):
187 extended_list = extended_list in (
188 '1', 'True', 'true') # pragma: no cover
189 if time_kwargs in (None, ''):
190 time_kwargs = None
191 if isinstance(time_kwargs, str):
192 time_kwargs = json.loads(time_kwargs)
193 # json only allows string as keys
194 time_kwargs = {int(k): v for k, v in time_kwargs.items()}
195 if isinstance(n_jobs, str):
196 n_jobs = int(n_jobs)
197 if n_jobs == 0:
198 n_jobs = None
199 if time_kwargs is not None and not isinstance(time_kwargs, dict):
200 raise ValueError( # pragma: no cover
201 "time_kwargs must be a dictionary not {}\n{}".format(
202 type(time_kwargs), time_kwargs))
203 if not isinstance(n_features, list):
204 if n_features in (None, ""):
205 n_features = None
206 elif ',' in n_features:
207 n_features = list(map(int, n_features.split(',')))
208 else:
209 n_features = int(n_features)
210 if not isinstance(runtime, list) and ',' in runtime:
211 runtime = runtime.split(',')
213 def fct_filter_exp(m, s):
214 return str(m) not in skip_models
216 if dtype in ('', None):
217 fct_filter = fct_filter_exp
218 elif dtype == '32':
219 def fct_filter_exp2(m, p):
220 return fct_filter_exp(m, p) and '64' not in p
221 fct_filter = fct_filter_exp2
222 elif dtype == '64': # pragma: no cover
223 def fct_filter_exp3(m, p):
224 return fct_filter_exp(m, p) and '64' in p
225 fct_filter = fct_filter_exp3
226 else:
227 raise ValueError( # pragma: no cover
228 "dtype must be empty, 32, 64 not '{}'.".format(dtype))
230 # time_kwargs
232 if benchmark:
233 if time_kwargs is None:
234 from ..onnxrt.validate.validate_helper import default_time_kwargs # pylint: disable=E0402
235 time_kwargs = default_time_kwargs()
236 for _, v in time_kwargs.items():
237 v['number'] *= number
238 v['repeat'] *= repeat
239 if verbose > 0:
240 fLOG("time_kwargs=%r" % time_kwargs)
242 # body
244 def build_rows(models_):
245 rows = list(enumerate_validated_operator_opsets(
246 verbose, models=models_, fLOG=fLOG, runtime=runtime, debug=debug,
247 dump_folder=dump_folder, opset_min=opset_min, opset_max=opset_max,
248 benchmark=benchmark, assume_finite=assume_finite, versions=versions,
249 extended_list=extended_list, time_kwargs=time_kwargs, dump_all=dump_all,
250 n_features=n_features, filter_exp=fct_filter,
251 skip_long_test=skip_long_test, time_limit=time_limit,
252 time_kwargs_fact=time_kwargs_fact, n_jobs=n_jobs))
253 return rows
255 def catch_build_rows(models_):
256 if catch_warnings:
257 with warnings.catch_warnings():
258 warnings.simplefilter("ignore",
259 (UserWarning, ConvergenceWarning,
260 RuntimeWarning, FutureWarning))
261 rows = build_rows(models_)
262 else:
263 rows = build_rows(models_) # pragma: no cover
264 return rows
266 rows = catch_build_rows(models)
267 res = _finalize(rows, out_raw, out_summary,
268 verbose, models, out_graph, fLOG)
269 return res if (force_return or verbose >= 2) else None
272def _finalize(rows, out_raw, out_summary, verbose, models, out_graph, fLOG):
273 from ..onnxrt.validate import summary_report # pylint: disable=E0402
274 from ..tools.cleaning import clean_error_msg # pylint: disable=E0402
276 # Drops data which cannot be serialized.
277 for row in rows:
278 keys = []
279 for k in row:
280 if 'lambda' in k:
281 keys.append(k)
282 for k in keys:
283 del row[k]
285 df = DataFrame(rows)
287 if out_raw:
288 if verbose > 0:
289 fLOG("Saving raw_data into '{}'.".format(out_raw))
290 if os.path.splitext(out_raw)[-1] == ".xlsx":
291 df.to_excel(out_raw, index=False)
292 else:
293 clean_error_msg(df).to_csv(out_raw, index=False)
295 if df.shape[0] == 0:
296 raise RuntimeError("No result produced by the benchmark.")
297 piv = summary_report(df)
298 if 'optim' not in piv:
299 raise RuntimeError( # pragma: no cover
300 "Unable to produce a summary. Missing column in \n{}".format(
301 piv.columns))
303 if out_summary:
304 if verbose > 0:
305 fLOG("Saving summary into '{}'.".format(out_summary))
306 if os.path.splitext(out_summary)[-1] == ".xlsx":
307 piv.to_excel(out_summary, index=False)
308 else:
309 clean_error_msg(piv).to_csv(out_summary, index=False)
311 if verbose > 1 and models is not None:
312 fLOG(piv.T)
313 if out_graph is not None:
314 if verbose > 0:
315 fLOG("Saving graph into '{}'.".format(out_graph))
316 from ..plotting.plotting import plot_validate_benchmark
317 fig = plot_validate_benchmark(piv)[0]
318 fig.savefig(out_graph)
320 return rows
323def _validate_runtime_dict(kwargs):
324 return validate_runtime(**kwargs)
327def _validate_runtime_separate_process(**kwargs):
328 models = kwargs['models']
329 if models in (None, ""):
330 from ..onnxrt.validate.validate_helper import sklearn_operators # pragma: no cover
331 models = [_['name']
332 for _ in sklearn_operators(extended=True)] # pragma: no cover
333 elif not isinstance(models, list):
334 models = models.strip().split(',')
336 skip_models = kwargs['skip_models']
337 skip_models = {} if skip_models in (
338 None, "") else skip_models.strip().split(',')
340 verbose = kwargs['verbose']
341 fLOG = kwargs['fLOG']
342 all_rows = []
343 skls = [m for m in models if m not in skip_models]
344 skls.sort()
346 if verbose > 0:
347 from tqdm import tqdm
348 pbar = tqdm(skls)
349 else:
350 pbar = skls # pragma: no cover
352 for op in pbar:
353 if not isinstance(pbar, list):
354 pbar.set_description("[%s]" % (op + " " * (25 - len(op))))
356 if kwargs['out_raw']:
357 out_raw = os.path.splitext(kwargs['out_raw'])
358 out_raw = "".join([out_raw[0], "_", op, out_raw[1]])
359 else:
360 out_raw = None # pragma: no cover
362 if kwargs['out_summary']:
363 out_summary = os.path.splitext(kwargs['out_summary'])
364 out_summary = "".join([out_summary[0], "_", op, out_summary[1]])
365 else:
366 out_summary = None # pragma: no cover
368 new_kwargs = kwargs.copy()
369 if 'fLOG' in new_kwargs:
370 del new_kwargs['fLOG']
371 new_kwargs['out_raw'] = out_raw
372 new_kwargs['out_summary'] = out_summary
373 new_kwargs['models'] = op
374 new_kwargs['verbose'] = 0 # tqdm fails
375 new_kwargs['out_graph'] = None
377 with Pool(1) as p:
378 try:
379 result = p.apply_async(_validate_runtime_dict, [new_kwargs])
380 lrows = result.get(timeout=150) # timeout fixed to 150s
381 all_rows.extend(lrows)
382 except Exception as e: # pylint: disable=W0703
383 all_rows.append({ # pragma: no cover
384 'name': op, 'scenario': 'CRASH',
385 'ERROR-msg': str(e).replace("\n", " -- ")
386 })
388 return _finalize(all_rows, kwargs['out_raw'], kwargs['out_summary'],
389 verbose, models, kwargs.get('out_graph', None), fLOG)