Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file Functions to creates a benchmark based on :epkg:`asv` 

3for many regressors and classifiers. 

4""" 

5import os 

6import textwrap 

7import hashlib 

8try: 

9 from ..onnx_tools.optim.sklearn_helper import set_n_jobs 

10except (ValueError, ImportError): # pragma: no cover 

11 from mlprodict.onnx_tools.optim.sklearn_helper import set_n_jobs 

12 

13# exec function does not import models but potentially 

14# requires all specific models used to defines scenarios 

15try: 

16 from ..onnxrt.validate.validate_scenarios import * # pylint: disable=W0614,W0401 

17except (ValueError, ImportError): # pragma: no cover 

18 # Skips this step if used in a benchmark. 

19 pass 

20 

21 

22default_asv_conf = { 

23 "version": 1, 

24 "project": "mlprodict", 

25 "project_url": "http://www.xavierdupre.fr/app/mlprodict/helpsphinx/index.html", 

26 "repo": "https://github.com/sdpython/mlprodict.git", 

27 "repo_subdir": "", 

28 "install_command": ["python -mpip install {wheel_file}"], 

29 "uninstall_command": ["return-code=any python -mpip uninstall -y {project}"], 

30 "build_command": [ 

31 "python setup.py build", 

32 "PIP_NO_BUILD_ISOLATION=false python -mpip wheel --no-deps --no-index -w {build_cache_dir} {build_dir}" 

33 ], 

34 "branches": ["master"], 

35 "environment_type": "virtualenv", 

36 "install_timeout": 600, 

37 "show_commit_url": "https://github.com/sdpython/mlprodict/commit/", 

38 # "pythons": ["__PYVER__"], 

39 "matrix": { 

40 "cython": [], 

41 "jinja2": [], 

42 "joblib": [], 

43 "lightgbm": [], 

44 "mlinsights": [], 

45 "numpy": [], 

46 "onnx": ["http://localhost:8067/simple/"], 

47 "onnxruntime": ["http://localhost:8067/simple/"], 

48 "pandas": [], 

49 "Pillow": [], 

50 "pybind11": [], 

51 "pyquickhelper": [], 

52 "scipy": [], 

53 # "git+https://github.com/xadupre/onnxconverter-common.git@jenkins"], 

54 "onnxconverter-common": ["http://localhost:8067/simple/"], 

55 # "git+https://github.com/xadupre/sklearn-onnx.git@jenkins"], 

56 "skl2onnx": ["http://localhost:8067/simple/"], 

57 # "git+https://github.com/scikit-learn/scikit-learn.git"], 

58 "scikit-learn": ["http://localhost:8067/simple/"], 

59 "xgboost": [], 

60 }, 

61 "benchmark_dir": "benches", 

62 "env_dir": "env", 

63 "results_dir": "results", 

64 "html_dir": "html", 

65} 

66 

67flask_helper = """ 

68''' 

69Local ASV files do no properly render in a browser, 

70it needs to be served through a server. 

71''' 

72import os.path 

73from flask import Flask, Response 

74 

75app = Flask(__name__) 

76app.config.from_object(__name__) 

77 

78 

79def root_dir(): 

80 return os.path.join(os.path.abspath(os.path.dirname(__file__)), "..", "html") 

81 

82 

83def get_file(filename): # pragma: no cover 

84 try: 

85 src = os.path.join(root_dir(), filename) 

86 with open(src, "r", encoding="utf-8", errors="ignore") as f: 

87 return f.read() 

88 except IOError as exc: 

89 return str(exc) 

90 

91 

92@app.route('/', methods=['GET']) 

93def mainpage(): 

94 content = get_file('index.html') 

95 return Response(content, mimetype="text/html") 

96 

97 

98@app.route('/', defaults={'path': ''}) 

99@app.route('/<path:path>') 

100def get_resource(path): # pragma: no cover 

101 mimetypes = { 

102 ".css": "text/css", 

103 ".html": "text/html", 

104 ".js": "application/javascript", 

105 } 

106 complete_path = os.path.join(root_dir(), path) 

107 ext = os.path.splitext(path)[1] 

108 mimetype = mimetypes.get(ext, "text/html") 

