Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Implements a benchmark about performance.
4"""
5import os
6import pickle
7from time import perf_counter as time_perf
8import textwrap
9import numpy
10from .bench_helper import enumerate_options
13class BenchPerfTest:
14 """
15 Defines a bench perf test.
16 See example :ref:`l-bench-slk-poly`.
18 .. faqref::
19 :title: Conventions for N, dim
21 In all the package, *N* refers to the number of observations,
22 *dim* the dimension or the number of features.
23 """
25 def __init__(self, **kwargs):
26 for k, v in kwargs.items():
27 setattr(self, k, v)
29 def data(self, **opts):
30 """
31 Generates one testing dataset.
33 @return dataset, usually a list of arrays
34 such as *X, y*
35 """
36 raise NotImplementedError()
38 def fcts(self, **opts):
39 """
40 Returns the function call to test,
41 it produces a dictionary ``{name: fct}``
42 where *name* is the name of the function
43 and *fct* the function to benchmark
44 """
45 raise NotImplementedError()
47 def validate(self, results, **kwargs):
48 """
49 Runs validations after the test was done
50 to make sure it was valid.
52 @param results results to validate, list of tuple
53 ``(parameters, results)``
54 @param kwargs additional information in case
55 errors must traced
57 The function raised an exception or not.
58 """
59 pass
61 def dump_error(self, msg, **kwargs):
62 """
63 Dumps everything which is needed to investigate an error.
64 Everything is pickled in the current folder or *dump_folder*
65 is attribute *dump_folder* was defined. This folder is created
66 if it does not exist.
68 @param msg message
69 @param kwargs needed data to investigate
70 @return filename
71 """
72 dump_folder = getattr(self, "dump_folder", '.')
73 if not os.path.exists(dump_folder):
74 os.makedirs(dump_folder) # pragma: no cover
75 pattern = os.path.join(
76 dump_folder, "BENCH-ERROR-{0}-%d.pkl".format(
77 self.__class__.__name__))
78 err = 0
79 name = pattern % err
80 while os.path.exists(name):
81 err += 1
82 name = pattern % err
83 # We remove knowns object which cannot be pickled.
84 rem = []
85 for k, v in kwargs.items():
86 if "InferenceSession" in str(v):
87 rem.append(k)
88 for k in rem:
89 kwargs[k] = str(kwargs[k])
90 with open(name, "wb") as f:
91 pickle.dump({'msg': msg, 'data': kwargs}, f)
94class BenchPerf:
95 """
96 Factorizes code to compare two implementations.
97 See example :ref:`l-bench-slk-poly`.
98 """
100 def __init__(self, pbefore, pafter, btest, filter_test=None,
101 profilers=None):
102 """
103 @param pbefore parameters before calling *fct*,
104 dictionary ``{name: [list of values]}``,
105 these parameters are sent to the instance
106 of @see cl BenchPerfTest to test
107 @param pafter parameters after calling *fct*,
108 dictionary ``{name: [list of values]}``,
109 these parameters are sent to method
110 :meth:`BenchPerfTest.fcts
111 <pymlbenchmark.benchmark.benchmark_perf.BenchPerfTest.fcts>`
112 @param btest instance of @see cl BenchPerfTest
113 @param filter_test function which tells if a configuration
114 must be tested or not, None to test them
115 all
116 @param profilers list of profilers to run
118 Every parameter specifies a function is called through
119 a method. The user can only overwrite it.
120 """
121 self.pbefore = pbefore
122 self.pafter = pafter
123 self.btest = btest
124 self.filter_test = filter_test
125 self.profilers = profilers
127 def __repr__(self):
128 "usual"
129 return '\n'.join(textwrap.wrap(
130 "%s(pbefore=%r, pafter=%r, btest=%r, filter_test=%r, profilers=%r)" % (
131 self.__class__.__name__, self.pbefore, self.pafter, self.btest,
132 self.filter_test, self.profilers),
133 subsequent_indent=' '))
135 def fct_filter_test(self, **conf):
136 """
137 Tells if the test by *conf* is valid or not.
139 @param conf dictionary ``{name: value}``
140 @return boolean
141 """
142 if self.filter_test is None:
143 return True
144 return self.filter_test(**conf)
146 def enumerate_tests(self, options):
147 """
148 Enumerates all possible options.
150 @param options dictionary ``{name: list of values}``
151 @return list of dictionary ``{name: value}``
153 The function applies the method *fct_filter_test*.
154 """
155 for row in enumerate_options(options, self.fct_filter_test):
156 yield row
158 def enumerate_run_benchs(self, repeat=10, verbose=False,
159 stop_if_error=True, validate=True,
160 number=1):
161 """
162 Runs the benchmark.
164 @param repeat number of repeatition of the same call
165 with different datasets
166 @param verbose if True, use :epkg:`tqdm`
167 @param stop_if_error by default, it stops when method *validate*
168 fails, if False, the function stores the exception
169 @param validate compare the outputs against the baseline
170 @param number number of times to call the same function,
171 the method then measure this number calls
172 @return yields dictionaries with all the metrics
173 """
174 all_opts = self.pbefore.copy()
175 all_opts.update(self.pafter)
176 all_tests = list(self.enumerate_tests(all_opts))
178 if verbose:
179 from tqdm import tqdm # pylint: disable=C0415
180 loop = iter(tqdm(range(len(all_tests))))
181 else:
182 loop = iter(all_tests)
184 for a_opt in self.enumerate_tests(self.pbefore):
185 if not self.fct_filter_test(**a_opt):
186 continue
188 inst = self.btest(**a_opt)
190 for b_opt in self.enumerate_tests(self.pafter):
191 obs = b_opt.copy()
192 obs.update(a_opt)
193 if not self.fct_filter_test(**obs):
194 continue
196 fcts = inst.fcts(**obs)
197 if not isinstance(fcts, list):
198 raise TypeError( # pragma: no cover
199 "Method fcts must return a list of dictionaries (name, fct) "
200 "not {}".format(fcts))
202 data = [inst.data(**obs) for r in range(repeat)]
203 if not isinstance(data, (list, tuple)):
204 raise ValueError( # pragma: no cover
205 "Method *data* must return a list or a tuple.")
206 obs["repeat"] = len(data)
207 obs["number"] = number
208 results = []
209 stores = []
211 for fct in fcts:
212 if not isinstance(fct, dict) or 'fct' not in fct:
213 raise ValueError( # pragma: no cover
214 "Method fcts must return a list of dictionaries with keys "
215 "('name', 'fct') not {}".format(fct))
216 f = fct['fct']
217 del fct['fct']
218 times = []
219 fct.update(obs)
221 if isinstance(f, tuple):
222 if len(f) != 2:
223 raise RuntimeError( # pragma: no cover
224 "If *f* is a tuple, it must return two function f1, f2.")
225 f1, f2 = f
226 dt = data[0]
227 dt2 = f1(*dt)
228 self.profile(fct, lambda: f2(*dt2))
229 for idt, dt in enumerate(data):
230 dt2 = f1(*dt)
231 if number == 1:
232 st = time_perf()
233 r = f2(*dt2)
234 d = time_perf() - st
235 else:
236 st = time_perf()
237 for _ in range(number):
238 r = f2(*dt2)
239 d = time_perf() - st
240 times.append(d)
241 results.append((idt, fct, r))
242 else:
243 dt = data[0]
244 self.profile(fct, lambda: f(*dt))
245 for idt, dt in enumerate(data):
246 if number == 1:
247 st = time_perf()
248 r = f(*dt)
249 d = time_perf() - st
250 else:
251 st = time_perf()
252 for _ in range(number):
253 r = f(*dt)
254 d = time_perf() - st
255 times.append(d)
256 results.append((idt, fct, r))
257 times.sort()
258 fct['min'] = times[0]
259 fct['max'] = times[-1]
260 if len(times) > 5:
261 fct['min3'] = times[3]
262 fct['max3'] = times[-3]
263 times = numpy.array(times)
264 fct['mean'] = times.mean()
265 std = times.std()
266 if len(times) >= 4:
267 fct['lower'] = max(
268 fct['min'], fct['mean'] - std * 1.96)
269 fct['upper'] = min(
270 fct['max'], fct['mean'] + std * 1.96)
271 else:
272 fct['lower'] = fct['min']
273 fct['upper'] = fct['max']
274 fct['count'] = len(times)
275 fct['median'] = numpy.median(times)
276 stores.append(fct)
278 if validate:
279 if stop_if_error:
280 up = inst.validate(results, data=data)
281 else:
282 try:
283 up = inst.validate(results, data=data)
284 except Exception as e: # pylint: disable=W0703
285 msg = str(e).replace("\n", " ").replace(",", " ")
286 up = {'error': msg, 'error_c': 1}
287 if up is not None:
288 for fct in stores:
289 fct.update(up)
290 else:
291 for fct in stores:
292 fct['error_c'] = 0
293 for fct in stores:
294 yield fct
295 next(loop) # pylint: disable=R1708
297 def profile(self, kwargs, fct):
298 """
299 Checks if a profiler applies on this set
300 of parameters, then profiles function *fct*.
302 @param kwargs dictionary of parameters
303 @param fct function to measure
304 """
305 if self.profilers:
306 for prof in self.profilers:
307 if prof.match(**kwargs):
308 prof.profile(fct, **kwargs)