Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Validates runtime for many :scikit-learn: operators. 

4The submodule relies on :epkg:`onnxconverter_common`, 

5:epkg:`sklearn-onnx`. 

6""" 

7import pprint 

8from inspect import signature 

9import numpy 

10from numpy.linalg import LinAlgError 

11import sklearn 

12from sklearn import __all__ as sklearn__all__, __version__ as sklearn_version 

13from sklearn.exceptions import ConvergenceWarning 

14from sklearn.utils._testing import ignore_warnings 

15from ... import __version__ as ort_version 

16from ...onnx_conv import to_onnx, register_converters, register_rewritten_operators 

17from ...tools.ort_wrapper import onnxrt_version 

18from ...tools.model_info import analyze_model, set_random_state 

19from ...tools.asv_options_helper import ( 

20 get_opset_number_from_onnx, get_ir_version_from_onnx) 

21from ..onnx_inference import OnnxInference 

22from ...onnx_tools.optim.sklearn_helper import inspect_sklearn_model, set_n_jobs 

23from ...onnx_tools.optim.onnx_helper import onnx_statistics 

24from ...onnx_tools.optim import onnx_optimisations 

25from .validate_problems import find_suitable_problem 

26from .validate_scenarios import _extra_parameters 

27from .validate_difference import measure_relative_difference 

28from .validate_helper import ( 

29 _dispsimple, sklearn_operators, 

30 _measure_time, _shape_exc, dump_into_folder, 

31 default_time_kwargs, RuntimeBadResultsError, 

32 _dictionary2str, _merge_options, _multiply_time_kwargs, 

33 _get_problem_data) 

34from .validate_benchmark import benchmark_fct 

35 

36 

37@ignore_warnings(category=(UserWarning, ConvergenceWarning)) 

38def _dofit_model(dofit, obs, inst, X_train, y_train, X_test, y_test, 

39 Xort_test, init_types, store_models, 

40 debug, verbose, fLOG): 

41 if dofit: 

42 if verbose >= 2 and fLOG is not None: 

43 fLOG("[enumerate_compatible_opset] fit, type: '{}' dtype: {}".format( 

44 type(X_train), getattr(X_train, 'dtype', '-'))) 

45 try: 

46 set_random_state(inst) 

47 if y_train is None: 

48 t4 = _measure_time(lambda: inst.fit(X_train))[1] 

49 else: 

50 t4 = _measure_time( 

51 lambda: inst.fit(X_train, y_train))[1] 

52 except (AttributeError, TypeError, ValueError, 

53 IndexError, NotImplementedError, MemoryError, 

54 LinAlgError, StopIteration) as e: 

55 if debug: 

56 raise # pragma: no cover 

57 obs["_1training_time_exc"] = str(e) 

58 return False 

59 

60 obs["training_time"] = t4 

61 try: 

62 skl_st = inspect_sklearn_model(inst) 

63 except NotImplementedError: 

64 skl_st = {} 

65 obs.update({'skl_' + k: v for k, v in skl_st.items()}) 

66 

67 if store_models: 

68 obs['MODEL'] = inst 

69 obs['X_test'] = X_test 

70 obs['Xort_test'] = Xort_test 

71 obs['init_types'] = init_types 

72 else: 

73 obs["training_time"] = 0. 

74 if store_models: 

75 obs['MODEL'] = inst 

76 obs['init_types'] = init_types 

77 

78 return True 

79 

80 

81def _run_skl_prediction(obs, check_runtime, assume_finite, inst, 

82 method_name, predict_kwargs, X_test, 

83 benchmark, debug, verbose, time_kwargs, 

84 skip_long_test, time_kwargs_fact, fLOG): 

85 if not check_runtime: 

86 return None # pragma: no cover 

87 if verbose >= 2 and fLOG is not None: 

88 fLOG("[enumerate_compatible_opset] check_runtime SKL {}-{}-{}-{}-{}".format( 

89 id(inst), method_name, predict_kwargs, time_kwargs, 

90 time_kwargs_fact)) 

91 with sklearn.config_context(assume_finite=assume_finite): 

92 # compute sklearn prediction 

93 obs['ort_version'] = ort_version 

94 try: 

95 meth = getattr(inst, method_name) 

96 except AttributeError as e: 

97 if debug: 

98 raise # pragma: no cover 

99 obs['_2skl_meth_exc'] = str(e) 

100 return e 

101 try: 

102 ypred, t4, ___ = _measure_time( 

103 lambda: meth(X_test, **predict_kwargs)) 

104 obs['lambda-skl'] = (lambda xo: meth(xo, **predict_kwargs), X_test) 

105 except (ValueError, AttributeError, TypeError, MemoryError, IndexError) as e: 

106 if debug: 

107 raise # pragma: no cover 

108 obs['_3prediction_exc'] = str(e) 

109 return e 

110 obs['prediction_time'] = t4 

111 obs['assume_finite'] = assume_finite 

112 if benchmark and 'lambda-skl' in obs: 

113 obs['bench-skl'] = benchmark_fct( 

114 *obs['lambda-skl'], obs=obs, 

115 time_kwargs=_multiply_time_kwargs( 

116 time_kwargs, time_kwargs_fact, inst), 

117 skip_long_test=skip_long_test) 

118 if verbose >= 3 and fLOG is not None: 

119 fLOG("[enumerate_compatible_opset] scikit-learn prediction") 

120 _dispsimple(ypred, fLOG) 

121 if verbose >= 2 and fLOG is not None: 

122 fLOG("[enumerate_compatible_opset] predictions stored") 

123 return ypred 

124 

125 

126def _retrieve_problems_extra(model, verbose, fLOG, extended_list): 

127 """ 