109 content = get_file(complete_path) 

110 return Response(content, mimetype=mimetype) 

111 

112 

113if __name__ == '__main__': # pragma: no cover 

114 app.run( # ssl_context=('cert.pem', 'key.pem'), 

115 port=8877, 

116 # host="", 

117 ) 

118""" 

119 

120pyspy_template = """ 

121import sys 

122sys.path.append(r"__PATH__") 

123from __PYFOLD__ import __CLASSNAME__ 

124import time 

125from datetime import datetime 

126 

127 

128def start(): 

129 cl = __CLASSNAME__() 

130 cl.setup_cache() 

131 return cl 

132 

133 

134def profile0(iter, cl, runtime, N, nf, opset, dtype, optim): 

135 begin = time.perf_counter() 

136 for i in range(0, 100): 

137 cl.time_predict(runtime, N, nf, opset, dtype, optim) 

138 duration = time.perf_counter() - begin 

139 iter = max(100, int(25 / duration * 100)) # 25 seconds 

140 return iter 

141 

142 

143def setup_profile0(iter, cl, runtime, N, nf, opset, dtype, optim): 

144 cl.setup(runtime, N, nf, opset, dtype, optim) 

145 return profile0(iter, cl, runtime, N, nf, opset, dtype, optim) 

146 

147 

148def profile(iter, cl, runtime, N, nf, opset, dtype, optim): 

149 for i in range(iter): 

150 cl.time_predict(runtime, N, nf, opset, dtype, optim) 

151 return iter 

152 

153 

154def setup_profile(iter, cl, runtime, N, nf, opset, dtype, optim): 

155 cl.setup(runtime, N, nf, opset, dtype, optim) 

156 return profile(iter, cl, runtime, N, nf, opset, dtype, optim) 

157 

158 

159cl = start() 

160iter = None 

161print(datetime.now(), "begin") 

162""" 

163 

164 

165def _sklearn_subfolder(model): 

166 """ 

167 Returns the list of subfolders for a model. 

