Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file Functions to creates a benchmark based on :epkg:`asv` 

3for many regressors and classifiers. 

4""" 

5import os 

6import sys 

7import json 

8import textwrap 

9import warnings 

10import re 

11from pyquickhelper.pycode.code_helper import remove_extra_spaces_and_pep8 

12try: 

13 from ._create_asv_helper import ( 

14 default_asv_conf, 

15 flask_helper, 

16 pyspy_template, 

17 _handle_init_files, 

18 _asv_class_name, 

19 _read_patterns, 

20 _select_pattern_problem, 

21 _display_code_lines, 

22 add_model_import_init, 

23 find_missing_sklearn_imports) 

24except ImportError: # pragma: no cover 

25 from mlprodict.asv_benchmark._create_asv_helper import ( 

26 default_asv_conf, 

27 flask_helper, 

28 pyspy_template, 

29 _handle_init_files, 

30 _asv_class_name, 

31 _read_patterns, 

32 _select_pattern_problem, 

33 _display_code_lines, 

34 add_model_import_init, 

35 find_missing_sklearn_imports) 

36 

37try: 

38 from ..tools.asv_options_helper import ( 

39 get_opset_number_from_onnx, shorten_onnx_options) 

40 from ..onnxrt.validate.validate_helper import sklearn_operators 

41 from ..onnxrt.validate.validate import ( 

42 _retrieve_problems_extra, _get_problem_data, _merge_options) 

43except (ValueError, ImportError): # pragma: no cover 

44 from mlprodict.tools.asv_options_helper import get_opset_number_from_onnx 

45 from mlprodict.onnxrt.validate.validate_helper import sklearn_operators 

46 from mlprodict.onnxrt.validate.validate import ( 

47 _retrieve_problems_extra, _get_problem_data, _merge_options) 

48 from mlprodict.tools.asv_options_helper import shorten_onnx_options 

49try: 

50 from ..testing.verify_code import verify_code 

51except (ValueError, ImportError): # pragma: no cover 

52 from mlprodict.testing.verify_code import verify_code 

53 

54# exec function does not import models but potentially 

55# requires all specific models used to define scenarios 

56try: 

57 from ..onnxrt.validate.validate_scenarios import * # pylint: disable=W0614,W0401 

58except (ValueError, ImportError): # pragma: no cover 

59 # Skips this step if used in a benchmark. 

60 pass 

61 

62 

63def create_asv_benchmark( 

64 location, opset_min=-1, opset_max=None, 

65 runtime=('scikit-learn', 'python_compiled'), models=None, 

66 skip_models=None, extended_list=True, 

67 dims=(1, 10, 100, 10000), 

68 n_features=(4, 20), dtype=None, 

69 verbose=0, fLOG=print, clean=True, 

70 conf_params=None, filter_exp=None, 

71 filter_scenario=None, flat=False, 

72 exc=False, build=None, execute=False, 

73 add_pyspy=False, env=None, 

74 matrix=None): 

75 """ 

76 Creates an :epkg:`asv` benchmark in a folder 

77 but does not run it. 

78 

79 :param n_features: number of features to try 

80 :param dims: number of observations to try 

81 :param verbose: integer from 0 (None) to 2 (full verbose) 

82 :param opset_min: tries every conversion from this minimum opset, 

83 -1 to get the current opset defined by module :epkg:`onnx` 

84 :param opset_max: tries every conversion up to maximum opset, 

85 -1 to get the current opset defined by module :epkg:`onnx` 

86 :param runtime: runtime to check, *scikit-learn*, *python*, 

87 *python_compiled* compiles the graph structure 

88 and is more efficient when the number of observations is 

89 small, *onnxruntime1* to check :epkg:`onnxruntime`, 

90 *onnxruntime2* to check every ONNX node independently 

91 with onnxruntime, many runtime can be checked at the same time 

92 if the value is a comma separated list 

93 :param models: list of models to test or empty 

94 string to test them all 

95 :param skip_models: models to skip 

96 :param extended_list: extends the list of :epkg:`scikit-learn` converters 

97 with converters implemented in this module 

98 :param n_features: change the default number of features for 

99 a specific problem, it can also be a comma separated list 

100 :param dtype: '32' or '64' or None for both, 

101 limits the test to one specific number types 

102 :param fLOG: logging function 

103 :param clean: clean the folder first, otherwise overwrites the content 