128 Use by @see fn enumerate_compatible_opset. 

129 """ 

130 extras = None 

131 if extended_list: 

132 from ...onnx_conv.validate_scenarios import find_suitable_problem as fsp_extended 

133 problems = fsp_extended(model) 

134 if problems is not None: 

135 from ...onnx_conv.validate_scenarios import build_custom_scenarios as fsp_scenarios 

136 extra_parameters = fsp_scenarios() 

137 

138 if verbose >= 2 and fLOG is not None: 

139 fLOG( 

140 "[enumerate_compatible_opset] found custom for model={}".format(model)) 

141 extras = extra_parameters.get(model, None) 

142 if extras is not None: 

143 fLOG( 

144 "[enumerate_compatible_opset] found custom scenarios={}".format(extras)) 

145 else: 

146 problems = None 

147 

148 if problems is None: 

149 # scikit-learn 

150 extra_parameters = _extra_parameters 

151 try: 

152 problems = find_suitable_problem(model) 

153 except RuntimeError as e: 

154 return {'name': model.__name__, 'skl_version': sklearn_version, 

155 '_0problem_exc': e}, extras 

156 extras = extra_parameters.get(model, [('default', {})]) 

157 

158 # checks existence of random_state 

159 sig = signature(model.__init__) 

160 if 'random_state' in sig.parameters: 

161 new_extras = [] 

162 for extra in extras: 

163 if 'random_state' not in extra[1]: 

164 ps = extra[1].copy() 

165 ps['random_state'] = 42 

166 if len(extra) == 2: 

167 extra = (extra[0], ps) 

168 else: 

169 extra = (extra[0], ps) + extra[2:] 

170 new_extras.append(extra) 

171 extras = new_extras 

172 

173 return problems, extras 

174 

175 

176def enumerate_compatible_opset(model, opset_min=-1, opset_max=-1, # pylint: disable=R0914 

177 check_runtime=True, debug=False, 

178 runtime='python', dump_folder=None, 

179 store_models=False, benchmark=False, 

180 assume_finite=True, node_time=False, 

181 fLOG=print, filter_exp=None, 

182 verbose=0, time_kwargs=None, 

183 extended_list=False, dump_all=False, 

184 n_features=None, skip_long_test=True, 

185 filter_scenario=None, time_kwargs_fact=None, 

186 time_limit=4, n_jobs=None): 

187 """ 

188 Lists all compatible opsets for a specific model. 

189 

190 @param model operator class 

191 @param opset_min starts with this opset 

192 @param opset_max ends with this opset (None to use 

193 current onnx opset) 

194 @param check_runtime checks that runtime can consume the 

195 model and compute predictions 

196 @param debug catch exception (True) or not (False) 

197 @param runtime test a specific runtime, by default ``'python'`` 

198 @param dump_folder dump information to replicate in case of mismatch 

199 @param dump_all dump all models not only the one which fail 

200 @param store_models if True, the function 

201 also stores the fitted model and its conversion 

202 into :epkg:`ONNX` 

203 @param benchmark if True, measures the time taken by each function 

204 to predict for different number of rows 

205 @param fLOG logging function 

206 @param filter_exp function which tells if the experiment must be run, 

207 None to run all, takes *model, problem* as an input 

208 @param filter_scenario second function which tells if the experiment must be run, 

209 None to run all, takes *model, problem, scenario, extra, options* 

210 as an input 

211 @param node_time collect time for each node in the :epkg:`ONNX` graph 

212 @param assume_finite See `config_context 

213 <https://scikit-learn.org/stable/modules/generated/ 

214 sklearn.config_context.html>`_, If True, validation for finiteness 

215 will be skipped, saving time, but leading to potential crashes. 

216 If False, validation for finiteness will be performed, avoiding error. 

217 @param verbose verbosity 

218 @param extended_list extends the list to custom converters 

219 and problems 

220 @param time_kwargs to define a more precise way to measure a model 

221 @param n_features modifies the shorts datasets used to train the models 

222 to use exactly this number of features, it can also 

223 be a list to test multiple datasets 

224 @param skip_long_test skips tests for high values of N if they seem too long 

225 @param time_kwargs_fact see :func:`_multiply_time_kwargs <mlprodict.onnxrt.validate.validate_helper._multiply_time_kwargs>` 

226 @param time_limit to stop benchmarking after this amount of time was spent 

227 @param n_jobs *n_jobs* is set to the number of CPU by default unless this 

228 value is changed 

229 @return dictionaries, each row has the following 

230 keys: opset, exception if any, conversion time, 

231 problem chosen to test the conversion... 

232 

233 The function requires :epkg:`sklearn-onnx`. 

234 The outcome can be seen at pages references 

235 by :ref:`l-onnx-availability`. 

236 The parameter *time_kwargs* is a dictionary which defines the 

237 number of times to repeat the same predictions in order 

238 to give more precise figures. The default value (if None) is returned 

239 by the following code: 

240 

241 .. runpython:: 

242 :showcode: 

243 :warningout: DeprecationWarning 

244 

245 from mlprodict.onnxrt.validate.validate_helper import default_time_kwargs 