168 """ 

169 mod = model.__module__ 

170 if mod is not None and mod.startswith('mlinsights'): 

171 return ['mlinsights', model.__name__] # pragma: no cover 

172 spl = mod.split('.') 

173 try: 

174 pos = spl.index('sklearn') 

175 except ValueError as e: # pragma: no cover 

176 raise ValueError( 

177 "Unable to find 'sklearn' in '{}'.".format(mod)) from e 

178 res = spl[pos + 1: -1] 

179 if len(res) == 0: 

180 if spl[-1] == 'sklearn': 

181 res = ['_externals'] 

182 elif spl[0] == 'sklearn': 

183 res = spl[pos + 1:] 

184 else: 

185 raise ValueError( # pragma: no cover 

186 "Unable to guess subfolder for '{}'.".format(model.__class__)) 

187 res.append(model.__name__) 

188 return res 

189 

190 

191def _handle_init_files(model, flat, location, verbose, location_pyspy, fLOG): 

192 "Returns created, location_model, prefix_import." 

193 if flat: 

194 return ([], location, ".", 

195 (None if location_pyspy is None else location_pyspy)) 

196 

197 created = [] 

198 subf = _sklearn_subfolder(model) 

199 subf = [_ for _ in subf if _[0] != '_' or _ == '_externals'] 

200 location_model = os.path.join(location, *subf) 

201 prefix_import = "." * (len(subf) + 1) 

202 if not os.path.exists(location_model): 

203 os.makedirs(location_model) 

204 for fold in [location_model, os.path.dirname(location_model), 

205 os.path.dirname(os.path.dirname(location_model))]: 

206 init = os.path.join(fold, '__init__.py') 

207 if not os.path.exists(init): 

208 with open(init, 'w') as _: 

209 pass 

210 created.append(init) 

211 if verbose > 1 and fLOG is not None: 

212 fLOG("[create_asv_benchmark] create '{}'.".format(init)) 

213 if location_pyspy is not None: 

214 location_pyspy_model = os.path.join(location_pyspy, *subf) 

215 if not os.path.exists(location_pyspy_model): 

216 os.makedirs(location_pyspy_model) 

217 else: 

218 location_pyspy_model = None 

219 

220 return created, location_model, prefix_import, location_pyspy_model 

221 

222 

223def _asv_class_name(model, scenario, optimisation, 

224 extra, dofit, conv_options, problem, 

225 shorten=True): 

226 

227 def clean_str(val): 

228 s = str(val) 

229 r = "" 

230 for c in s: 

231 if c in ",-\n": 

232 r += "_" 

233 continue 

234 if c in ": =.+()[]{}\"'<>~": 

235 continue 

236 r += c 

237 for k, v in {'n_estimators': 'nest', 

238 'max_iter': 'mxit'}.items(): 

239 r = r.replace(k, v) 

240 return r 

241 

242 def clean_str_list(val): 

243 if val is None: 

244 return "" # pragma: no cover 

245 if isinstance(val, list): 

246 return ".".join( # pragma: no cover 

247 clean_str_list(v) for v in val if v) 

248 return clean_str(val) 

249 

250 els = ['bench', model.__name__, scenario, clean_str(problem)] 

251 if not dofit: 

252 els.append('nofit') 

253 if extra: 

254 if 'random_state' in extra and extra['random_state'] == 42: 

255 extra2 = extra.copy() 

256 del extra2['random_state'] 

257 if extra2: 

258 els.append(clean_str(extra2)) 

259 else: 

260 els.append(clean_str(extra)) 

261 if optimisation: 

262 els.append(clean_str_list(optimisation)) 

263 if conv_options: 

264 els.append(clean_str_list(conv_options)) 

265 res = ".".join(els).replace("-", "_") 

266 

267 if shorten: 

268 rep = { 

269 'ConstantKernel': 'Cst', 

270 'DotProduct': 'Dot', 

271 'Exponentiation': 'Exp', 

272 'ExpSineSquared': 'ExpS2', 

273 'GaussianProcess': 'GaussProc', 

274 'GaussianMixture': 'GaussMixt', 

275 'HistGradientBoosting': 'HGB', 

276 'LinearRegression': 'LinReg', 

277 'LogisticRegression': 'LogReg', 

278 'MultiOutput': 'MultOut', 

279 'OrthogonalMatchingPursuit': 'OrthMatchPurs', 

280 'PairWiseKernel': 'PW', 

281 'Product': 'Prod', 

282 'RationalQuadratic': 'RQ', 

283 'WhiteKernel': 'WK', 

284 'length_scale': 'ls', 

285 'periodicity': 'pcy', 

286 } 

287 for k, v in rep.items(): 

288 res = res.replace(k, v) 

289 

290 rep = { 

291 'Classifier': 'Clas', 

292 'Regressor': 'Reg', 

293 'KNeighbors': 'KNN', 

294 'NearestNeighbors': 'kNN', 

295 'RadiusNeighbors': 'RadNN', 

296 } 

297 for k, v in rep.items(): 

298 res = res.replace(k, v) 

299 

300 if len(res) > 70: # shorten filename 

301 m = hashlib.sha256() 

302 m.update(res.encode('utf-8')) 

303 sh = m.hexdigest() 

304 if len(sh) > 6: 

305 sh = sh[:6] 

306 res = res[:70] + sh 

307 return res 

308 

309 

310def _read_patterns(): 

311 """ 

312 Reads the testing pattern. 

313 """ 

314 # Reads the template 

315 patterns = {} 

316 for suffix in ['classifier', 'classifier_raw_scores', 'regressor', 'clustering', 

317 'outlier', 'trainable_transform', 'transform', 

318 'multi_classifier', 'transform_positive']: 

319 template_name = os.path.join(os.path.dirname( 

320 __file__), "template", "skl_model_%s.py" % suffix) 

321 if not os.path.exists(template_name): 

322 raise FileNotFoundError( # pragma: no cover 

323 "Template '{}' was not found.".format(template_name)) 

324 with open(template_name, "r", encoding="utf-8") as f: 

325 content = f.read() 

326 initial_content = '"""'.join(content.split('"""')[2:]) 

327 patterns[suffix] = initial_content 

328 return patterns 

329 

330 

331def _select_pattern_problem(prob, patterns): 

332 """ 