104 :param conf_params: to overwrite some of the configuration parameters 

105 :param filter_exp: function which tells if the experiment must be run, 

106 None to run all, takes *model, problem* as an input 

107 :param filter_scenario: second function which tells if the experiment must be run, 

108 None to run all, takes *model, problem, scenario, extra* 

109 as an input 

110 :param flat: one folder for all files or subfolders 

111 :param exc: if False, raises warnings instead of exceptions 

112 whenever possible 

113 :param build: where to put the outputs 

114 :param execute: execute each script to make sure 

115 imports are correct 

116 :param add_pyspy: add an extra folder with code to profile 

117 each configuration 

118 :param env: None to use the default configuration or ``same`` to use 

119 the current one 

120 :param matrix: specifies versions for a module, 

121 example: ``{'onnxruntime': ['1.1.1', '1.1.2']}``, 

122 if a package name starts with `'~'`, the package is removed 

123 :return: created files 

124 

125 The default configuration is the following: 

126 

127 .. runpython:: 

128 :showcode: 

129 :warningout: DeprecationWarning 

130 

131 import pprint 

132 from mlprodict.asv_benchmark.create_asv import default_asv_conf 

133 

134 pprint.pprint(default_asv_conf) 

135 

136 The benchmark does not seem to work well with setting 

137 ``-environment existing:same``. The publishing fails. 

138 """ 

139 if opset_min == -1: 

140 opset_min = get_opset_number_from_onnx() 

141 if opset_max == -1: 

142 opset_max = get_opset_number_from_onnx() # pragma: no cover 

143 if verbose > 0 and fLOG is not None: # pragma: no cover 

144 fLOG("[create_asv_benchmark] opset in [{}, {}].".format( 

145 opset_min, opset_max)) 

146 

147 # creates the folder if it does not exist. 

148 if not os.path.exists(location): 

149 if verbose > 0 and fLOG is not None: # pragma: no cover 

150 fLOG("[create_asv_benchmark] create folder '{}'.".format(location)) 

151 os.makedirs(location) # pragma: no cover 

152 

153 location_test = os.path.join(location, 'benches') 

154 if not os.path.exists(location_test): 

155 if verbose > 0 and fLOG is not None: 

156 fLOG("[create_asv_benchmark] create folder '{}'.".format(location_test)) 

157 os.mkdir(location_test) 

158 

159 # Cleans the content of the folder 

160 created = [] 

161 if clean: 

162 for name in os.listdir(location_test): 

163 full_name = os.path.join(location_test, name) # pragma: no cover 

164 if os.path.isfile(full_name): # pragma: no cover 

165 os.remove(full_name) 

166 

167 # configuration 

168 conf = default_asv_conf.copy() 

169 if conf_params is not None: 

170 for k, v in conf_params.items(): 

171 conf[k] = v 

172 if build is not None: 

173 for fi in ['env_dir', 'results_dir', 'html_dir']: # pragma: no cover 

174 conf[fi] = os.path.join(build, conf[fi]) 

175 if env == 'same': 

176 if matrix is not None: 

177 raise ValueError( # pragma: no cover 

178 "Parameter matrix must be None if env is 'same'.") 

179 conf['pythons'] = ['same'] 

180 conf['matrix'] = {} 

181 elif matrix is not None: 

182 drop_keys = set(p for p in matrix if p.startswith('~')) 

183 matrix = {k: v for k, v in matrix.items() if k not in drop_keys} 

184 conf['matrix'] = {k: v for k, 

185 v in conf['matrix'].items() if k not in drop_keys} 

186 conf['matrix'].update(matrix) 

187 elif env is not None: 

188 raise ValueError( # pragma: no cover 

189 "Unable to handle env='{}'.".format(env)) 

190 dest = os.path.join(location, "asv.conf.json") 

191 created.append(dest) 

192 with open(dest, "w", encoding='utf-8') as f: 

193 json.dump(conf, f, indent=4) 

194 if verbose > 0 and fLOG is not None: 

195 fLOG("[create_asv_benchmark] create 'asv.conf.json'.") 

196 

197 # __init__.py 

198 dest = os.path.join(location, "__init__.py") 

199 with open(dest, "w", encoding='utf-8') as f: 

200 pass 

201 created.append(dest) 

202 if verbose > 0 and fLOG is not None: 

203 fLOG("[create_asv_benchmark] create '__init__.py'.") 

204 dest = os.path.join(location_test, '__init__.py') 

205 with open(dest, "w", encoding='utf-8') as f: 

206 pass 

207 created.append(dest) 

208 if verbose > 0 and fLOG is not None: 

209 fLOG("[create_asv_benchmark] create 'benches/__init__.py'.") 

210 

211 # flask_server 

212 tool_dir = os.path.join(location, 'tools') 

213 if not os.path.exists(tool_dir): 

214 os.mkdir(tool_dir) 

215 fl = os.path.join(tool_dir, 'flask_serve.py') 

216 with open(fl, "w", encoding='utf-8') as f: 

217 f.write(flask_helper) 

218 if verbose > 0 and fLOG is not None: 

219 fLOG("[create_asv_benchmark] create 'flask_serve.py'.") 

220 

221 # command line 

222 if sys.platform.startswith("win"): 

223 run_bash = os.path.join(tool_dir, 'run_asv.bat') # pragma: no cover 

224 else: 

225 run_bash = os.path.join(tool_dir, 'run_asv.sh') 

226 with open(run_bash, 'w') as f: 

227 f.write(textwrap.dedent(""" 