246 import pprint 

247 pprint.pprint(default_time_kwargs()) 

248 

249 Parameter *time_kwargs_fact* multiples these values for some 

250 specific models. ``'lin'`` multiplies by 10 when the model 

251 is linear. 

252 """ 

253 if opset_min == -1: 

254 opset_min = get_opset_number_from_onnx() # pragma: no cover 

255 if opset_max == -1: 

256 opset_max = get_opset_number_from_onnx() # pragma: no cover 

257 if verbose > 0 and fLOG is not None: 

258 fLOG("[enumerate_compatible_opset] opset in [{}, {}].".format( 

259 opset_min, opset_max)) 

260 if verbose > 1 and fLOG: 

261 fLOG("[enumerate_compatible_opset] validate class '{}'.".format( 

262 model.__name__)) 

263 if verbose > 2: 

264 fLOG(model) 

265 

266 if time_kwargs is None: 

267 time_kwargs = default_time_kwargs() 

268 problems, extras = _retrieve_problems_extra( 

269 model, verbose, fLOG, extended_list) 

270 if isinstance(problems, dict): 

271 yield problems # pragma: no cover 

272 problems = [] # pragma: no cover 

273 

274 if opset_max is None: 

275 opset_max = get_opset_number_from_onnx() # pragma: no cover 

276 opsets = list(range(opset_min, opset_max + 1)) # pragma: no cover 

277 opsets.append(None) # pragma: no cover 

278 else: 

279 opsets = list(range(opset_min, opset_max + 1)) 

280 

281 if extras is None: 

282 problems = [] 

283 yield {'name': model.__name__, 'skl_version': sklearn_version, 

284 '_0problem_exc': 'SKIPPED'} 

285 

286 if not isinstance(n_features, list): 

287 n_features = [n_features] 

288 

289 for prob in problems: 

290 if filter_exp is not None and not filter_exp(model, prob): 

291 continue 

292 for n_feature in n_features: 

293 if verbose >= 2 and fLOG is not None: 

294 fLOG("[enumerate_compatible_opset] problem={} n_feature={}".format( 

295 prob, n_feature)) 

296 

297 (X_train, X_test, y_train, 

298 y_test, Xort_test, 

299 init_types, conv_options, method_name, 

300 output_index, dofit, predict_kwargs) = _get_problem_data(prob, n_feature) 

301 

302 for scenario_extra in extras: 

303 subset_problems = None 

304 optimisations = None 

305 new_conv_options = None 

306 if len(scenario_extra) > 2: 

307 options = scenario_extra[2] 

308 if isinstance(options, dict): 

309 subset_problems = options.get('subset_problems', None) 

310 optimisations = options.get('optim', None) 

311 new_conv_options = options.get('conv_options', None) 

312 else: 

313 subset_problems = options 

314 

315 if subset_problems and isinstance(subset_problems, (list, set)): 

316 if prob not in subset_problems: 

317 # Skips unrelated problem for a specific configuration. 

318 continue 

319 elif subset_problems is not None: 

320 raise RuntimeError( # pragma: no cover 

321 "subset_problems must be a set or a list not {}.".format( 

322 subset_problems)) 

323 

324 try: 

325 scenario, extra = scenario_extra[:2] 

326 except TypeError as e: # pragma: no cover 

327 raise TypeError( 

328 "Unable to interpret 'scenario_extra'\n{}".format( 

329 scenario_extra)) from e 

330 if optimisations is None: 

331 optimisations = [None] 

332 if new_conv_options is None: 

333 new_conv_options = [{}] 

334 

335 if (filter_scenario is not None and 

336 not filter_scenario(model, prob, scenario, 

337 extra, new_conv_options)): 

338 continue 

339 

340 if verbose >= 2 and fLOG is not None: 

341 fLOG("[enumerate_compatible_opset] ##############################") 

342 fLOG("[enumerate_compatible_opset] scenario={} optim={} extra={} dofit={} (problem={})".format( 

343 scenario, optimisations, extra, dofit, prob)) 

344 

345 # training 

346 obs = {'scenario': scenario, 'name': model.__name__, 

347 'skl_version': sklearn_version, 'problem': prob, 

348 'method_name': method_name, 'output_index': output_index, 

349 'fit': dofit, 'conv_options': conv_options, 

350 'idtype': Xort_test.dtype, 'predict_kwargs': predict_kwargs, 

351 'init_types': init_types, 'inst': extra if extra else None, 

352 'n_features': X_train.shape[1] if len(X_train.shape) == 2 else 1} 

353 inst = None 

354 extra = set_n_jobs(model, extra, n_jobs=n_jobs) 

355 try: 

356 inst = model(**extra) 

357 except TypeError as e: # pragma: no cover 

358 if debug: # pragma: no cover 

359 raise 

360 if "__init__() missing" not in str(e): 

361 raise RuntimeError( 

362 "Unable to instantiate model '{}'.\nextra=\n{}".format( 

363 model.__name__, pprint.pformat(extra))) from e 

364 yield obs.copy() 

365 continue 

366 

367 if not _dofit_model(dofit, obs, inst, X_train, y_train, X_test, y_test, 

368 Xort_test, init_types, store_models, 

369 debug, verbose, fLOG): 

370 yield obs.copy() 

371 continue 

372 

373 # statistics about the trained model 

374 skl_infos = analyze_model(inst) 

375 for k, v in skl_infos.items(): 

376 obs['fit_' + k] = v 

377 

378 # runtime 

379 ypred = _run_skl_prediction( 

380 obs, check_runtime, assume_finite, inst, 

381 method_name, predict_kwargs, X_test, 

382 benchmark, debug, verbose, time_kwargs, 

383 skip_long_test, time_kwargs_fact, fLOG) 

384 if isinstance(ypred, Exception): 

385 yield obs.copy() 

386 continue 

387 

388 for run_obs in _call_conv_runtime_opset( 

389 obs=obs.copy(), opsets=opsets, debug=debug, 

390 new_conv_options=new_conv_options, 

391 model=model, prob=prob, scenario=scenario, 

392 extra=extra, extras=extras, conv_options=conv_options, 

393 init_types=init_types, inst=inst, 

394 optimisations=optimisations, verbose=verbose, 

395 benchmark=benchmark, runtime=runtime, 

396 filter_scenario=filter_scenario, 

397 X_test=X_test, y_test=y_test, ypred=ypred, 

398 Xort_test=Xort_test, method_name=method_name, 

399 check_runtime=check_runtime, 

400 output_index=output_index, 

401 kwargs=dict( 

402 dump_all=dump_all, 

403 dump_folder=dump_folder, 

404 node_time=node_time, 

405 skip_long_test=skip_long_test, 

406 store_models=store_models, 

407 time_kwargs=_multiply_time_kwargs( 

408 time_kwargs, time_kwargs_fact, inst) 

409 ), 

410 time_limit=time_limit, 

411 fLOG=fLOG): 

412 yield run_obs 

413 

414 

415def _check_run_benchmark(benchmark, stat_onnx, bench_memo, runtime): 

416 unique = set(stat_onnx.items()) 

417 unique.add(runtime) 

418 run_benchmark = benchmark and all( 

419 map(lambda u: unique != u, bench_memo)) 

420 if run_benchmark: 

421 bench_memo.append(unique) 

422 return run_benchmark 

423 

424 

425def _call_conv_runtime_opset( 

426 obs, opsets, debug, new_conv_options, 

427 model, prob, scenario, extra, extras, conv_options, 

428 init_types, inst, optimisations, verbose, 

429 benchmark, runtime, filter_scenario, 

430 check_runtime, X_test, y_test, ypred, Xort_test, 

431 method_name, output_index, 

432 kwargs, time_limit, fLOG): 

433 # Calls the conversion and runtime for different opets 

434 if None in opsets: 

435 set_opsets = [None] + list(sorted((_ for _ in opsets if _ is not None), 

436 reverse=True)) 

437 else: 

438 set_opsets = list(sorted(opsets, reverse=True)) 

439 bench_memo = [] 

440 

441 for opset in set_opsets: 

442 if verbose >= 2 and fLOG is not None: 

443 fLOG("[enumerate_compatible_opset] opset={} init_types={}".format( 

444 opset, init_types)) 

445 obs_op = obs.copy() 

446 if opset is not None: 

447 obs_op['opset'] = opset 

448 

449 if len(init_types) != 1: 

450 raise NotImplementedError( # pragma: no cover 

451 "Multiple types are is not implemented: " 

452 "{}.".format(init_types)) 

453 

454 if not isinstance(runtime, list): 

455 runtime = [runtime] 

456 

457 obs_op_0c = obs_op.copy() 

458 for aoptions in new_conv_options: 

459 obs_op = obs_op_0c.copy() 

460 all_conv_options = {} if conv_options is None else conv_options.copy() 

461 all_conv_options = _merge_options( 

462 all_conv_options, aoptions) 

463 obs_op['conv_options'] = all_conv_options 

464 

465 if (filter_scenario is not None and 

466 not filter_scenario(model, prob, scenario, 

467 extra, all_conv_options)): 

468 continue 

469 

470 for rt in runtime: 

471 def fct_conv(itt=inst, it=init_types[0][1], ops=opset, 

472 options=all_conv_options): 

473 return to_onnx(itt, it, target_opset=ops, options=options, 

474 rewrite_ops=rt in ('', None, 'python', 

475 'python_compiled')) 

476 

477 if verbose >= 2 and fLOG is not None: 

478 fLOG( 

479 "[enumerate_compatible_opset] conversion to onnx: {}".format(all_conv_options)) 

480 try: 

481 conv, t4 = _measure_time(fct_conv)[:2] 

482 obs_op["convert_time"] = t4 

483 except (RuntimeError, IndexError, AttributeError, TypeError, 

484 ValueError, NameError, NotImplementedError) as e: 

485 if debug: 

486 fLOG(pprint.pformat(obs_op)) # pragma: no cover 

487 raise # pragma: no cover 

488 obs_op["_4convert_exc"] = e 

489 yield obs_op.copy() 

490 continue 

491 

492 if verbose >= 6 and fLOG is not None: 

493 fLOG( # pragma: no cover 

494 "[enumerate_compatible_opset] ONNX:\n{}".format(conv)) 

495 

496 if all_conv_options.get('optim', '') == 'cdist': # pragma: no cover 

497 check_cdist = [_ for _ in str(conv).split('\n') 

498 if 'CDist' in _] 

499 check_scan = [_ for _ in str(conv).split('\n') 

500 if 'Scan' in _] 

501 if len(check_cdist) == 0 and len(check_scan) > 0: 

502 raise RuntimeError( 

503 "Operator CDist was not used in\n{}" 

504 "".format(conv)) 

505 

506 obs_op0 = obs_op.copy() 

507 for optimisation in optimisations: 

508 obs_op = obs_op0.copy() 

509 if optimisation is not None: 

510 if optimisation == 'onnx': 

511 obs_op['optim'] = optimisation 

512 if len(aoptions) != 0: 

513 obs_op['optim'] += '/' + \ 

514 _dictionary2str(aoptions) 

515 conv = onnx_optimisations(conv) 

516 else: 

517 raise ValueError( # pragma: no cover 

518 "Unknown optimisation option '{}' (extra={})" 

519 "".format(optimisation, extras)) 

520 else: 

521 obs_op['optim'] = _dictionary2str(aoptions) 

522 

523 if verbose >= 3 and fLOG is not None: 

524 fLOG("[enumerate_compatible_opset] optim='{}' optimisation={} all_conv_options={}".format( 

525 obs_op['optim'], optimisation, all_conv_options)) 

526 if kwargs['store_models']: 

527 obs_op['ONNX'] = conv 

528 if verbose >= 2 and fLOG is not None: 

529 fLOG( # pragma: no cover 

530 "[enumerate_compatible_opset] onnx nodes: {}".format( 

531 len(conv.graph.node))) 

532 stat_onnx = onnx_statistics(conv) 

533 obs_op.update( 

534 {'onx_' + k: v for k, v in stat_onnx.items()}) 

535 

536 # opset_domain 

537 for op_imp in list(conv.opset_import): 

538 obs_op['domain_opset_%s' % 

539 op_imp.domain] = op_imp.version 

540 

541 run_benchmark = _check_run_benchmark( 

542 benchmark, stat_onnx, bench_memo, rt) 

543 

544 # prediction 

545 if check_runtime: 

546 yield _call_runtime(obs_op=obs_op.copy(), conv=conv, 

547 opset=opset, debug=debug, 

548 runtime=rt, inst=inst, 

549 X_test=X_test, y_test=y_test, 

550 init_types=init_types, 

551 method_name=method_name, 

552 output_index=output_index, 

553 ypred=ypred, Xort_test=Xort_test, 

554 model=model, 

555 dump_folder=kwargs['dump_folder'], 

556 benchmark=run_benchmark, 

557 node_time=kwargs['node_time'], 

558 time_kwargs=kwargs['time_kwargs'], 

559 fLOG=fLOG, verbose=verbose, 

560 store_models=kwargs['store_models'], 

561 dump_all=kwargs['dump_all'], 

562 skip_long_test=kwargs['skip_long_test'], 

563 time_limit=time_limit) 

564 else: 

565 yield obs_op.copy() # pragma: no cover 

566 

567 

568def _call_runtime(obs_op, conv, opset, debug, inst, runtime, 

569 X_test, y_test, init_types, method_name, output_index, 

570 ypred, Xort_test, model, dump_folder, 

571 benchmark, node_time, fLOG, 

572 verbose, store_models, time_kwargs, 

573 dump_all, skip_long_test, time_limit): 

574 """ 

