Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file Functions to creates a benchmark based on :epkg:`asv`
3for many regressors and classifiers.
4"""
5import os
6import sys
7import json
8import textwrap
9import warnings
10import re
11from pyquickhelper.pycode.code_helper import remove_extra_spaces_and_pep8
12try:
13 from ._create_asv_helper import (
14 default_asv_conf,
15 flask_helper,
16 pyspy_template,
17 _handle_init_files,
18 _asv_class_name,
19 _read_patterns,
20 _select_pattern_problem,
21 _display_code_lines,
22 add_model_import_init,
23 find_missing_sklearn_imports)
24except ImportError: # pragma: no cover
25 from mlprodict.asv_benchmark._create_asv_helper import (
26 default_asv_conf,
27 flask_helper,
28 pyspy_template,
29 _handle_init_files,
30 _asv_class_name,
31 _read_patterns,
32 _select_pattern_problem,
33 _display_code_lines,
34 add_model_import_init,
35 find_missing_sklearn_imports)
37try:
38 from ..tools.asv_options_helper import (
39 get_opset_number_from_onnx, shorten_onnx_options)
40 from ..onnxrt.validate.validate_helper import sklearn_operators
41 from ..onnxrt.validate.validate import (
42 _retrieve_problems_extra, _get_problem_data, _merge_options)
43except (ValueError, ImportError): # pragma: no cover
44 from mlprodict.tools.asv_options_helper import get_opset_number_from_onnx
45 from mlprodict.onnxrt.validate.validate_helper import sklearn_operators
46 from mlprodict.onnxrt.validate.validate import (
47 _retrieve_problems_extra, _get_problem_data, _merge_options)
48 from mlprodict.tools.asv_options_helper import shorten_onnx_options
49try:
50 from ..testing.verify_code import verify_code
51except (ValueError, ImportError): # pragma: no cover
52 from mlprodict.testing.verify_code import verify_code
54# exec function does not import models but potentially
55# requires all specific models used to define scenarios
56try:
57 from ..onnxrt.validate.validate_scenarios import * # pylint: disable=W0614,W0401
58except (ValueError, ImportError): # pragma: no cover
59 # Skips this step if used in a benchmark.
60 pass
63def create_asv_benchmark(
64 location, opset_min=-1, opset_max=None,
65 runtime=('scikit-learn', 'python_compiled'), models=None,
66 skip_models=None, extended_list=True,
67 dims=(1, 10, 100, 10000),
68 n_features=(4, 20), dtype=None,
69 verbose=0, fLOG=print, clean=True,
70 conf_params=None, filter_exp=None,
71 filter_scenario=None, flat=False,
72 exc=False, build=None, execute=False,
73 add_pyspy=False, env=None,
74 matrix=None):
75 """
76 Creates an :epkg:`asv` benchmark in a folder
77 but does not run it.
79 :param n_features: number of features to try
80 :param dims: number of observations to try
81 :param verbose: integer from 0 (None) to 2 (full verbose)
82 :param opset_min: tries every conversion from this minimum opset,
83 -1 to get the current opset defined by module :epkg:`onnx`
84 :param opset_max: tries every conversion up to maximum opset,
85 -1 to get the current opset defined by module :epkg:`onnx`
86 :param runtime: runtime to check, *scikit-learn*, *python*,
87 *python_compiled* compiles the graph structure
88 and is more efficient when the number of observations is
89 small, *onnxruntime1* to check :epkg:`onnxruntime`,
90 *onnxruntime2* to check every ONNX node independently
91 with onnxruntime, many runtime can be checked at the same time
92 if the value is a comma separated list
93 :param models: list of models to test or empty
94 string to test them all
95 :param skip_models: models to skip
96 :param extended_list: extends the list of :epkg:`scikit-learn` converters
97 with converters implemented in this module
98 :param n_features: change the default number of features for
99 a specific problem, it can also be a comma separated list
100 :param dtype: '32' or '64' or None for both,
101 limits the test to one specific number types
102 :param fLOG: logging function
103 :param clean: clean the folder first, otherwise overwrites the content
104 :param conf_params: to overwrite some of the configuration parameters
105 :param filter_exp: function which tells if the experiment must be run,
106 None to run all, takes *model, problem* as an input
107 :param filter_scenario: second function which tells if the experiment must be run,
108 None to run all, takes *model, problem, scenario, extra*
109 as an input
110 :param flat: one folder for all files or subfolders
111 :param exc: if False, raises warnings instead of exceptions
112 whenever possible
113 :param build: where to put the outputs
114 :param execute: execute each script to make sure
115 imports are correct
116 :param add_pyspy: add an extra folder with code to profile
117 each configuration
118 :param env: None to use the default configuration or ``same`` to use
119 the current one
120 :param matrix: specifies versions for a module,
121 example: ``{'onnxruntime': ['1.1.1', '1.1.2']}``,
122 if a package name starts with `'~'`, the package is removed
123 :return: created files
125 The default configuration is the following:
127 .. runpython::
128 :showcode:
129 :warningout: DeprecationWarning
131 import pprint
132 from mlprodict.asv_benchmark.create_asv import default_asv_conf
134 pprint.pprint(default_asv_conf)
136 The benchmark does not seem to work well with setting
137 ``-environment existing:same``. The publishing fails.
138 """
139 if opset_min == -1:
140 opset_min = get_opset_number_from_onnx()
141 if opset_max == -1:
142 opset_max = get_opset_number_from_onnx() # pragma: no cover
143 if verbose > 0 and fLOG is not None: # pragma: no cover
144 fLOG("[create_asv_benchmark] opset in [{}, {}].".format(
145 opset_min, opset_max))
147 # creates the folder if it does not exist.
148 if not os.path.exists(location):
149 if verbose > 0 and fLOG is not None: # pragma: no cover
150 fLOG("[create_asv_benchmark] create folder '{}'.".format(location))
151 os.makedirs(location) # pragma: no cover
153 location_test = os.path.join(location, 'benches')
154 if not os.path.exists(location_test):
155 if verbose > 0 and fLOG is not None:
156 fLOG("[create_asv_benchmark] create folder '{}'.".format(location_test))
157 os.mkdir(location_test)
159 # Cleans the content of the folder
160 created = []
161 if clean:
162 for name in os.listdir(location_test):
163 full_name = os.path.join(location_test, name) # pragma: no cover
164 if os.path.isfile(full_name): # pragma: no cover
165 os.remove(full_name)
167 # configuration
168 conf = default_asv_conf.copy()
169 if conf_params is not None:
170 for k, v in conf_params.items():
171 conf[k] = v
172 if build is not None:
173 for fi in ['env_dir', 'results_dir', 'html_dir']: # pragma: no cover
174 conf[fi] = os.path.join(build, conf[fi])
175 if env == 'same':
176 if matrix is not None:
177 raise ValueError( # pragma: no cover
178 "Parameter matrix must be None if env is 'same'.")
179 conf['pythons'] = ['same']
180 conf['matrix'] = {}
181 elif matrix is not None:
182 drop_keys = set(p for p in matrix if p.startswith('~'))
183 matrix = {k: v for k, v in matrix.items() if k not in drop_keys}
184 conf['matrix'] = {k: v for k,
185 v in conf['matrix'].items() if k not in drop_keys}
186 conf['matrix'].update(matrix)
187 elif env is not None:
188 raise ValueError( # pragma: no cover
189 "Unable to handle env='{}'.".format(env))
190 dest = os.path.join(location, "asv.conf.json")
191 created.append(dest)
192 with open(dest, "w", encoding='utf-8') as f:
193 json.dump(conf, f, indent=4)
194 if verbose > 0 and fLOG is not None:
195 fLOG("[create_asv_benchmark] create 'asv.conf.json'.")
197 # __init__.py
198 dest = os.path.join(location, "__init__.py")
199 with open(dest, "w", encoding='utf-8') as f:
200 pass
201 created.append(dest)
202 if verbose > 0 and fLOG is not None:
203 fLOG("[create_asv_benchmark] create '__init__.py'.")
204 dest = os.path.join(location_test, '__init__.py')
205 with open(dest, "w", encoding='utf-8') as f:
206 pass
207 created.append(dest)
208 if verbose > 0 and fLOG is not None:
209 fLOG("[create_asv_benchmark] create 'benches/__init__.py'.")
211 # flask_server
212 tool_dir = os.path.join(location, 'tools')
213 if not os.path.exists(tool_dir):
214 os.mkdir(tool_dir)
215 fl = os.path.join(tool_dir, 'flask_serve.py')
216 with open(fl, "w", encoding='utf-8') as f:
217 f.write(flask_helper)
218 if verbose > 0 and fLOG is not None:
219 fLOG("[create_asv_benchmark] create 'flask_serve.py'.")
221 # command line
222 if sys.platform.startswith("win"):
223 run_bash = os.path.join(tool_dir, 'run_asv.bat') # pragma: no cover
224 else:
225 run_bash = os.path.join(tool_dir, 'run_asv.sh')
226 with open(run_bash, 'w') as f:
227 f.write(textwrap.dedent("""
228 echo --BENCHRUN--
229 python -m asv run --show-stderr --config ./asv.conf.json
230 echo --PUBLISH--
231 python -m asv publish --config ./asv.conf.json -o ./html
232 echo --CSV--
233 python -m mlprodict asv2csv -f ./results -o ./data_bench.csv
234 """))
236 # pyspy
237 if add_pyspy:
238 dest_pyspy = os.path.join(location, 'pyspy')
239 if not os.path.exists(dest_pyspy):
240 os.mkdir(dest_pyspy)
241 else:
242 dest_pyspy = None
244 if verbose > 0 and fLOG is not None:
245 fLOG("[create_asv_benchmark] create all tests.")
247 created.extend(list(_enumerate_asv_benchmark_all_models(
248 location_test, opset_min=opset_min, opset_max=opset_max,
249 runtime=runtime, models=models,
250 skip_models=skip_models, extended_list=extended_list,
251 n_features=n_features, dtype=dtype,
252 verbose=verbose, filter_exp=filter_exp,
253 filter_scenario=filter_scenario,
254 dims=dims, exc=exc, flat=flat,
255 fLOG=fLOG, execute=execute,
256 dest_pyspy=dest_pyspy)))
258 if verbose > 0 and fLOG is not None:
259 fLOG("[create_asv_benchmark] done.")
260 return created
263def _enumerate_asv_benchmark_all_models( # pylint: disable=R0914
264 location, opset_min=10, opset_max=None,
265 runtime=('scikit-learn', 'python'), models=None,
266 skip_models=None, extended_list=True,
267 n_features=None, dtype=None,
268 verbose=0, filter_exp=None,
269 dims=None, filter_scenario=None,
270 exc=True, flat=False, execute=False,
271 dest_pyspy=None, fLOG=print):
272 """
273 Loops over all possible models and fills a folder
274 with benchmarks following :epkg:`asv` concepts.
276 :param n_features: number of features to try
277 :param dims: number of observations to try
278 :param verbose: integer from 0 (None) to 2 (full verbose)
279 :param opset_min: tries every conversion from this minimum opset
280 :param opset_max: tries every conversion up to maximum opset
281 :param runtime: runtime to check, *scikit-learn*, *python*,
282 *onnxruntime1* to check :epkg:`onnxruntime`,
283 *onnxruntime2* to check every ONNX node independently
284 with onnxruntime, many runtime can be checked at the same time
285 if the value is a comma separated list
286 :param models: list of models to test or empty
287 string to test them all
288 :param skip_models: models to skip
289 :param extended_list: extends the list of :epkg:`scikit-learn` converters
290 with converters implemented in this module
291 :param n_features: change the default number of features for
292 a specific problem, it can also be a comma separated list
293 :param dtype: '32' or '64' or None for both,
294 limits the test to one specific number types
295 :param fLOG: logging function
296 :param filter_exp: function which tells if the experiment must be run,
297 None to run all, takes *model, problem* as an input
298 :param filter_scenario: second function which tells if the experiment must be run,
299 None to run all, takes *model, problem, scenario, extra*
300 as an input
301 :param exc: if False, raises warnings instead of exceptions
302 whenever possible
303 :param flat: one folder for all files or subfolders
304 :param execute: execute each script to make sure
305 imports are correct
306 :param dest_pyspy: add a file to profile the prediction
307 function with :epkg:`pyspy`
308 """
310 ops = [_ for _ in sklearn_operators(extended=extended_list)]
311 patterns = _read_patterns()
313 if models is not None:
314 if not all(map(lambda m: isinstance(m, str), models)):
315 raise ValueError(
316 "models must be a set of strings.") # pragma: no cover
317 ops_ = [_ for _ in ops if _['name'] in models]
318 if len(ops) == 0:
319 raise ValueError("Parameter models is wrong: {}\n{}".format( # pragma: no cover
320 models, ops[0]))
321 ops = ops_
322 if skip_models is not None:
323 ops = [m for m in ops if m['name'] not in skip_models]
325 if verbose > 0:
327 def iterate():
328 for i, row in enumerate(ops): # pragma: no cover
329 fLOG("{}/{} - {}".format(i + 1, len(ops), row))
330 yield row
332 if verbose >= 11:
333 verbose -= 10 # pragma: no cover
334 loop = iterate() # pragma: no cover
335 else:
336 try:
337 from tqdm import trange
339 def iterate_tqdm():
340 with trange(len(ops)) as t:
341 for i in t:
342 row = ops[i]
343 disp = row['name'] + " " * (28 - len(row['name']))
344 t.set_description("%s" % disp)
345 yield row
347 loop = iterate_tqdm()
349 except ImportError: # pragma: no cover
350 loop = iterate()
351 else:
352 loop = ops
354 if opset_max is None:
355 opset_max = get_opset_number_from_onnx()
356 opsets = list(range(opset_min, opset_max + 1))
357 all_created = set()
359 # loop on all models
360 for row in loop:
362 model = row['cl']
364 problems, extras = _retrieve_problems_extra(
365 model, verbose, fLOG, extended_list)
366 if extras is None or problems is None:
367 # Not tested yet.
368 continue # pragma: no cover
370 # flat or not flat
371 created, location_model, prefix_import, dest_pyspy_model = _handle_init_files(
372 model, flat, location, verbose, dest_pyspy, fLOG)
373 for init in created:
374 yield init
376 # loops on problems
377 for prob in problems:
378 if filter_exp is not None and not filter_exp(model, prob):
379 continue
381 (X_train, X_test, y_train,
382 y_test, Xort_test,
383 init_types, conv_options, method_name,
384 output_index, dofit, predict_kwargs) = _get_problem_data(prob, None)
386 for scenario_extra in extras:
387 subset_problems = None
388 optimisations = None
389 new_conv_options = None
391 if len(scenario_extra) > 2:
392 options = scenario_extra[2]
393 if isinstance(options, dict):
394 subset_problems = options.get('subset_problems', None)
395 optimisations = options.get('optim', None)
396 new_conv_options = options.get('conv_options', None)
397 else:
398 subset_problems = options
400 if subset_problems and isinstance(subset_problems, (list, set)):
401 if prob not in subset_problems:
402 # Skips unrelated problem for a specific configuration.
403 continue
404 elif subset_problems is not None:
405 raise RuntimeError( # pragma: no cover
406 "subset_problems must be a set or a list not {}.".format(
407 subset_problems))
409 scenario, extra = scenario_extra[:2]
410 if optimisations is None:
411 optimisations = [None]
412 if new_conv_options is None:
413 new_conv_options = [{}]
415 if (filter_scenario is not None and
416 not filter_scenario(model, prob, scenario,
417 extra, new_conv_options)):
418 continue # pragma: no cover
420 if verbose >= 3 and fLOG is not None:
421 fLOG("[create_asv_benchmark] model={} scenario={} optim={} extra={} dofit={} (problem={} method_name='{}')".format(
422 model.__name__, scenario, optimisations, extra, dofit, prob, method_name))
423 created = _create_asv_benchmark_file(
424 location_model, opsets=opsets,
425 model=model, scenario=scenario, optimisations=optimisations,
426 extra=extra, dofit=dofit, problem=prob,
427 runtime=runtime, new_conv_options=new_conv_options,
428 X_train=X_train, X_test=X_test, y_train=y_train,
429 y_test=y_test, Xort_test=Xort_test,
430 init_types=init_types, conv_options=conv_options,
431 method_name=method_name, dims=dims, n_features=n_features,
432 output_index=output_index, predict_kwargs=predict_kwargs,
433 exc=exc, prefix_import=prefix_import,
434 execute=execute, location_pyspy=dest_pyspy_model,
435 patterns=patterns)
436 for cr in created:
437 if cr in all_created:
438 raise RuntimeError( # pragma: no cover
439 "File '{}' was already created.".format(cr))
440 all_created.add(cr)
441 if verbose > 1 and fLOG is not None:
442 fLOG("[create_asv_benchmark] add '{}'.".format(cr))
443 yield cr
446def _create_asv_benchmark_file( # pylint: disable=R0914
447 location, model, scenario, optimisations, new_conv_options,
448 extra, dofit, problem, runtime, X_train, X_test, y_train,
449 y_test, Xort_test, init_types, conv_options,
450 method_name, n_features, dims, opsets,
451 output_index, predict_kwargs, prefix_import,
452 exc, execute=False, location_pyspy=None, patterns=None):
453 """
454 Creates a benchmark file based in the information received
455 through the argument. It uses one of the templates
456 like @see cl TemplateBenchmarkClassifier or
457 @see cl TemplateBenchmarkRegressor.
458 """
459 if patterns is None:
460 raise ValueError("Patterns list is empty.") # pragma: no cover
462 def format_conv_options(d_options, class_name):
463 if d_options is None:
464 return None
465 res = {}
466 for k, v in d_options.items():
467 if isinstance(k, type):
468 if "." + class_name + "'" in str(k):
469 res[class_name] = v
470 continue
471 raise ValueError( # pragma: no cover
472 "Class '{}', unable to format options {}".format(
473 class_name, d_options))
474 res[k] = v
475 return res
477 def _nick_name_options(model, opts):
478 # Shorten common onnx options, see _CommonAsvSklBenchmark._to_onnx.
479 if opts is None:
480 return opts # pragma: no cover
481 short_opts = shorten_onnx_options(model, opts)
482 if short_opts is not None:
483 return short_opts
484 res = {}
485 for k, v in opts.items():
486 if hasattr(k, '__name__'):
487 res["####" + k.__name__ + "####"] = v
488 else:
489 res[k] = v # pragma: no cover
490 return res
492 def _make_simple_name(name):
493 simple_name = name.replace("bench_", "").replace("_bench", "")
494 simple_name = simple_name.replace("bench.", "").replace(".bench", "")
495 simple_name = simple_name.replace(".", "-")
496 repl = {'_': '', 'solverliblinear': 'liblinear'}
497 for k, v in repl.items():
498 simple_name = simple_name.replace(k, v)
499 return simple_name
501 def _optdict2string(opt):
502 if isinstance(opt, str):
503 return opt
504 if isinstance(opt, list):
505 raise TypeError(
506 "Unable to process type %r." % type(opt))
507 reps = {True: 1, False: 0, 'zipmap': 'zm',
508 'optim': 'opt'}
509 info = []
510 for k, v in sorted(opt.items()):
511 if isinstance(v, dict):
512 v = _optdict2string(v)
513 if k.startswith('####'):
514 k = ''
515 i = '{}{}'.format(reps.get(k, k), reps.get(v, v))
516 info.append(i)
517 return "-".join(info)
519 runtimes_abb = {
520 'scikit-learn': 'skl',
521 'onnxruntime1': 'ort',
522 'onnxruntime2': 'ort2',
523 'python': 'pyrt',
524 'python_compiled': 'pyrtc',
525 }
526 runtime = [runtimes_abb[k] for k in runtime]
528 # Looping over configuration.
529 names = []
530 for optimisation in optimisations:
531 merged_options = [_merge_options(nconv_options, conv_options)
532 for nconv_options in new_conv_options]
534 nck_opts = [_nick_name_options(model, opts)
535 for opts in merged_options]
536 try:
537 name = _asv_class_name(
538 model, scenario, optimisation, extra,
539 dofit, conv_options, problem,
540 shorten=True)
541 except ValueError as e: # pragma: no cover
542 if exc:
543 raise e
544 warnings.warn(str(e))
545 continue
546 filename = name.replace(".", "_") + ".py"
547 try:
548 class_content = _select_pattern_problem(problem, patterns)
549 except ValueError as e:
550 if exc:
551 raise e # pragma: no cover
552 warnings.warn(str(e))
553 continue
554 full_class_name = _asv_class_name(
555 model, scenario, optimisation, extra,
556 dofit, conv_options, problem,
557 shorten=False)
558 class_name = name.replace(
559 "bench.", "").replace(".", "_") + "_bench"
561 # n_features, N, runtimes
562 rep = {
563 "['skl', 'pyrtc', 'ort'], # values for runtime": str(runtime),
564 "[1, 10, 100, 1000, 10000], # values for N": str(dims),
565 "[4, 20], # values for nf": str(n_features),
566 "[get_opset_number_from_onnx()], # values for opset": str(opsets),
567 "['float', 'double'], # values for dtype":
568 "['float']" if '-64' not in problem else "['double']",
569 "[None], # values for optim": "%r" % nck_opts,
570 }
571 for k, v in rep.items():
572 if k not in class_content:
573 raise ValueError("Unable to find '{}'\n{}.".format( # pragma: no cover
574 k, class_content))
575 class_content = class_content.replace(k, v + ',')
576 class_content = class_content.split(
577 "def _create_model(self):")[0].strip("\n ")
578 if "####" in class_content:
579 class_content = class_content.replace(
580 "'####", "").replace("####'", "")
581 if "####" in class_content:
582 raise RuntimeError( # pragma: no cover
583 "Substring '####' should not be part of the script for '{}'\n{}".format(
584 model.__name__, class_content))
586 # Model setup
587 class_content, atts = add_model_import_init(
588 class_content, model, optimisation,
589 extra, merged_options)
590 class_content = class_content.replace(
591 "class TemplateBenchmark",
592 "class {}".format(class_name))
594 # dtype, dofit
595 atts.append("chk_method_name = %r" % method_name)
596 atts.append("par_scenario = %r" % scenario)
597 atts.append("par_problem = %r" % problem)
598 atts.append("par_optimisation = %r" % optimisation)
599 if not dofit:
600 atts.append("par_dofit = False")
601 if merged_options is not None and len(merged_options) > 0:
602 atts.append("par_convopts = %r" % format_conv_options(
603 conv_options, model.__name__))
604 atts.append("par_full_test_name = %r" % full_class_name)
606 simple_name = _make_simple_name(name)
607 atts.append("benchmark_name = %r" % simple_name)
608 atts.append("pretty_name = %r" % simple_name)
610 if atts:
611 class_content = class_content.replace(
612 "# additional parameters",
613 "\n ".join(atts))
614 if prefix_import != '.':
615 class_content = class_content.replace(
616 " from .", "from .{}".format(prefix_import))
618 # Check compilation
619 try:
620 compile(class_content, filename, 'exec')
621 except SyntaxError as e: # pragma: no cover
622 raise SyntaxError("Unable to compile model '{}'\n{}".format(
623 model.__name__, class_content)) from e
625 # Verifies missing imports.
626 to_import, _ = verify_code(class_content, exc=False)
627 try:
628 miss = find_missing_sklearn_imports(to_import)
629 except ValueError as e: # pragma: no cover
630 raise ValueError(
631 "Unable to check import in script\n{}".format(
632 class_content)) from e
633 class_content = class_content.replace(
634 "# __IMPORTS__", "\n".join(miss))
635 verify_code(class_content, exc=True)
636 class_content = class_content.replace(
637 "par_extra = {", "par_extra = {\n")
638 class_content = remove_extra_spaces_and_pep8(
639 class_content, aggressive=True)
641 # Check compilation again
642 try:
643 obj = compile(class_content, filename, 'exec')
644 except SyntaxError as e: # pragma: no cover
645 raise SyntaxError("Unable to compile model '{}'\n{}".format(
646 model.__name__,
647 _display_code_lines(class_content))) from e
649 # executes to check import
650 if execute:
651 try:
652 exec(obj, globals(), locals()) # pylint: disable=W0122
653 except Exception as e: # pragma: no cover
654 raise RuntimeError(
655 "Unable to process class '{}' ('{}') a script due to '{}'\n{}".format(
656 model.__name__, filename, str(e),
657 _display_code_lines(class_content))) from e
659 # Saves
660 fullname = os.path.join(location, filename)
661 names.append(fullname)
662 with open(fullname, "w", encoding='utf-8') as f:
663 f.write(class_content)
665 if location_pyspy is not None:
666 # adding configuration for pyspy
667 class_name = re.compile(
668 'class ([A-Za-z_0-9]+)[(]').findall(class_content)[0]
669 fullname_pyspy = os.path.splitext(
670 os.path.join(location_pyspy, filename))[0]
671 pyfold = os.path.splitext(os.path.split(fullname)[-1])[0]
673 dtypes = ['float', 'double'] if '-64' in problem else ['float']
674 for dim in dims:
675 for nf in n_features:
676 for opset in opsets:
677 for dtype in dtypes:
678 for opt in nck_opts:
679 tmpl = pyspy_template.replace(
680 '__PATH__', location)
681 tmpl = tmpl.replace(
682 '__CLASSNAME__', class_name)
683 tmpl = tmpl.replace('__PYFOLD__', pyfold)
684 opt = "" if opt == {} else opt
686 first = True
687 for rt in runtime:
688 if first:
689 tmpl += textwrap.dedent("""
691 def profile0_{rt}(iter, cl, N, nf, opset, dtype, optim):
692 return setup_profile0(iter, cl, '{rt}', N, nf, opset, dtype, optim)
693 iter = profile0_{rt}(iter, cl, {dim}, {nf}, {opset}, '{dtype}', {opt})
694 print(datetime.now(), "iter", iter)
696 """).format(rt=rt, dim=dim, nf=nf, opset=opset,
697 dtype=dtype, opt="%r" % opt)
698 first = False
700 tmpl += textwrap.dedent("""
702 def profile_{rt}(iter, cl, N, nf, opset, dtype, optim):
703 return setup_profile(iter, cl, '{rt}', N, nf, opset, dtype, optim)
704 profile_{rt}(iter, cl, {dim}, {nf}, {opset}, '{dtype}', {opt})
705 print(datetime.now(), "iter", iter)
707 """).format(rt=rt, dim=dim, nf=nf, opset=opset,
708 dtype=dtype, opt="%r" % opt)
710 thename = "{n}_{dim}_{nf}_{opset}_{dtype}_{opt}.py".format(
711 n=fullname_pyspy, dim=dim, nf=nf,
712 opset=opset, dtype=dtype, opt=_optdict2string(opt))
713 with open(thename, 'w', encoding='utf-8') as f:
714 f.write(tmpl)
715 names.append(thename)
717 ext = '.bat' if sys.platform.startswith(
718 'win') else '.sh'
719 script = os.path.splitext(thename)[0] + ext
720 short = os.path.splitext(
721 os.path.split(thename)[-1])[0]
722 with open(script, 'w', encoding='utf-8') as f:
723 f.write('py-spy record --native --function --rate=10 -o {n}_fct.svg -- {py} {n}.py\n'.format(
724 py=sys.executable, n=short))
725 f.write('py-spy record --native --rate=10 -o {n}_line.svg -- {py} {n}.py\n'.format(
726 py=sys.executable, n=short))
728 return names