228 echo --BENCHRUN-- 

229 python -m asv run --show-stderr --config ./asv.conf.json 

230 echo --PUBLISH-- 

231 python -m asv publish --config ./asv.conf.json -o ./html 

232 echo --CSV-- 

233 python -m mlprodict asv2csv -f ./results -o ./data_bench.csv 

234 """)) 

235 

236 # pyspy 

237 if add_pyspy: 

238 dest_pyspy = os.path.join(location, 'pyspy') 

239 if not os.path.exists(dest_pyspy): 

240 os.mkdir(dest_pyspy) 

241 else: 

242 dest_pyspy = None 

243 

244 if verbose > 0 and fLOG is not None: 

245 fLOG("[create_asv_benchmark] create all tests.") 

246 

247 created.extend(list(_enumerate_asv_benchmark_all_models( 

248 location_test, opset_min=opset_min, opset_max=opset_max, 

249 runtime=runtime, models=models, 

250 skip_models=skip_models, extended_list=extended_list, 

251 n_features=n_features, dtype=dtype, 

252 verbose=verbose, filter_exp=filter_exp, 

253 filter_scenario=filter_scenario, 

254 dims=dims, exc=exc, flat=flat, 

255 fLOG=fLOG, execute=execute, 

256 dest_pyspy=dest_pyspy))) 

257 

258 if verbose > 0 and fLOG is not None: 

259 fLOG("[create_asv_benchmark] done.") 

260 return created 

261 

262 

263def _enumerate_asv_benchmark_all_models( # pylint: disable=R0914 

264 location, opset_min=10, opset_max=None, 

265 runtime=('scikit-learn', 'python'), models=None, 

266 skip_models=None, extended_list=True, 

267 n_features=None, dtype=None, 

268 verbose=0, filter_exp=None, 

269 dims=None, filter_scenario=None, 

270 exc=True, flat=False, execute=False, 

271 dest_pyspy=None, fLOG=print): 

272 """ 

273 Loops over all possible models and fills a folder 

274 with benchmarks following :epkg:`asv` concepts. 

275 

276 :param n_features: number of features to try 

277 :param dims: number of observations to try 

278 :param verbose: integer from 0 (None) to 2 (full verbose) 

279 :param opset_min: tries every conversion from this minimum opset 

280 :param opset_max: tries every conversion up to maximum opset 

281 :param runtime: runtime to check, *scikit-learn*, *python*, 

282 *onnxruntime1* to check :epkg:`onnxruntime`, 

283 *onnxruntime2* to check every ONNX node independently 

284 with onnxruntime, many runtime can be checked at the same time 

285 if the value is a comma separated list 

286 :param models: list of models to test or empty 

287 string to test them all 

288 :param skip_models: models to skip 

289 :param extended_list: extends the list of :epkg:`scikit-learn` converters 

290 with converters implemented in this module 

291 :param n_features: change the default number of features for 

292 a specific problem, it can also be a comma separated list 

293 :param dtype: '32' or '64' or None for both, 

294 limits the test to one specific number types 

295 :param fLOG: logging function 

296 :param filter_exp: function which tells if the experiment must be run, 

297 None to run all, takes *model, problem* as an input 

298 :param filter_scenario: second function which tells if the experiment must be run, 

