Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Measures time processing for ONNX models.
4"""
5import pickle
6import os
7import sklearn
8from ...tools.ort_wrapper import InferenceSession, OrtFail
9from .. import OnnxInference
10from .validate_helper import default_time_kwargs, measure_time, _multiply_time_kwargs
11from .validate_benchmark import make_n_rows
14class SimplifiedOnnxInference:
15 "Simple wrapper around InferenceSession which imitates OnnxInference."
17 def __init__(self, ort):
18 self.sess = InferenceSession(ort)
20 @property
21 def input_names(self):
22 "Returns InferenceSession input names."
23 return [_.name for _ in self.sess.get_inputs()]
25 def run(self, input):
26 "Calls InferenceSession.run."
27 return self.sess.run(None, input)
30def enumerate_benchmark_replay(folder, runtime='python', time_kwargs=None,
31 skip_long_test=True, time_kwargs_fact=None,
32 time_limit=4, verbose=1, fLOG=None):
33 """
34 Replays a benchmark stored with function
35 :func:`enumerate_validated_operator_opsets
36 <mlprodict.onnxrt.validate.validate.enumerate_validated_operator_opsets>`
37 or command line :ref:`validate_runtime <l-cmd-validate_runtime>`.
38 Enumerates the results.
40 @param folder folder where to find pickled files, all files must have
41 *pkl* or *pickle* extension
42 @param runtime runtime or runtimes
43 @param time_kwargs to define a more precise way to measure a model
44 @param skip_long_test skips tests for high values of N if they seem too long
45 @param time_kwargs_fact see :func:`_multiply_time_kwargs <mlprodict.onnxrt.validate.validate_helper._multiply_time_kwargs>`
46 @param time_limit to skip the rest of the test after this limit (in second)
47 @param verbose if >= 1, uses :epkg:`tqdm`
48 @param fLOG logging function
49 @return iterator on results
50 """
51 files = [_ for _ in os.listdir(folder) if _.endswith(
52 ".pkl") or _.endswith("_.pickle")]
53 if len(files) == 0:
54 raise FileNotFoundError(
55 "Unable to find any file in folder '{}'.".format(folder))
57 if time_kwargs in (None, ''):
58 time_kwargs = default_time_kwargs()
60 if isinstance(runtime, str):
61 runtime = runtime.split(",")
63 loop = files
64 if verbose >= 1:
65 try:
66 from tqdm import tqdm
67 loop = tqdm(files)
68 except ImportError: # pragma: no cover
69 pass
71 for pkl in loop:
72 if "ERROR" in pkl:
73 # An error.
74 if verbose >= 2 and fLOG is not None: # pragma: no cover
75 fLOG( # pragma: no cover
76 "[enumerate_benchmark_replay] skip '{}'.".format(pkl))
77 continue # pragma: no cover
78 if verbose >= 2 and fLOG is not None:
79 fLOG("[enumerate_benchmark_replay] process '{}'.".format(pkl))
80 row = {}
81 with open(os.path.join(folder, pkl), 'rb') as f:
82 obj = pickle.load(f)
83 X_test = obj['X_test']
84 ort_test = obj['Xort_test']
85 onx = obj['onnx_bytes']
86 model = obj['skl_model']
87 tkw = _multiply_time_kwargs(time_kwargs, time_kwargs_fact, model)
88 row['folder'] = folder
89 row['filename'] = pkl
90 row['n_features'] = X_test.shape[1]
92 for key in ['assume_finite', 'conv_options',
93 'init_types', 'idtype', 'method_name', 'n_features',
94 'name', 'optim', 'opset', 'predict_kwargs',
95 'output_index', 'problem', 'scenario']:
96 row[key] = obj['obs_op'][key]
98 # 'bench-batch',
99 # 'bench-skl',
101 oinfs = {}
102 for rt in runtime:
103 if rt == 'onnxruntime':
104 try:
105 oinfs[rt] = SimplifiedOnnxInference(onx)
106 except (OrtFail, RuntimeError) as e: # pragma: no cover
107 row['ERROR'] = str(e)
108 oinfs[rt] = None
109 else:
110 try:
111 oinfs[rt] = OnnxInference(onx, runtime=rt)
112 except (OrtFail, RuntimeError) as e: # pragma: no cover
113 row['ERROR'] = str(e)
114 oinfs[rt] = None
116 for k, v in sorted(tkw.items()):
117 if verbose >= 3 and fLOG is not None:
118 fLOG( # pragma: no cover
119 "[enumerate_benchmark_replay] process n_rows={} - {}".format(k, v))
120 xt = make_n_rows(X_test, k)
121 number = v['number']
122 repeat = v['repeat']
124 meth = getattr(model, row['method_name'])
125 with sklearn.config_context(assume_finite=row['assume_finite']):
126 skl = measure_time(lambda x: meth(x), xt,
127 number=number, repeat=repeat,
128 div_by_number=True)
129 if verbose >= 4 and fLOG is not None:
130 fLOG( # pragma: no cover
131 "[enumerate_benchmark_replay] skl={}".format(skl))
132 row['%d-skl-details' % k] = skl
133 row['%d-skl' % k] = skl['average']
135 xto = make_n_rows(ort_test, k)
136 for rt in runtime:
137 oinf = oinfs[rt]
138 if oinf is None:
139 continue # pragma: no cover
140 if len(oinf.input_names) != 1:
141 raise NotImplementedError( # pragma: no cover
142 "This function only allows one input not {}".format(
143 len(oinf.input_names)))
144 name = oinf.input_names[0]
145 ort = measure_time(lambda x: oinf.run({name: x}), xto,
146 number=number, repeat=repeat,
147 div_by_number=True)
148 if verbose >= 4 and fLOG is not None:
149 fLOG( # pragma: no cover
150 "[enumerate_benchmark_replay] {}={}".format(rt, ort))
151 row['%d-%s-detail' % (k, rt)] = ort
152 row['%d-%s' % (k, rt)] = ort['average']
153 yield row