575 Private. 

576 """ 

577 if 'onnxruntime' in runtime: 

578 old = conv.ir_version 

579 conv.ir_version = get_ir_version_from_onnx() 

580 else: 

581 old = None 

582 

583 ser, t5, ___ = _measure_time(lambda: conv.SerializeToString()) 

584 obs_op['tostring_time'] = t5 

585 obs_op['runtime'] = runtime 

586 

587 if old is not None: 

588 conv.ir_version = old 

589 

590 # load 

591 if verbose >= 2 and fLOG is not None: 

592 fLOG("[enumerate_compatible_opset-R] load onnx") 

593 try: 

594 sess, t5, ___ = _measure_time( 

595 lambda: OnnxInference(ser, runtime=runtime)) 

596 obs_op['tostring_time'] = t5 

597 except (RuntimeError, ValueError, KeyError, IndexError, TypeError) as e: 

598 if debug: 

599 raise # pragma: no cover 

600 obs_op['_5ort_load_exc'] = e 

601 return obs_op 

602 

603 # compute batch 

604 if store_models: 

605 obs_op['OINF'] = sess 

606 if verbose >= 2 and fLOG is not None: 

607 fLOG("[enumerate_compatible_opset-R] compute batch with runtime " 

608 "'{}'".format(runtime)) 

609 

610 def fct_batch(se=sess, xo=Xort_test, it=init_types): # pylint: disable=W0102 

611 return se.run({it[0][0]: xo}, 

612 verbose=max(verbose - 1, 1) if debug else 0, fLOG=fLOG) 

613 

614 try: 

615 opred, t5, ___ = _measure_time(fct_batch) 

616 obs_op['ort_run_time_batch'] = t5 

617 obs_op['lambda-batch'] = (lambda xo: sess.run( 

618 {init_types[0][0]: xo}, node_time=node_time), Xort_test) 

619 except (RuntimeError, TypeError, ValueError, KeyError, IndexError) as e: 

620 if debug: 

621 raise RuntimeError("Issue with {}.".format( 

622 obs_op)) from e # pragma: no cover 

623 obs_op['_6ort_run_batch_exc'] = e 

624 if (benchmark or node_time) and 'lambda-batch' in obs_op: 

625 try: 

626 benres = benchmark_fct(*obs_op['lambda-batch'], obs=obs_op, 

627 node_time=node_time, time_kwargs=time_kwargs, 

628 skip_long_test=skip_long_test, 

629 time_limit=time_limit) 

630 obs_op['bench-batch'] = benres 

631 except (RuntimeError, TypeError, ValueError) as e: # pragma: no cover 

632 if debug: 

633 raise e # pragma: no cover 

634 obs_op['_6ort_run_batch_exc'] = e 

635 obs_op['_6ort_run_batch_bench_exc'] = e 

636 

637 # difference 

638 debug_exc = [] 

639 if verbose >= 2 and fLOG is not None: 

640 fLOG("[enumerate_compatible_opset-R] differences") 

641 if '_6ort_run_batch_exc' not in obs_op: 

642 if isinstance(opred, dict): 

643 ch = [(k, v) for k, v in opred.items()] 

644 opred = [_[1] for _ in ch] 

645 

646 if output_index != 'all': 

647 try: 

648 opred = opred[output_index] 

649 except IndexError as e: # pragma: no cover 

650 if debug: 

651 raise IndexError( 

652 "Issue with output_index={}/{}".format( 

653 output_index, len(opred))) from e 

654 obs_op['_8max_rel_diff_batch_exc'] = ( 

655 "Unable to fetch output {}/{} for model '{}'" 

656 "".format(output_index, len(opred), 

657 model.__name__)) 

658 opred = None 

659 

660 if opred is not None: 

661 if store_models: 

662 obs_op['skl_outputs'] = ypred 

663 obs_op['ort_outputs'] = opred 

664 if verbose >= 3 and fLOG is not None: 

665 fLOG("[_call_runtime] runtime prediction") 

666 _dispsimple(opred, fLOG) 

667 

668 if (method_name == "decision_function" and hasattr(opred, 'shape') and 

669 hasattr(ypred, 'shape') and len(opred.shape) == 2 and 

670 opred.shape[1] == 2 and len(ypred.shape) == 1): 

671 # decision_function, for binary classification, 

672 # raw score is a distance 

673 max_rel_diff = measure_relative_difference( 

674 ypred, opred[:, 1]) 

675 else: 

676 max_rel_diff = measure_relative_difference( 

677 ypred, opred) 

678 

679 if max_rel_diff >= 1e9 and debug: # pragma: no cover 

680 _shape = lambda o: o.shape if hasattr( 

681 o, 'shape') else 'no shape' 

682 raise RuntimeError( 

683 "Big difference (opset={}, runtime='{}' p='{}' s='{}')" 

684 ":\n-------\n{}-{}\n{}\n--------\n{}-{}\n{}".format( 

685 opset, runtime, obs_op['problem'], obs_op['scenario'], 

686 type(ypred), _shape(ypred), ypred, 

687 type(opred), _shape(opred), opred)) 

688 

689 if numpy.isnan(max_rel_diff): 

690 obs_op['_8max_rel_diff_batch_exc'] = ( # pragma: no cover 

691 "Unable to compute differences between" 

692 " {}-{}\n{}\n--------\n{}".format( 

693 _shape_exc( 

694 ypred), _shape_exc(opred), 

695 ypred, opred)) 

696 if debug: # pragma: no cover 

697 debug_exc.append(RuntimeError( 

698 obs_op['_8max_rel_diff_batch_exc'])) 

699 else: 

700 obs_op['max_rel_diff_batch'] = max_rel_diff 

701 if dump_folder and max_rel_diff > 1e-5: 

702 dump_into_folder(dump_folder, kind='batch', obs_op=obs_op, 

703 X_test=X_test, y_test=y_test, Xort_test=Xort_test) 

704 if debug and max_rel_diff >= 0.1: # pragma: no cover 

705 raise RuntimeError("Two big differences {}\n{}\n{}\n{}".format( 

706 max_rel_diff, inst, conv, pprint.pformat(obs_op))) 

707 

708 if debug and len(debug_exc) == 2: 

709 raise debug_exc[0] # pragma: no cover 

710 if debug and verbose >= 2: # pragma: no cover 

711 if verbose >= 3: 

712 fLOG(pprint.pformat(obs_op)) 

713 else: 

714 obs_op_log = {k: v for k, 

715 v in obs_op.items() if 'lambda-' not in k} 

716 fLOG(pprint.pformat(obs_op_log)) 

717 if verbose >= 2 and fLOG is not None: 

718 fLOG("[enumerate_compatible_opset-R] next...") 

719 if dump_all: 

720 dump = dump_into_folder(dump_folder, kind='batch', obs_op=obs_op, 

721 X_test=X_test, y_test=y_test, Xort_test=Xort_test, 

722 is_error=len(debug_exc) > 1, 

723 onnx_bytes=conv.SerializeToString(), 

724 skl_model=inst, ypred=ypred) 

725 obs_op['dumped'] = dump 

726 return obs_op 

727 

728 

729def _enumerate_validated_operator_opsets_ops(extended_list, models, skip_models): 

730 ops = [_ for _ in sklearn_operators(extended=extended_list)] 

731 

732 if models is not None: 

733 if not all(map(lambda m: isinstance(m, str), models)): 

734 raise ValueError( # pragma: no cover 

735 "models must be a set of strings.") 

736 ops_ = [_ for _ in ops if _['name'] in models] 

737 if len(ops) == 0: 

738 raise ValueError( # pragma: no cover 

739 "Parameter models is wrong: {}\n{}".format( 

740 models, ops[0])) 

741 ops = ops_ 

742 if skip_models is not None: 

743 ops = [m for m in ops if m['name'] not in skip_models] 

744 return ops 

745 

746 

747def _enumerate_validated_operator_opsets_version(runtime): 

748 from numpy import __version__ as numpy_version 

749 from onnx import __version__ as onnx_version 

750 from scipy import __version__ as scipy_version 

751 from skl2onnx import __version__ as skl2onnx_version 

752 add_versions = {'v_numpy': numpy_version, 'v_onnx': onnx_version, 

753 'v_scipy': scipy_version, 'v_skl2onnx': skl2onnx_version, 

754 'v_sklearn': sklearn_version, 'v_onnxruntime': ort_version} 

755 if "onnxruntime" in runtime: 

756 add_versions['v_onnxruntime'] = onnxrt_version 

757 return add_versions 

758 

759 

760def enumerate_validated_operator_opsets(verbose=0, opset_min=-1, opset_max=-1, 

761 check_runtime=True, debug=False, runtime='python', 

762 models=None, dump_folder=None, store_models=False, 

763 benchmark=False, skip_models=None, 

764 assume_finite=True, node_time=False, 

765 fLOG=print, filter_exp=None, 

766 versions=False, extended_list=False, 

767 time_kwargs=None, dump_all=False, 

768 n_features=None, skip_long_test=True, 

769 fail_bad_results=False, 

770 filter_scenario=None, 

771 time_kwargs_fact=None, 

772 time_limit=4, n_jobs=None): 

773 """ 