333 Selects a benchmark type based on the problem kind. 

334 """ 

335 if '-reg' in prob: 

336 return patterns['regressor'] 

337 if '-cl' in prob and '-dec' in prob: 

338 return patterns['classifier_raw_scores'] 

339 if '-cl' in prob: 

340 return patterns['classifier'] 

341 if 'cluster' in prob: 

342 return patterns['clustering'] 

343 if 'outlier' in prob: 

344 return patterns['outlier'] 

345 if 'num+y-tr' in prob: 

346 return patterns['trainable_transform'] 

347 if 'num-tr-pos' in prob: 

348 return patterns['transform_positive'] 

349 if 'num-tr' in prob: 

350 return patterns['transform'] 

351 if 'm-label' in prob: 

352 return patterns['multi_classifier'] 

353 raise ValueError( # pragma: no cover 

354 "Unable to guess the right pattern for '{}'.".format(prob)) 

355 

356 

357def _display_code_lines(code): 

358 rows = ["%03d %s" % (i + 1, line) 

359 for i, line in enumerate(code.split("\n"))] 

360 return "\n".join(rows) 

361 

362 

363def _format_dict(opts, indent): 

364 """ 

365 Formats a dictionary as code. 

366 """ 

367 rows = [] 

368 for k, v in sorted(opts.items()): 

369 rows.append('%s=%r' % (k, v)) 

370 content = ', '.join(rows) 

371 st1 = "\n".join(textwrap.wrap(content)) 

372 return textwrap.indent(st1, prefix=' ' * indent) 

373 

374 

375def _additional_imports(model_name): 

376 """ 

377 Adds additional imports for experimental models. 

378 """ 

379 if model_name == 'IterativeImputer': 

380 return ["from sklearn.experimental import enable_iterative_imputer # pylint: disable=W0611"] 

381 if model_name in ('HistGradientBoostingClassifier', 'HistGradientBoostingClassifier'): 

382 return ["from sklearn.experimental import enable_hist_gradient_boosting # pylint: disable=W0611"] 

383 return None 

384 

385 

386def add_model_import_init( 

387 class_content, model, optimisation=None, 

388 extra=None, conv_options=None): 

389 """ 

390 Modifies a template such as @see cl TemplateBenchmarkClassifier 

391 with code associated to the model *model*. 

392 

393 @param class_content template (as a string) 

394 @param model model class 

395 @param optimisation model optimisation 

396 @param extra addition parameter to the constructor 

397 @param conv_options options for the conversion to ONNX 

398 @returm modified template 

