Coverage for src/pymlbenchmark/benchmark/benchmark_perf.py: 99%

149 statements  

« prev     ^ index     » next       coverage.py v7.2.1, created at 2023-03-08 00:27 +0100

1""" 

2@file 

3@brief Implements a benchmark about performance. 

4""" 

5import os 

6import pickle 

7from time import perf_counter as time_perf 

8import textwrap 

9import numpy 

10from .bench_helper import enumerate_options 

11 

12 

13class BenchPerfTest: 

14 """ 

15 Defines a bench perf test. 

16 See example :ref:`l-bench-slk-poly`. 

17 

18 .. faqref:: 

19 :title: Conventions for N, dim 

20 

21 In all the package, *N* refers to the number of observations, 

22 *dim* the dimension or the number of features. 

23 """ 

24 

25 def __init__(self, **kwargs): 

26 for k, v in kwargs.items(): 

27 setattr(self, k, v) 

28 

29 def data(self, **opts): 

30 """ 

31 Generates one testing dataset. 

32 

33 @return dataset, usually a list of arrays 

34 such as *X, y* 

35 """ 

36 raise NotImplementedError() 

37 

38 def fcts(self, **opts): 

39 """ 

40 Returns the function call to test, 

41 it produces a dictionary ``{name: fct}`` 

42 where *name* is the name of the function 

43 and *fct* the function to benchmark 

44 """ 

45 raise NotImplementedError() 

46 

47 def validate(self, results, **kwargs): 

48 """ 

49 Runs validations after the test was done 

50 to make sure it was valid. 

51 

52 @param results results to validate, list of tuple 

53 ``(parameters, results)`` 

54 @param kwargs additional information in case 

55 errors must traced 

56 

57 The function raised an exception or not. 

58 """ 

59 pass 

60 

61 def dump_error(self, msg, **kwargs): 

62 """ 

63 Dumps everything which is needed to investigate an error. 

64 Everything is pickled in the current folder or *dump_folder* 

65 is attribute *dump_folder* was defined. This folder is created 

66 if it does not exist. 

67 

68 @param msg message 

69 @param kwargs needed data to investigate 

70 @return filename 

71 """ 

72 dump_folder = getattr(self, "dump_folder", '.') 

73 if not os.path.exists(dump_folder): 

74 os.makedirs(dump_folder) # pragma: no cover 

75 pattern = os.path.join( 

76 dump_folder, "BENCH-ERROR-{0}-%d.pkl".format( 

77 self.__class__.__name__)) 

78 err = 0 

79 name = pattern % err 

80 while os.path.exists(name): 

81 err += 1 

82 name = pattern % err 

83 # We remove knowns object which cannot be pickled. 

84 rem = [] 

85 for k, v in kwargs.items(): 

86 if "InferenceSession" in str(v): 

87 rem.append(k) 

88 for k in rem: 

89 kwargs[k] = str(kwargs[k]) 

90 with open(name, "wb") as f: 

91 pickle.dump({'msg': msg, 'data': kwargs}, f) 

92 

93 

94class BenchPerf: 

95 """ 

96 Factorizes code to compare two implementations. 

97 See example :ref:`l-bench-slk-poly`. 

98 """ 

99 

100 def __init__(self, pbefore, pafter, btest, filter_test=None, 

101 profilers=None): 

102 """ 

103 @param pbefore parameters before calling *fct*, 

104 dictionary ``{name: [list of values]}``, 

105 these parameters are sent to the instance 

106 of @see cl BenchPerfTest to test 

107 @param pafter parameters after calling *fct*, 

108 dictionary ``{name: [list of values]}``, 

109 these parameters are sent to method 

110 :meth:`BenchPerfTest.fcts 

111 <pymlbenchmark.benchmark.benchmark_perf.BenchPerfTest.fcts>` 

112 @param btest instance of @see cl BenchPerfTest 

113 @param filter_test function which tells if a configuration 

114 must be tested or not, None to test them 

115 all 

116 @param profilers list of profilers to run 

117 

118 Every parameter specifies a function is called through 

119 a method. The user can only overwrite it. 

120 """ 

121 self.pbefore = pbefore 

122 self.pafter = pafter 

123 self.btest = btest 

124 self.filter_test = filter_test 

125 self.profilers = profilers 

126 

127 def __repr__(self): 

128 "usual" 

129 return '\n'.join(textwrap.wrap( 

130 "%s(pbefore=%r, pafter=%r, btest=%r, filter_test=%r, profilers=%r)" % ( 

131 self.__class__.__name__, self.pbefore, self.pafter, self.btest, 

132 self.filter_test, self.profilers), 

133 subsequent_indent=' ')) 

134 

135 def fct_filter_test(self, **conf): 

136 """ 

137 Tells if the test by *conf* is valid or not. 

138 

139 @param conf dictionary ``{name: value}`` 

140 @return boolean 

141 """ 

142 if self.filter_test is None: 

143 return True 

144 return self.filter_test(**conf) 

145 

146 def enumerate_tests(self, options): 

147 """ 

148 Enumerates all possible options. 

149 

150 @param options dictionary ``{name: list of values}`` 

151 @return list of dictionary ``{name: value}`` 

152 

153 The function applies the method *fct_filter_test*. 

154 """ 

155 for row in enumerate_options(options, self.fct_filter_test): 

156 yield row 

157 