774 Tests all possible configurations for all possible 

775 operators and returns the results. 

776 

777 :param verbose: integer 0, 1, 2 

778 :param opset_min: checks conversion starting from the opset, -1 

779 to get the last one 

780 :param opset_max: checks conversion up to this opset, 

781 None means :func:`get_opset_number_from_onnx 

782 <mlprodict.tools.asv_options_helper.get_opset_number_from_onnx>` 

783 :param check_runtime: checks the python runtime 

784 :param models: only process a small list of operators, 

785 set of model names 

786 :param debug: stops whenever an exception 

787 is raised 

788 :param runtime: test a specific runtime, by default ``'python'`` 

789 :param dump_folder: dump information to replicate in case of mismatch 

790 :param dump_all: dump all models not only the one which fail 

791 :param store_models: if True, the function 

792 also stores the fitted model and its conversion 

793 into :epkg:`ONNX` 

794 :param benchmark: if True, measures the time taken by each function 

795 to predict for different number of rows 

796 :param filter_exp: function which tells if the experiment must be run, 

797 None to run all, takes *model, problem* as an input 

798 :param filter_scenario: second function which tells if the experiment must be run, 

799 None to run all, takes *model, problem, scenario, extra, options* 

800 as an input 

801 :param skip_models: models to skip 