299 None to run all, takes *model, problem, scenario, extra* 

300 as an input 

301 :param exc: if False, raises warnings instead of exceptions 

302 whenever possible 

303 :param flat: one folder for all files or subfolders 

304 :param execute: execute each script to make sure 

305 imports are correct 

306 :param dest_pyspy: add a file to profile the prediction 

307 function with :epkg:`pyspy` 

308 """ 

309 

310 ops = [_ for _ in sklearn_operators(extended=extended_list)] 

311 patterns = _read_patterns() 

312 

313 if models is not None: 

314 if not all(map(lambda m: isinstance(m, str), models)): 

315 raise ValueError( 

316 "models must be a set of strings.") # pragma: no cover 

317 ops_ = [_ for _ in ops if _['name'] in models] 

318 if len(ops) == 0: 

319 raise ValueError("Parameter models is wrong: {}\n{}".format( # pragma: no cover 

320 models, ops[0])) 

321 ops = ops_ 

322 if skip_models is not None: 

323 ops = [m for m in ops if m['name'] not in skip_models] 

324 

325 if verbose > 0: 

326 

327 def iterate(): 

328 for i, row in enumerate(ops): # pragma: no cover 

329 fLOG("{}/{} - {}".format(i + 1, len(ops), row)) 

330 yield row 

331 

332 if verbose >= 11: 

333 verbose -= 10 # pragma: no cover 

334 loop = iterate() # pragma: no cover 

335 else: 

336 try: 

337 from tqdm import trange 

338 

339 def iterate_tqdm(): 

340 with trange(len(ops)) as t: 

341 for i in t: 

342 row = ops[i] 

343 disp = row['name'] + " " * (28 - len(row['name'])) 

344 t.set_description("%s" % disp) 

345 yield row 

346 

347 loop = iterate_tqdm() 

348 

349 except ImportError: # pragma: no cover 

350 loop = iterate() 

351 else: 

352 loop = ops 

353 

354 if opset_max is None: 

355 opset_max = get_opset_number_from_onnx() 

356 opsets = list(range(opset_min, opset_max + 1)) 

357 all_created = set() 

358 

359 # loop on all models 

360 for row in loop: 

361 

362 model = row['cl'] 

363 

364 problems, extras = _retrieve_problems_extra( 

365 model, verbose, fLOG, extended_list) 

366 if extras is None or problems is None: 

367 # Not tested yet. 

368 continue # pragma: no cover 

369 

370 # flat or not flat 

371 created, location_model, prefix_import, dest_pyspy_model = _handle_init_files( 

372 model, flat, location, verbose, dest_pyspy, fLOG) 

373 for init in created: 

374 yield init 

375 

376 # loops on problems 

377 for prob in problems: 

378 if filter_exp is not None and not filter_exp(model, prob): 

379 continue 

380 

381 (X_train, X_test, y_train, 

382 y_test, Xort_test, 

383 init_types, conv_options, method_name, 

384 output_index, dofit, predict_kwargs) = _get_problem_data(prob, None) 

385 

386 for scenario_extra in extras: 

387 subset_problems = None 

388 optimisations = None 

389 new_conv_options = None 

390 

391 if len(scenario_extra) > 2: 

392 options = scenario_extra[2] 

393 if isinstance(options, dict): 

394 subset_problems = options.get('subset_problems', None) 

395 optimisations = options.get('optim', None) 

396 new_conv_options = options.get('conv_options', None) 

397 else: 

398 subset_problems = options 

399 

400 if subset_problems and isinstance(subset_problems, (list, set)): 

401 if prob not in subset_problems: 

402 # Skips unrelated problem for a specific configuration. 

403 continue 

404 elif subset_problems is not None: 

405 raise RuntimeError( # pragma: no cover 

406 "subset_problems must be a set or a list not {}.".format( 

407 subset_problems)) 

408 

409 scenario, extra = scenario_extra[:2] 

410 if optimisations is None: 

411 optimisations = [None] 

412 if new_conv_options is None: 

413 new_conv_options = [{}] 

414 

415 if (filter_scenario is not None and 

416 not filter_scenario(model, prob, scenario, 

417 extra, new_conv_options)): 

418 continue # pragma: no cover 

419 

420 if verbose >= 3 and fLOG is not None: 

421 fLOG("[create_asv_benchmark] model={} scenario={} optim={} extra={} dofit={} (problem={} method_name='{}')".format( 

422 model.__name__, scenario, optimisations, extra, dofit, prob, method_name)) 

423 created = _create_asv_benchmark_file( 

424 location_model, opsets=opsets, 

425 model=model, scenario=scenario, optimisations=optimisations, 

426 extra=extra, dofit=dofit, problem=prob, 

427 runtime=runtime, new_conv_options=new_conv_options, 

428 X_train=X_train, X_test=X_test, y_train=y_train, 

429 y_test=y_test, Xort_test=Xort_test, 

430 init_types=init_types, conv_options=conv_options, 

431 method_name=method_name, dims=dims, n_features=n_features, 

432 output_index=output_index, predict_kwargs=predict_kwargs, 

433 exc=exc, prefix_import=prefix_import, 

434 execute=execute, location_pyspy=dest_pyspy_model, 

435 patterns=patterns) 

436 for cr in created: 

437 if cr in all_created: 

438 raise RuntimeError( # pragma: no cover 

439 "File '{}' was already created.".format(cr)) 

440 all_created.add(cr) 

441 if verbose > 1 and fLOG is not None: 

442 fLOG("[create_asv_benchmark] add '{}'.".format(cr)) 

443 yield cr 

444 

445 

446def _create_asv_benchmark_file( # pylint: disable=R0914 

447 location, model, scenario, optimisations, new_conv_options, 

448 extra, dofit, problem, runtime, X_train, X_test, y_train, 

449 y_test, Xort_test, init_types, conv_options, 

450 method_name, n_features, dims, opsets, 

451 output_index, predict_kwargs, prefix_import, 

452 exc, execute=False, location_pyspy=None, patterns=None): 

453 """ 