158 def enumerate_run_benchs(self, repeat=10, verbose=False, 

159 stop_if_error=True, validate=True, 

160 number=1): 

161 """ 

162 Runs the benchmark. 

163 

164 @param repeat number of repeatition of the same call 

165 with different datasets 

166 @param verbose if True, use :epkg:`tqdm` 

167 @param stop_if_error by default, it stops when method *validate* 

168 fails, if False, the function stores the exception 

169 @param validate compare the outputs against the baseline 

170 @param number number of times to call the same function, 

171 the method then measure this number calls 

172 @return yields dictionaries with all the metrics 

173 """ 

174 all_opts = self.pbefore.copy() 

175 all_opts.update(self.pafter) 

176 all_tests = list(self.enumerate_tests(all_opts)) 

177 

178 if verbose: 

179 from tqdm import tqdm # pylint: disable=C0415 

180 loop = iter(tqdm(range(len(all_tests)))) 

181 else: 

182 loop = iter(all_tests) 

183 

184 for a_opt in self.enumerate_tests(self.pbefore): 

185 if not self.fct_filter_test(**a_opt): 

186 continue 

187 

188 inst = self.btest(**a_opt) 

189 

190 for b_opt in self.enumerate_tests(self.pafter): 

191 obs = b_opt.copy() 

192 obs.update(a_opt) 

193 if not self.fct_filter_test(**obs): 

194 continue 

195 

196 fcts = inst.fcts(**obs) 

197 if not isinstance(fcts, list): 

198 raise TypeError( # pragma: no cover 

199 "Method fcts must return a list of dictionaries (name, fct) " 

200 "not {}".format(fcts)) 

201 

202 data = [inst.data(**obs) for r in range(repeat)] 

203 if not isinstance(data, (list, tuple)): 

204 raise ValueError( # pragma: no cover 

205 "Method *data* must return a list or a tuple.") 

206 obs["repeat"] = len(data) 

207 obs["number"] = number 

208 results = [] 

209 stores = [] 

210 

211 for fct in fcts: 

212 if not isinstance(fct, dict) or 'fct' not in fct: 

213 raise ValueError( # pragma: no cover 

214 "Method fcts must return a list of dictionaries with keys " 

215 "('name', 'fct') not {}".format(fct)) 

216 f = fct['fct'] 

217 del fct['fct'] 

218 times = [] 

219 fct.update(obs) 

220 

221 if isinstance(f, tuple): 

222 if len(f) != 2: 

223 raise RuntimeError( # pragma: no cover 

224 "If *f* is a tuple, it must return two function f1, f2.") 

225 f1, f2 = f 

226 dt = data[0] 

227 dt2 = f1(*dt) 

228 self.profile(fct, lambda: f2(*dt2)) 

229 for idt, dt in enumerate(data): 

230 dt2 = f1(*dt) 

231 if number == 1: 

232 st = time_perf() 

233 r = f2(*dt2) 

234 d = time_perf() - st 

235 else: 

236 st = time_perf() 

237 for _ in range(number): 

238 r = f2(*dt2) 

239 d = time_perf() - st 

240 times.append(d) 

241 results.append((idt, fct, r)) 

242 else: 

243 dt = data[0] 

244 self.profile(fct, lambda: f(*dt)) 

245 for idt, dt in enumerate(data): 

246 if number == 1: 

247 st = time_perf() 

248 r = f(*dt) 

249 d = time_perf() - st 

250 else: 

251 st = time_perf() 

252 for _ in range(number): 

253 r = f(*dt) 

254 d = time_perf() - st 

255 times.append(d) 

256 results.append((idt, fct, r)) 

257 times.sort() 

258 fct['min'] = times[0] 

259 fct['max'] = times[-1] 

260 if len(times) > 5: 

261 fct['min3'] = times[3] 

262 fct['max3'] = times[-3] 

263 times = numpy.array(times) 

264 fct['mean'] = times.mean() 

265 std = times.std() 

266 if len(times) >= 4: 

267 fct['lower'] = max( 

268 fct['min'], fct['mean'] - std * 1.96) 

269 fct['upper'] = min( 

270 fct['max'], fct['mean'] + std * 1.96) 

271 else: 

272 fct['lower'] = fct['min'] 

273 fct['upper'] = fct['max'] 

274 fct['count'] = len(times) 

275 fct['median'] = numpy.median(times) 

276 stores.append(fct) 

277 

278 if validate: 

279 if stop_if_error: 

280 up = inst.validate(results, data=data) 

281 else: 

282 try: 

283 up = inst.validate(results, data=data) 

284 except Exception as e: # pylint: disable=W0703 

285 msg = str(e).replace("\n", " ").replace(",", " ") 

286 up = {'error': msg, 'error_c': 1} 

287 if up is not None: 

288 for fct in stores: 

289 fct.update(up) 

290 else: 

291 for fct in stores: 

292 fct['error_c'] = 0 

293 for fct in stores: 

294 yield fct 

295 next(loop) # pylint: disable=R1708 

296 

297 def profile(self, kwargs, fct): 

298 """ 

299 Checks if a profiler applies on this set 

300 of parameters, then profiles function *fct*. 

301 

302 @param kwargs dictionary of parameters 

303 @param fct function to measure 

304 """ 

305 if self.profilers: 

306 for prof in self.profilers: 

307 if prof.match(**kwargs): 

308 prof.profile(fct, **kwargs)