399 """ 

400 add_imports = [] 

401 add_methods = [] 

402 add_params = ["par_modelname = '%s'" % model.__name__, 

403 "par_extra = %r" % extra] 

404 

405 # additional methods and imports 

406 if optimisation is not None: 

407 add_imports.append( 

408 'from mlprodict.onnx_tools.optim import onnx_optimisations') 

409 if optimisation == 'onnx': 

410 add_methods.append(textwrap.dedent(''' 

411 def _optimize_onnx(self, onx): 

412 return onnx_optimisations(onx)''')) 

413 add_params.append('par_optimonnx = True') 

414 elif isinstance(optimisation, dict): 

415 add_methods.append(textwrap.dedent(''' 

416 def _optimize_onnx(self, onx): 

417 return onnx_optimisations(onx, self.par_optims)''')) 

418 add_params.append('par_optims = {}'.format( 

419 _format_dict(optimisation, indent=4))) 

420 else: 

421 raise ValueError( # pragma: no cover 

422 "Unable to interpret optimisation {}.".format(optimisation)) 

423 

424 # look for import place 

425 lines = class_content.split('\n') 

426 keep = None 

427 for pos, line in enumerate(lines): 

428 if "# Import specific to this model." in line: 

429 keep = pos 

430 break 

431 if keep is None: 

432 raise RuntimeError( # pragma: no cover 

433 "Unable to locate where to insert import in\n{}\n".format( 

434 class_content)) 

435 

436 # imports 

437 loc_class = model.__module__ 

438 sub = loc_class.split('.') 

439 if 'sklearn' not in sub: 

440 mod = loc_class 

441 else: 

442 skl = sub.index('sklearn') 

443 if skl == 0: 

444 if sub[-1].startswith("_"): 

445 mod = '.'.join(sub[skl:-1]) 

446 else: 

447 mod = '.'.join(sub[skl:]) 

448 else: 

449 mod = '.'.join(sub[:-1]) 

450 

451 exp_imports = _additional_imports(model.__name__) 

452 if exp_imports: 

453 add_imports.extend(exp_imports) 

454 imp_inst = ( 

455 "try:\n from {0} import {1}\nexcept ImportError:\n {1} = None" 

456 "".format(mod, model.__name__)) 

457 add_imports.append(imp_inst) 

458 add_imports.append("# __IMPORTS__") 

459 lines[keep + 1] = "\n".join(add_imports) 

460 content = "\n".join(lines) 

461 

462 # _create_model 

463 content = content.split('def _create_model(self):', 

464 maxsplit=1)[0].strip(' \n') 

465 lines = [content, "", " def _create_model(self):"] 

466 if extra is not None and len(extra) > 0: 

467 lines.append(" return {}(".format(model.__name__)) 

468 lines.append(_format_dict(set_n_jobs(model, extra), 12)) 

469 lines.append(" )") 

470 else: 

471 lines.append(" return {}()".format(model.__name__)) 

472 lines.append("") 

473 

474 # methods 

475 for meth in add_methods: 

476 lines.append(textwrap.indent(meth, ' ')) 

477 lines.append('') 

478 

479 # end 

480 return "\n".join(lines), add_params 

481 

482 

483def find_missing_sklearn_imports(pieces): 

484 """ 

485 Finds in :epkg:`scikit-learn` the missing pieces. 

486 

487 @param pieces list of names in scikit-learn 

488 @return list of corresponding imports 

489 """ 

490 res = {} 

491 for piece in pieces: 

492 mod = find_sklearn_module(piece) 

493 if mod not in res: 

494 res[mod] = [] 

495 res[mod].append(piece) 

496 

497 lines = [] 

498 for k, v in res.items(): 

499 lines.append("from {} import {}".format( 

500 k, ", ".join(sorted(v)))) 

501 return lines 

502 

503 

504def find_sklearn_module(piece): 

505 """ 

506 Finds the corresponding modulee for an element of :epkg:`scikit-learn`. 

507 

508 @param piece name to import 

509 @return module name 

510 

511 The implementation is not intelligence and should 

512 be improved. It is a kind of white list. 

513 """ 

514 glo = globals() 

515 if piece in {'LinearRegression', 'LogisticRegression', 

516 'SGDClassifier'}: 

517 import sklearn.linear_model 

518 glo[piece] = getattr(sklearn.linear_model, piece) 

519 return "sklearn.linear_model" 

520 if piece in {'DecisionTreeRegressor', 'DecisionTreeClassifier'}: 

521 import sklearn.tree 

522 glo[piece] = getattr(sklearn.tree, piece) 

523 return "sklearn.tree" 

524 if piece in {'ExpSineSquared', 'DotProduct', 'RationalQuadratic', 'RBF'}: 

525 import sklearn.gaussian_process.kernels 

526 glo[piece] = getattr(sklearn.gaussian_process.kernels, piece) 

527 return "sklearn.gaussian_process.kernels" 

528 if piece in {'LinearSVC', 'LinearSVR', 'NuSVR', 'SVR', 'SVC', 'NuSVC'}: # pragma: no cover 

529 import sklearn.svm 

530 glo[piece] = getattr(sklearn.svm, piece) 

531 return "sklearn.svm" 

532 if piece in {'KMeans'}: # pragma: no cover 

533 import sklearn.cluster 

534 glo[piece] = getattr(sklearn.cluster, piece) 

535 return "sklearn.cluster" 

536 if piece in {'OneVsRestClassifier', 'OneVsOneClassifier'}: # pragma: no cover 

537 import sklearn.multiclass 

538 glo[piece] = getattr(sklearn.multiclass, piece) 

539 return "sklearn.multiclass" 

540 raise ValueError( # pragma: no cover 

541 "Unable to find module to import for '{}'.".format(piece))