Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Command line about validation of prediction runtime. 

4""" 

5import os 

6from logging import getLogger 

7import warnings 

8import json 

9from multiprocessing import Pool 

10from pandas import DataFrame 

11from sklearn.exceptions import ConvergenceWarning 

12 

13 

14def validate_runtime(verbose=1, opset_min=-1, opset_max="", 

15 check_runtime=True, runtime='python', debug=False, 

16 models=None, out_raw="model_onnx_raw.xlsx", 

17 out_summary="model_onnx_summary.xlsx", 

18 dump_folder=None, dump_all=False, benchmark=False, 

19 catch_warnings=True, assume_finite=True, 

20 versions=False, skip_models=None, 

21 extended_list=True, separate_process=False, 

22 time_kwargs=None, n_features=None, fLOG=print, 

23 out_graph=None, force_return=False, 

24 dtype=None, skip_long_test=False, 

25 number=1, repeat=1, time_kwargs_fact='lin', 

26 time_limit=4, n_jobs=0): 

27 """ 

28 Walks through most of :epkg:`scikit-learn` operators 

29 or model or predictor or transformer, tries to convert 

30 them into :epkg:`ONNX` and computes the predictions 

31 with a specific runtime. 

32 

33 :param verbose: integer from 0 (None) to 2 (full verbose) 

34 :param opset_min: tries every conversion from this minimum opset, 

35 -1 to get the current opset 

36 :param opset_max: tries every conversion up to maximum opset, 

37 -1 to get the current opset 

38 :param check_runtime: to check the runtime 

39 and not only the conversion 

40 :param runtime: runtime to check, python, 

41 onnxruntime1 to check :epkg:`onnxruntime`, 

42 onnxruntime2 to check every *ONNX* node independently 

43 with onnxruntime, many runtime can be checked at the same time 

44 if the value is a comma separated list 

45 :param models: comma separated list of models to test or empty 

46 string to test them all 

47 :param skip_models: models to skip 

48 :param debug: stops whenever an exception is raised, 

49 only if *separate_process* is False 

50 :param out_raw: output raw results into this file (excel format) 

51 :param out_summary: output an aggregated view into this file (excel format) 

52 :param dump_folder: folder where to dump information (pickle) 

53 in case of mismatch 

54 :param dump_all: dumps all models, not only the failing ones 

55 :param benchmark: run benchmark 

56 :param catch_warnings: catch warnings 

57 :param assume_finite: See `config_context 

58 <https://scikit-learn.org/stable/modules/generated/sklearn.config_context.html>`_, 

59 If True, validation for finiteness will be skipped, saving time, but leading 

60 to potential crashes. If False, validation for finiteness will be performed, 

61 avoiding error. 

62 :param versions: add columns with versions of used packages, 

63 :epkg:`numpy`, :epkg:`scikit-learn`, :epkg:`onnx`, :epkg:`onnxruntime`, 

64 :epkg:`sklearn-onnx` 

65 :param extended_list: extends the list of :epkg:`scikit-learn` converters 

66 with converters implemented in this module 

67 :param separate_process: run every model in a separate process, 

68 this option must be used to run all model in one row 

69 even if one of them is crashing 

70 :param time_kwargs: a dictionary which defines the number of rows and 

71 the parameter *number* and *repeat* when benchmarking a model, 

72 the value must follow :epkg:`json` format 

73 :param n_features: change the default number of features for 

74 a specific problem, it can also be a comma separated list 

75 :param force_return: forces the function to return the results, 

76 used when the results are produces through a separate process 

77 :param out_graph: image name, to output a graph which summarizes 

78 a benchmark in case it was run 

79 :param dtype: '32' or '64' or None for both, 

80 limits the test to one specific number types 

81 :param skip_long_test: skips tests for high values of N if 

82 they seem too long 

83 :param number: to multiply number values in *time_kwargs* 

84 :param repeat: to multiply repeat values in *time_kwargs* 

85 :param time_kwargs_fact: to multiply number and repeat in 

86 *time_kwargs* depending on the model 

87 (see :func:`_multiply_time_kwargs <mlprodict.onnxrt.validate.validate_helper._multiply_time_kwargs>`) 

88 :param time_limit: to stop benchmarking after this limit of time 

89 :param n_jobs: force the number of jobs to have this value, 

90 by default, it is equal to the number of CPU 

91 :param fLOG: logging function 