802 :param assume_finite: See `config_context 

803 <https://scikit-learn.org/stable/modules/generated/ 

804 sklearn.config_context.html>`_, If True, validation for finiteness 

805 will be skipped, saving time, but leading to potential crashes. 

806 If False, validation for finiteness will be performed, avoiding error. 

807 :param node_time: measure time execution for every node in the graph 

808 :param versions: add columns with versions of used packages, 

809 :epkg:`numpy`, :epkg:`scikit-learn`, :epkg:`onnx`, 

810 :epkg:`onnxruntime`, :epkg:`sklearn-onnx` 

811 :param extended_list: also check models this module implements a converter for 

812 :param time_kwargs: to define a more precise way to measure a model 

813 :param n_features: modifies the shorts datasets used to train the models 

814 to use exactly this number of features, it can also 

815 be a list to test multiple datasets 

816 :param skip_long_test: skips tests for high values of N if they seem too long 

817 :param fail_bad_results: fails if the results are aligned with :epkg:`scikit-learn` 

818 :param time_kwargs_fact: see :func:`_multiply_time_kwargs 

819 <mlprodict.onnxrt.validate.validate_helper._multiply_time_kwargs>` 

820 :param time_limit: to skip the rest of the test after this limit (in second) 

821 :param n_jobs: *n_jobs* is set to the number of CPU by default unless this 