454 Creates a benchmark file based in the information received 

455 through the argument. It uses one of the templates 

456 like @see cl TemplateBenchmarkClassifier or 

457 @see cl TemplateBenchmarkRegressor. 

458 """ 

459 if patterns is None: 

460 raise ValueError("Patterns list is empty.") # pragma: no cover 

461 

462 def format_conv_options(d_options, class_name): 

463 if d_options is None: 

464 return None 

465 res = {} 

466 for k, v in d_options.items(): 

467 if isinstance(k, type): 

468 if "." + class_name + "'" in str(k): 

469 res[class_name] = v 

470 continue 

471 raise ValueError( # pragma: no cover 

472 "Class '{}', unable to format options {}".format( 

473 class_name, d_options)) 

474 res[k] = v 

475 return res 

476 

477 def _nick_name_options(model, opts): 

478 # Shorten common onnx options, see _CommonAsvSklBenchmark._to_onnx. 

479 if opts is None: 

480 return opts # pragma: no cover 

481 short_opts = shorten_onnx_options(model, opts) 

482 if short_opts is not None: 

483 return short_opts 

484 res = {} 

485 for k, v in opts.items(): 

486 if hasattr(k, '__name__'): 

487 res["####" + k.__name__ + "####"] = v 

488 else: 

489 res[k] = v # pragma: no cover 

490 return res 

491 

492 def _make_simple_name(name): 

493 simple_name = name.replace("bench_", "").replace("_bench", "") 

494 simple_name = simple_name.replace("bench.", "").replace(".bench", "") 

495 simple_name = simple_name.replace(".", "-") 

496 repl = {'_': '', 'solverliblinear': 'liblinear'} 

497 for k, v in repl.items(): 

498 simple_name = simple_name.replace(k, v) 

499 return simple_name 

500 

501 def _optdict2string(opt): 

502 if isinstance(opt, str): 

503 return opt 

504 if isinstance(opt, list): 

505 raise TypeError( 

506 "Unable to process type %r." % type(opt)) 

507 reps = {True: 1, False: 0, 'zipmap': 'zm', 

508 'optim': 'opt'} 

509 info = [] 

510 for k, v in sorted(opt.items()): 

511 if isinstance(v, dict): 

512 v = _optdict2string(v) 

513 if k.startswith('####'): 

514 k = '' 

515 i = '{}{}'.format(reps.get(k, k), reps.get(v, v)) 

516 info.append(i) 

517 return "-".join(info) 

518 

519 runtimes_abb = { 

520 'scikit-learn': 'skl', 

521 'onnxruntime1': 'ort', 

522 'onnxruntime2': 'ort2', 

523 'python': 'pyrt', 

524 'python_compiled': 'pyrtc', 

525 } 

526 runtime = [runtimes_abb[k] for k in runtime] 

527 

528 # Looping over configuration. 

529 names = [] 

530 for optimisation in optimisations: 

531 merged_options = [_merge_options(nconv_options, conv_options) 

532 for nconv_options in new_conv_options] 

533 

534 nck_opts = [_nick_name_options(model, opts) 

535 for opts in merged_options] 

536 try: 

537 name = _asv_class_name( 

538 model, scenario, optimisation, extra, 

539 dofit, conv_options, problem, 

540 shorten=True) 

541 except ValueError as e: # pragma: no cover 

542 if exc: 

543 raise e 

544 warnings.warn(str(e)) 

545 continue 

546 filename = name.replace(".", "_") + ".py" 

547 try: 

548 class_content = _select_pattern_problem(problem, patterns) 

549 except ValueError as e: 

550 if exc: 

551 raise e # pragma: no cover 

552 warnings.warn(str(e)) 

553 continue 

554 full_class_name = _asv_class_name( 

555 model, scenario, optimisation, extra, 

556 dofit, conv_options, problem, 

557 shorten=False) 

558 class_name = name.replace( 

559 "bench.", "").replace(".", "_") + "_bench" 

560 

561 # n_features, N, runtimes 

562 rep = { 

563 "['skl', 'pyrtc', 'ort'], # values for runtime": str(runtime), 

564 "[1, 10, 100, 1000, 10000], # values for N": str(dims), 

565 "[4, 20], # values for nf": str(n_features), 

566 "[get_opset_number_from_onnx()], # values for opset": str(opsets), 

567 "['float', 'double'], # values for dtype": 

568 "['float']" if '-64' not in problem else "['double']", 

569 "[None], # values for optim": "%r" % nck_opts, 

570 } 

571 for k, v in rep.items(): 

572 if k not in class_content: 

573 raise ValueError("Unable to find '{}'\n{}.".format( # pragma: no cover 

574 k, class_content)) 

575 class_content = class_content.replace(k, v + ',') 

576 class_content = class_content.split( 

577 "def _create_model(self):")[0].strip("\n ") 

578 if "####" in class_content: 

579 class_content = class_content.replace( 

580 "'####", "").replace("####'", "") 

581 if "####" in class_content: 

582 raise RuntimeError( # pragma: no cover 

583 "Substring '####' should not be part of the script for '{}'\n{}".format( 

584 model.__name__, class_content)) 

585 

586 # Model setup 

587 class_content, atts = add_model_import_init( 

588 class_content, model, optimisation, 

589 extra, merged_options) 

590 class_content = class_content.replace( 

591 "class TemplateBenchmark", 

592 "class {}".format(class_name)) 

593 

594 # dtype, dofit 

595 atts.append("chk_method_name = %r" % method_name) 

596 atts.append("par_scenario = %r" % scenario) 

597 atts.append("par_problem = %r" % problem) 

598 atts.append("par_optimisation = %r" % optimisation) 

599 if not dofit: 

600 atts.append("par_dofit = False") 

601 if merged_options is not None and len(merged_options) > 0: 

602 atts.append("par_convopts = %r" % format_conv_options( 

603 conv_options, model.__name__)) 

604 atts.append("par_full_test_name = %r" % full_class_name) 

605 

606 simple_name = _make_simple_name(name) 

607 atts.append("benchmark_name = %r" % simple_name) 

608 atts.append("pretty_name = %r" % simple_name) 

609 

610 if atts: 

611 class_content = class_content.replace( 

612 "# additional parameters", 

613 "\n ".join(atts)) 

614 if prefix_import != '.': 

615 class_content = class_content.replace( 

616 " from .", "from .{}".format(prefix_import)) 

617 

618 # Check compilation 

619 try: 

620 compile(class_content, filename, 'exec') 

621 except SyntaxError as e: # pragma: no cover 

622 raise SyntaxError("Unable to compile model '{}'\n{}".format( 

623 model.__name__, class_content)) from e 

624 

625 # Verifies missing imports. 

626 to_import, _ = verify_code(class_content, exc=False) 

627 try: 

628 miss = find_missing_sklearn_imports(to_import) 

629 except ValueError as e: # pragma: no cover 

630 raise ValueError( 

631 "Unable to check import in script\n{}".format( 

632 class_content)) from e 

633 class_content = class_content.replace( 

634 "# __IMPORTS__", "\n".join(miss)) 

635 verify_code(class_content, exc=True) 

636 class_content = class_content.replace( 

637 "par_extra = {", "par_extra = {\n") 

638 class_content = remove_extra_spaces_and_pep8( 

639 class_content, aggressive=True) 

640 

641 # Check compilation again 

642 try: 

643 obj = compile(class_content, filename, 'exec') 

644 except SyntaxError as e: # pragma: no cover 

645 raise SyntaxError("Unable to compile model '{}'\n{}".format( 

646 model.__name__, 

647 _display_code_lines(class_content))) from e 

648 

649 # executes to check import 

650 if execute: 

651 try: 

652 exec(obj, globals(), locals()) # pylint: disable=W0122 

653 except Exception as e: # pragma: no cover 

654 raise RuntimeError( 

655 "Unable to process class '{}' ('{}') a script due to '{}'\n{}".format( 

656 model.__name__, filename, str(e), 

657 _display_code_lines(class_content))) from e 

658 

659 # Saves 

660 fullname = os.path.join(location, filename) 

661 names.append(fullname) 

662 with open(fullname, "w", encoding='utf-8') as f: 

663 f.write(class_content) 

664 

665 if location_pyspy is not None: 

666 # adding configuration for pyspy 

667 class_name = re.compile( 

668 'class ([A-Za-z_0-9]+)[(]').findall(class_content)[0] 

669 fullname_pyspy = os.path.splitext( 

670 os.path.join(location_pyspy, filename))[0] 

671 pyfold = os.path.splitext(os.path.split(fullname)[-1])[0] 

672 

673 dtypes = ['float', 'double'] if '-64' in problem else ['float'] 

674 for dim in dims: 

675 for nf in n_features: 

676 for opset in opsets: 

677 for dtype in dtypes: 

678 for opt in nck_opts: 

679 tmpl = pyspy_template.replace( 

680 '__PATH__', location) 

681 tmpl = tmpl.replace( 

682 '__CLASSNAME__', class_name) 

683 tmpl = tmpl.replace('__PYFOLD__', pyfold) 

684 opt = "" if opt == {} else opt 

685 

686 first = True 

687 for rt in runtime: 

688 if first: 

689 tmpl += textwrap.dedent(""" 