92 

93 .. cmdref:: 

94 :title: Validate a runtime against scikit-learn 

95 :cmd: -m mlprodict validate_runtime --help 

96 :lid: l-cmd-validate_runtime 

97 

98 The command walks through all scikit-learn operators, 

99 tries to convert them, checks the predictions, 

100 and produces a report. 

101 

102 Example:: 

103 

104 python -m mlprodict validate_runtime --models LogisticRegression,LinearRegression 

105 

106 Following example benchmarks models 

107 :epkg:`sklearn:ensemble:RandomForestRegressor`, 

108 :epkg:`sklearn:tree:DecisionTreeRegressor`, it compares 

109 :epkg:`onnxruntime` against :epkg:`scikit-learn` for opset 10. 

110 

111 :: 

112 

113 python -m mlprodict validate_runtime -v 1 -o 10 -op 10 -c 1 -r onnxruntime1 

114 -m RandomForestRegressor,DecisionTreeRegressor -out bench_onnxruntime.xlsx -b 1 

115 

116 Parameter ``--time_kwargs`` may be used to reduce or increase 

117 bencharmak precisions. The following value tells the function 

118 to run a benchmarks with datasets of 1 or 10 number, to repeat 

119 a given number of time *number* predictions in one row. 

120 The total time is divided by :math:`number \\times repeat``. 

121 Parameter ``--time_kwargs_fact`` may be used to increase these 

122 number for some specific models. ``'lin'`` multiplies 

123 by 10 number when the model is linear. 

124 

125 :: 

126 

127 -t "{\\"1\\":{\\"number\\":10,\\"repeat\\":10},\\"10\\":{\\"number\\":5,\\"repeat\\":5}}" 

128 

129 The following example dumps every model in the list: 

130 

131 :: 

132 

133 python -m mlprodict validate_runtime --out_raw raw.csv --out_summary sum.csv 

134 --models LinearRegression,LogisticRegression,DecisionTreeRegressor,DecisionTreeClassifier 

135 -r python,onnxruntime1 -o 10 -op 10 -v 1 -b 1 -dum 1 

136 -du model_dump -n 20,100,500 --out_graph benchmark.png --dtype 32 

137 

138 The command line generates a graph produced by function 

139 :func:`plot_validate_benchmark 