822 value is changed 

823 :param fLOG: logging function 

824 :return: list of dictionaries 

825 

826 The function is available through command line 

827 :ref:`validate_runtime <l-cmd-validate_runtime>`. 

828 The default for *time_kwargs* is the following: 

829 

830 .. runpython:: 

831 :showcode: 

832 :warningout: DeprecationWarning 

833 

834 from mlprodict.onnxrt.validate.validate_helper import default_time_kwargs 

835 import pprint 

836 pprint.pprint(default_time_kwargs()) 

837 """ 

838 register_converters() 

839 register_rewritten_operators() 

840 ops = _enumerate_validated_operator_opsets_ops( 

841 extended_list, models, skip_models) 

842 

843 if verbose > 0: 

844 

845 def iterate(): 

846 for i, row in enumerate(ops): # pragma: no cover 

847 fLOG("{}/{} - {}".format(i + 1, len(ops), row)) 

848 yield row 

849 

850 if verbose >= 11: 

851 verbose -= 10 # pragma: no cover 

852 loop = iterate() # pragma: no cover 

853 else: 

854 try: 

855 from tqdm import trange 

856 

857 def iterate_tqdm(): 

858 with trange(len(ops)) as t: 

859 for i in t: 

860 row = ops[i] 

861 disp = row['name'] + " " * (28 - len(row['name'])) 

862 t.set_description("%s" % disp) 

863 yield row 

864 

865 loop = iterate_tqdm() 

866 

867 except ImportError: # pragma: no cover 

868 loop = iterate() 

869 else: 

870 loop = ops 

871 

872 if versions: 

873 add_versions = _enumerate_validated_operator_opsets_version(runtime) 

874 else: 

875 add_versions = {} 

876 

877 current_opset = get_opset_number_from_onnx() 

878 if opset_min == -1: 

879 opset_min = get_opset_number_from_onnx() 

880 if opset_max == -1: 

881 opset_max = get_opset_number_from_onnx() 

882 if verbose > 0 and fLOG is not None: 

883 fLOG("[enumerate_validated_operator_opsets] opset in [{}, {}].".format( 

884 opset_min, opset_max)) 

885 for row in loop: 

886 

887 model = row['cl'] 

888 if verbose > 1: 

889 fLOG("[enumerate_validated_operator_opsets] - model='{}'".format(model)) 

890 

891 for obs in enumerate_compatible_opset( 

892 model, opset_min=opset_min, opset_max=opset_max, 

893 check_runtime=check_runtime, runtime=runtime, 

894 debug=debug, dump_folder=dump_folder, 

895 store_models=store_models, benchmark=benchmark, 

896 fLOG=fLOG, filter_exp=filter_exp, 

897 assume_finite=assume_finite, node_time=node_time, 

898 verbose=verbose, extended_list=extended_list, 

899 time_kwargs=time_kwargs, dump_all=dump_all, 

900 n_features=n_features, skip_long_test=skip_long_test, 

901 filter_scenario=filter_scenario, 

902 time_kwargs_fact=time_kwargs_fact, 

903 time_limit=time_limit, n_jobs=n_jobs): 

904 

905 for mandkey in ('inst', 'method_name', 'problem', 

906 'scenario'): 

907 if '_0problem_exc' in obs: 

908 continue 

909 if mandkey not in obs: 

910 raise ValueError("Missing key '{}' in\n{}".format( 

911 mandkey, pprint.pformat(obs))) # pragma: no cover 

912 if verbose > 1: 

913 fLOG('[enumerate_validated_operator_opsets] - OBS') 

914 if verbose > 2: 

915 fLOG(" ", obs) 

916 else: 

917 obs_log = {k: v for k, 

918 v in obs.items() if 'lambda-' not in k} 

919 fLOG(" ", obs_log) 

920 elif verbose > 0 and "_0problem_exc" in obs: 

921 fLOG(" ???", obs) # pragma: no cover 

922 

923 diff = obs.get('max_rel_diff_batch', None) 

924 batch = 'max_rel_diff_batch' in obs and diff is not None 

925 op1 = obs.get('domain_opset_', '') 

926 op2 = obs.get('domain_opset_ai.onnx.ml', '') 

927 op = '{}/{}'.format(op1, op2) 

928 

929 obs['available'] = "?" 

930 if diff is not None: 

931 if diff < 1e-5: 

932 obs['available'] = 'OK' 

933 elif diff < 0.0001: 

934 obs['available'] = 'e<0.0001' 

935 elif diff < 0.001: 

936 obs['available'] = 'e<0.001' 

937 elif diff < 0.01: 

938 obs['available'] = 'e<0.01' # pragma: no cover 

939 elif diff < 0.1: 

940 obs['available'] = 'e<0.1' 

941 else: 

942 obs['available'] = "ERROR->=%1.1f" % diff 

943 obs['available'] += '-' + op 

944 if not batch: 

945 obs['available'] += "-NOBATCH" # pragma: no cover 

946 if fail_bad_results and 'e<' in obs['available']: 

947 raise RuntimeBadResultsError( 

948 "Wrong results '{}'.".format(obs['available']), obs) # pragma: no cover 

949 

950 excs = [] 

951 for k, v in sorted(obs.items()): 

952 if k.endswith('_exc'): 

953 excs.append((k, v)) 

954 break 

955 if 'opset' not in obs: 

956 # It fails before the conversion happens. 

957 obs['opset'] = current_opset 

958 if obs['opset'] == current_opset and len(excs) > 0: 

959 k, v = excs[0] 

960 obs['available'] = 'ERROR-%s' % k 

961 obs['available-ERROR'] = v 

962 

963 if 'bench-skl' in obs: 

964 b1 = obs['bench-skl'] 

965 if 'bench-batch' in obs: 

966 b2 = obs['bench-batch'] 

967 else: 

968 b2 = None 

969 if b1 is not None and b2 is not None: 

970 for k in b1: 

971 if k in b2 and b2[k] is not None and b1[k] is not None: 

972 key = 'time-ratio-N=%d' % k 

973 obs[key] = b2[k]['average'] / b1[k]['average'] 

974 key = 'time-ratio-N=%d-min' % k 

975 obs[key] = b2[k]['min_exec'] / b1[k]['max_exec'] 

976 key = 'time-ratio-N=%d-max' % k 

977 obs[key] = b2[k]['max_exec'] / b1[k]['min_exec'] 

978 

979 obs.update(row) 

980 obs.update(add_versions) 

981 yield obs.copy()