690 

691 def profile0_{rt}(iter, cl, N, nf, opset, dtype, optim): 

692 return setup_profile0(iter, cl, '{rt}', N, nf, opset, dtype, optim) 

693 iter = profile0_{rt}(iter, cl, {dim}, {nf}, {opset}, '{dtype}', {opt}) 

694 print(datetime.now(), "iter", iter) 

695 

696 """).format(rt=rt, dim=dim, nf=nf, opset=opset, 

697 dtype=dtype, opt="%r" % opt) 

698 first = False 

699 

700 tmpl += textwrap.dedent(""" 

701 

702 def profile_{rt}(iter, cl, N, nf, opset, dtype, optim): 

703 return setup_profile(iter, cl, '{rt}', N, nf, opset, dtype, optim) 

704 profile_{rt}(iter, cl, {dim}, {nf}, {opset}, '{dtype}', {opt}) 

705 print(datetime.now(), "iter", iter) 

706 

707 """).format(rt=rt, dim=dim, nf=nf, opset=opset, 

708 dtype=dtype, opt="%r" % opt) 

709 

710 thename = "{n}_{dim}_{nf}_{opset}_{dtype}_{opt}.py".format( 

711 n=fullname_pyspy, dim=dim, nf=nf, 

712 opset=opset, dtype=dtype, opt=_optdict2string(opt)) 

713 with open(thename, 'w', encoding='utf-8') as f: 

714 f.write(tmpl) 

715 names.append(thename) 

716 

717 ext = '.bat' if sys.platform.startswith( 

718 'win') else '.sh' 

719 script = os.path.splitext(thename)[0] + ext 

720 short = os.path.splitext( 

721 os.path.split(thename)[-1])[0] 

722 with open(script, 'w', encoding='utf-8') as f: 

723 f.write('py-spy record --native --function --rate=10 -o {n}_fct.svg -- {py} {n}.py\n'.format( 

724 py=sys.executable, n=short)) 

725 f.write('py-spy record --native --rate=10 -o {n}_line.svg -- {py} {n}.py\n'.format( 

726 py=sys.executable, n=short)) 

727 

728 return names