140 <mlprodict.onnxrt.validate.validate_graph.plot_validate_benchmark>`. 

141 """ 

142 if separate_process: 

143 return _validate_runtime_separate_process( 

144 verbose=verbose, opset_min=opset_min, opset_max=opset_max, 

145 check_runtime=check_runtime, runtime=runtime, debug=debug, 

146 models=models, out_raw=out_raw, 

147 out_summary=out_summary, dump_all=dump_all, 

148 dump_folder=dump_folder, benchmark=benchmark, 

149 catch_warnings=catch_warnings, assume_finite=assume_finite, 

150 versions=versions, skip_models=skip_models, 

151 extended_list=extended_list, time_kwargs=time_kwargs, 

152 n_features=n_features, fLOG=fLOG, force_return=True, 

153 out_graph=None, dtype=dtype, skip_long_test=skip_long_test, 

154 time_kwargs_fact=time_kwargs_fact, time_limit=time_limit, 

155 n_jobs=n_jobs) 

156 

157 from ..onnxrt.validate import enumerate_validated_operator_opsets # pylint: disable=E0402 

158 

159 if not isinstance(models, list): 

160 models = (None if models in (None, "") 

161 else models.strip().split(',')) 

162 if not isinstance(skip_models, list): 

163 skip_models = ({} if skip_models in (None, "") 

164 else skip_models.strip().split(',')) 

165 if verbose <= 1: 

166 logger = getLogger('skl2onnx') 

167 logger.disabled = True 

168 if not dump_folder: 

169 dump_folder = None 

170 if dump_folder and not os.path.exists(dump_folder): 

171 os.mkdir(dump_folder) # pragma: no cover 

172 if dump_folder and not os.path.exists(dump_folder): 

173 raise FileNotFoundError( # pragma: no cover 

174 "Cannot find dump_folder '{0}'.".format( 

175 dump_folder)) 

176 

177 # handling parameters 

178 if opset_max == "": 

179 opset_max = None # pragma: no cover 

180 if isinstance(opset_min, str): 

181 opset_min = int(opset_min) # pragma: no cover 

182 if isinstance(opset_max, str): 

183 opset_max = int(opset_max) 

184 if isinstance(verbose, str): 

185 verbose = int(verbose) # pragma: no cover 

186 if isinstance(extended_list, str): 

187 extended_list = extended_list in ( 

188 '1', 'True', 'true') # pragma: no cover 

189 if time_kwargs in (None, ''): 

190 time_kwargs = None 

191 if isinstance(time_kwargs, str): 

192 time_kwargs = json.loads(time_kwargs) 

193 # json only allows string as keys 

194 time_kwargs = {int(k): v for k, v in time_kwargs.items()} 

195 if isinstance(n_jobs, str): 

196 n_jobs = int(n_jobs) 

197 if n_jobs == 0: 

198 n_jobs = None 

199 if time_kwargs is not None and not isinstance(time_kwargs, dict): 

200 raise ValueError( # pragma: no cover 

201 "time_kwargs must be a dictionary not {}\n{}".format( 

202 type(time_kwargs), time_kwargs)) 

203 if not isinstance(n_features, list): 

204 if n_features in (None, ""): 

205 n_features = None 

206 elif ',' in n_features: 

207 n_features = list(map(int, n_features.split(','))) 

208 else: 

209 n_features = int(n_features) 

210 if not isinstance(runtime, list) and ',' in runtime: 

211 runtime = runtime.split(',') 

212 

213 def fct_filter_exp(m, s): 

214 return str(m) not in skip_models 

215 

216 if dtype in ('', None): 

217 fct_filter = fct_filter_exp 

218 elif dtype == '32': 

219 def fct_filter_exp2(m, p): 

220 return fct_filter_exp(m, p) and '64' not in p 

221 fct_filter = fct_filter_exp2 

222 elif dtype == '64': # pragma: no cover 

223 def fct_filter_exp3(m, p): 

224 return fct_filter_exp(m, p) and '64' in p 

225 fct_filter = fct_filter_exp3 

226 else: 

227 raise ValueError( # pragma: no cover 

228 "dtype must be empty, 32, 64 not '{}'.".format(dtype)) 

229 

230 # time_kwargs 

231 

232 if benchmark: 

233 if time_kwargs is None: 

234 from ..onnxrt.validate.validate_helper import default_time_kwargs # pylint: disable=E0402 

235 time_kwargs = default_time_kwargs() 

236 for _, v in time_kwargs.items(): 

237 v['number'] *= number 

238 v['repeat'] *= repeat 

239 if verbose > 0: 

240 fLOG("time_kwargs=%r" % time_kwargs) 

241 

242 # body 

243 

244 def build_rows(models_): 

245 rows = list(enumerate_validated_operator_opsets( 

246 verbose, models=models_, fLOG=fLOG, runtime=runtime, debug=debug, 

247 dump_folder=dump_folder, opset_min=opset_min, opset_max=opset_max, 

248 benchmark=benchmark, assume_finite=assume_finite, versions=versions, 

249 extended_list=extended_list, time_kwargs=time_kwargs, dump_all=dump_all, 

250 n_features=n_features, filter_exp=fct_filter, 

251 skip_long_test=skip_long_test, time_limit=time_limit, 

252 time_kwargs_fact=time_kwargs_fact, n_jobs=n_jobs)) 

253 return rows 

254 

255 def catch_build_rows(models_): 

256 if catch_warnings: 

257 with warnings.catch_warnings(): 

258 warnings.simplefilter("ignore", 

259 (UserWarning, ConvergenceWarning, 

260 RuntimeWarning, FutureWarning)) 

261 rows = build_rows(models_) 

262 else: 

263 rows = build_rows(models_) # pragma: no cover 

264 return rows 

265 

266 rows = catch_build_rows(models) 

267 res = _finalize(rows, out_raw, out_summary, 

268 verbose, models, out_graph, fLOG) 

269 return res if (force_return or verbose >= 2) else None 

270 

271 

272def _finalize(rows, out_raw, out_summary, verbose, models, out_graph, fLOG): 

273 from ..onnxrt.validate import summary_report # pylint: disable=E0402 

274 from ..tools.cleaning import clean_error_msg # pylint: disable=E0402 

275 

276 # Drops data which cannot be serialized. 

277 for row in rows: 

278 keys = [] 

279 for k in row: 

280 if 'lambda' in k: 

281 keys.append(k) 

282 for k in keys: 

283 del row[k] 

284 

285 df = DataFrame(rows) 

286 

287 if out_raw: 

288 if verbose > 0: 

289 fLOG("Saving raw_data into '{}'.".format(out_raw)) 

290 if os.path.splitext(out_raw)[-1] == ".xlsx": 

291 df.to_excel(out_raw, index=False) 

292 else: 

293 clean_error_msg(df).to_csv(out_raw, index=False) 

294 

295 if df.shape[0] == 0: 

296 raise RuntimeError("No result produced by the benchmark.") 

297 piv = summary_report(df) 

298 if 'optim' not in piv: 

299 raise RuntimeError( # pragma: no cover 

300 "Unable to produce a summary. Missing column in \n{}".format( 

301 piv.columns)) 

302 

303 if out_summary: 

304 if verbose > 0: 

305 fLOG("Saving summary into '{}'.".format(out_summary)) 

306 if os.path.splitext(out_summary)[-1] == ".xlsx": 

307 piv.to_excel(out_summary, index=False) 

308 else: 

309 clean_error_msg(piv).to_csv(out_summary, index=False) 

310 

311 if verbose > 1 and models is not None: 

312 fLOG(piv.T) 

313 if out_graph is not None: 

314 if verbose > 0: 

315 fLOG("Saving graph into '{}'.".format(out_graph)) 

316 from ..plotting.plotting import plot_validate_benchmark 

317 fig = plot_validate_benchmark(piv)[0] 

318 fig.savefig(out_graph) 

319 

320 return rows 

321 

322 

323def _validate_runtime_dict(kwargs): 

324 return validate_runtime(**kwargs) 

325 

326 

327def _validate_runtime_separate_process(**kwargs): 

328 models = kwargs['models'] 

329 if models in (None, ""): 

330 from ..onnxrt.validate.validate_helper import sklearn_operators # pragma: no cover 

331 models = [_['name'] 

332 for _ in sklearn_operators(extended=True)] # pragma: no cover 

333 elif not isinstance(models, list): 

334 models = models.strip().split(',') 

335 

336 skip_models = kwargs['skip_models'] 

337 skip_models = {} if skip_models in ( 

338 None, "") else skip_models.strip().split(',') 

339 

340 verbose = kwargs['verbose'] 

341 fLOG = kwargs['fLOG'] 

342 all_rows = [] 

343 skls = [m for m in models if m not in skip_models] 

344 skls.sort() 

345 

346 if verbose > 0: 

347 from tqdm import tqdm 

348 pbar = tqdm(skls) 

349 else: 

350 pbar = skls # pragma: no cover 

351 

352 for op in pbar: 

353 if not isinstance(pbar, list): 

354 pbar.set_description("[%s]" % (op + " " * (25 - len(op)))) 

355 

356 if kwargs['out_raw']: 

357 out_raw = os.path.splitext(kwargs['out_raw']) 

358 out_raw = "".join([out_raw[0], "_", op, out_raw[1]]) 

359 else: 

360 out_raw = None # pragma: no cover 

361 

362 if kwargs['out_summary']: 

363 out_summary = os.path.splitext(kwargs['out_summary']) 

364 out_summary = "".join([out_summary[0], "_", op, out_summary[1]]) 

365 else: 

366 out_summary = None # pragma: no cover 

367 

368 new_kwargs = kwargs.copy() 

369 if 'fLOG' in new_kwargs: 

370 del new_kwargs['fLOG'] 

371 new_kwargs['out_raw'] = out_raw 

372 new_kwargs['out_summary'] = out_summary 

373 new_kwargs['models'] = op 

374 new_kwargs['verbose'] = 0 # tqdm fails 

375 new_kwargs['out_graph'] = None 

376 

377 with Pool(1) as p: 

378 try: 

379 result = p.apply_async(_validate_runtime_dict, [new_kwargs]) 

380 lrows = result.get(timeout=150) # timeout fixed to 150s 

381 all_rows.extend(lrows) 

382 except Exception as e: # pylint: disable=W0703 

383 all_rows.append({ # pragma: no cover 

384 'name': op, 'scenario': 'CRASH', 

385 'ERROR-msg': str(e).replace("\n", " -- ") 

386 }) 

387 

388 return _finalize(all_rows, kwargs['out_raw'], kwargs['out_summary'], 

389 verbose, models, kwargs.get('out_graph', None), fLOG)