Coverage for src/mlstatpy/ml/ml_grid_benchmark.py: 82%

152 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-27 05:59 +0100

1# -*- coding: utf-8 -*- 

2""" 

3@file 

4@brief About Machine Learning Benchmark 

5""" 

6import os 

7import numpy 

8from sklearn.model_selection import train_test_split 

9from sklearn.base import ClusterMixin 

10from sklearn.metrics import silhouette_score 

11from pyquickhelper.loghelper import noLOG 

12from pyquickhelper.benchhelper import GridBenchMark 

13 

14 

15class MlGridBenchMark(GridBenchMark): 

16 """ 

17 The class tests a list of model over a list of datasets. 

18 """ 

19 

20 def __init__(self, name, datasets, clog=None, fLOG=noLOG, path_to_images=".", 

21 cache_file=None, progressbar=None, graphx=None, graphy=None, 

22 **params): 

23 """ 

24 @param name name of the test 

25 @param datasets list of dictionary of dataframes 

26 @param clog see @see cl CustomLog or string 

27 @param fLOG logging function 

28 @param params extra parameters 

29 @param path_to_images path to images and intermediate results 

30 @param cache_file cache file 

31 @param progressbar relies on *tqdm*, example *tnrange* 

32 @param graphx list of variables to use as X axis 

33 @param graphy list of variables to use as Y axis 

34 

35 If *cache_file* is specified, the class will store the results of the 

36 method :meth:`bench <pyquickhelper.benchhelper.benchmark.GridBenchMark.bench>`. 

37 On a second run, the function load the cache 

38 and run modified or new run (in *param_list*). 

39 

40 *datasets* should be a dictionary with dataframes a values 

41 with the following keys: 

42 

43 * ``'X'``: features 

44 * ``'Y'``: labels (optional) 

45 """ 

46 GridBenchMark.__init__(self, name=name, datasets=datasets, clog=clog, fLOG=fLOG, 

47 path_to_images=path_to_images, cache_file=cache_file, 

48 progressbar=progressbar, **params) 

49 self._xaxis = graphx 

50 self._yaxis = graphy 

51 

52 def preprocess_dataset(self, dsi, **params): 

53 """ 

54 Splits the dataset into train and test. 

55 

56 @param dsi dataset index 

57 @param params additional parameters 

58 @return dataset (like info), dictionary for metrics 

59 """ 

60 ds, appe, params = GridBenchMark.preprocess_dataset( 

61 self, dsi, **params) 

62 

63 no_split = ds["no_split"] if "no_split" in ds else False 

64 

65 if no_split: 

66 self.fLOG("[MlGridBenchMark.preprocess_dataset] no split") 

67 return (ds, ds), appe, params 

68 

69 self.fLOG("[MlGridBenchMark.preprocess_dataset] split train test") 

70 spl = ["X", "Y", "weight", "group"] 

71 names = [_ for _ in spl if _ in ds] 

72 if len(names) == 0: 

73 raise ValueError( # pragma: no cover 

74 "No dataframe or matrix was found.") 

75 mats = [ds[_] for _ in names] 

76 

77 pars = {"train_size", "test_size"} 

78 options = {k: v for k, v in params.items() if k in pars} 

79 for k in pars: 

80 if k in params: 

81 del params[k] 

82 

83 res = train_test_split(*mats, **options) 

84 

85 train = {} 

86 for i, n in enumerate(names): 

87 train[n] = res[i * 2] 

88 test = {} 

89 for i, n in enumerate(names): 

90 test[n] = res[i * 2 + 1] 

91 

92 self.fLOG("[MlGridBenchMark.preprocess_dataset] done") 

93 return (train, test), appe, params 

94 

95 def bench_experiment(self, ds, **params): # pylint: disable=W0237 

96 """ 

97 Calls meth *fit*. 

98 """ 

99 if not isinstance(ds, tuple) and len(ds) != 2: 

100 raise TypeError( # pragma: no cover 

101 "ds must a tuple with two dictionaries train, test") 

102 if "model" not in params: 

103 raise KeyError( # pragma: no cover 

104 "params must contains key 'model'") 

105 model = params["model"] 

106 # we assume model is a function which creates a model 

107 model = model() 

108 del params["model"] 

109 return self.fit(ds[0], model, **params) 

110 

111 def predict_score_experiment(self, ds, model, **params): # pylint: disable=W0237 

112 """ 

113 Calls method *score*. 

114 """ 

115 if not isinstance(ds, tuple) and len(ds) != 2: 

116 raise TypeError( # pragma: no cover 

117 "ds must a tuple with two dictionaries train, test") 

118 if "model" in params: 

119 raise KeyError( # pragma: no cover 

120 "params must not contains key 'model'") 

121 return self.score(ds[1], model, **params) 

122 

123 def fit(self, ds, model, **params): 

124 """ 

125 Trains a model. 

126 

127 @param ds dictionary with the data to use for training 

128 @param model model to train 

129 """ 

130 if "X" not in ds: 

131 raise KeyError( # pragma: no cover 

132 "ds must contain key 'X'") 

133 if "model" in params: 

134 raise KeyError( # pragma: no cover 

135 "params must not contain key 'model', this is the model to train") 

136 X = ds["X"] 

137 Y = ds.get("Y", None) 

138 weight = ds.get("weight", None) 

139 self.fLOG("[MlGridBenchMark.fit] fit", params) 

140 

141 train_params = params.get("train_params", {}) 

142 

143 if weight is not None: 

144 model.fit(X=X, y=Y, weight=weight, **train_params) 

145 else: 

146 model.fit(X=X, y=Y, **train_params) 

147 self.fLOG("[MlGridBenchMark.fit] Done.") 

148 return model 

149 

150 def score(self, ds, model, **params): 

151 """ 

152 Scores a model. 

153 """ 

154 X = ds["X"] 

155 Y = ds.get("Y", None) 

156 

157 if "weight" in ds: 

158 raise NotImplementedError( # pragma: no cover 

159 "weight are not used yet") 

160 

161 metrics = {} 

162 appe = {} 

163 

164 if hasattr(model, "score"): 

165 score = model.score(X, Y) 

166 metrics["own_score"] = score 

167 

168 if isinstance(model, ClusterMixin): 

169 # add silhouette 

170 if hasattr(model, "predict"): 

171 ypred = model.predict(X) 

172 elif hasattr(model, "transform"): 

173 ypred = model.transform(X) 

174 elif hasattr(model, "labels_"): 

175 ypred = model.labels_ 

176 if len(ypred.shape) > 1 and ypred.shape[1] > 1: 

177 ypred = numpy.argmax(ypred, axis=1) 

178 score = silhouette_score(X, ypred) 

179 metrics["silhouette"] = score 

180 

181 return metrics, appe 

182 

183 def end(self): 

184 """ 

185 nothing to do 

186 """ 

187 pass 

188 

189 def graphs(self, path_to_images): 

190 """ 

191 Plots multiples graphs. 

192 

193 @param path_to_images where to store images 

194 @return list of tuple (image_name, function to create the graph) 

195 """ 

196 import matplotlib.pyplot as plt # pylint: disable=C0415 

197 import matplotlib.cm as mcm # pylint: disable=C0415 

198 df = self.to_df() 

199 

200 def local_graph(vx, vy, ax=None, text=True, figsize=(5, 5)): 

201 btrys = set(df["_btry"]) 

202 ymin = df[vy].min() 

203 ymax = df[vy].max() 

204 decy = (ymax - ymin) / 50 

205 colors = mcm.rainbow(numpy.linspace(0, 1, len(btrys))) 

206 if len(btrys) == 0: 

207 raise ValueError("The benchmark is empty.") # pragma: no cover 

208 if ax is None: 

209 _, ax = plt.subplots(1, 1, figsize=figsize) # pragma: no cover 

210 ax.grid(True) # pragma: no cover 

211 for i, btry in enumerate(sorted(btrys)): 

212 subset = df[df["_btry"] == btry] 

213 if subset.shape[0] > 0: 

214 tx = subset[vx].mean() 

215 ty = subset[vy].mean() 

216 if not numpy.isnan(tx) and not numpy.isnan(ty): 

217 subset.plot(x=vx, y=vy, kind="scatter", 

218 label=btry, ax=ax, color=colors[i]) 

219 if text: 

220 ax.text(tx, ty + decy, btry, size='small', 

221 color=colors[i], ha='center', va='bottom') 

222 ax.set_xlabel(vx) 

223 ax.set_ylabel(vy) 

224 return ax 

225 

226 res = [] 

227 if self._xaxis is not None and self._yaxis is not None: 

228 for vx in self._xaxis: 

229 for vy in self._yaxis: 

230 self.fLOG(f"Plotting {vx} x {vy}") 

231 func_graph = lambda ax=None, text=True, vx=vx, vy=vy, **kwargs: \ 

232 local_graph(vx, vy, ax=ax, text=text, **kwargs) 

233 

234 if path_to_images is not None: 

235 img = os.path.join( 

236 path_to_images, f"img-{self.Name}-{vx}x{vy}.png") 

237 gr = self.LocalGraph( 

238 func_graph, img, root=path_to_images) 

239 self.fLOG(f"Saving '{img}'") 

240 fig, ax = plt.subplots(1, 1, figsize=(8, 8)) 

241 gr.plot(ax=ax, text=True) 

242 fig.savefig(img) 

243 self.fLOG("Done") 

244 res.append(gr) 

245 plt.close('all') 

246 else: 

247 gr = self.LocalGraph(func_graph) 

248 res.append(gr) 

249 return res 

250 

251 def plot_graphs(self, grid=None, text=True, **kwargs): 

252 """ 

253 Plots all graphs in the same graphs. 

254 

255 @param grid grid of axes 

256 @param text add legend title on the graph 

257 @return grid 

258 """ 

259 nb = len(self.Graphs) 

260 if nb == 0: 

261 raise ValueError("No graph to plot.") # pragma: no cover 

262 

263 nb = len(self.Graphs) 

264 if nb % 2 == 0: 

265 size = nb // 2, 2 

266 else: 

267 size = nb // 2 + 1, 2 

268 

269 if grid is None: 

270 import matplotlib.pyplot as plt # pylint: disable=C0415 

271 fg = kwargs.get('figsize', (5 * size[0], 10)) 

272 _, grid = plt.subplots(size[0], size[1], figsize=fg) 

273 if 'figsize' in kwargs: 

274 del kwargs['figsize'] # pragma: no cover 

275 else: 

276 shape = grid.shape 

277 if shape[0] * shape[1] < nb: 

278 raise ValueError( # pragma: no cover 

279 f"The graph is not big enough {shape} < {nb}") 

280 

281 x = 0 

282 y = 0 

283 for i, gr in enumerate(self.Graphs): 

284 self.fLOG(f"Plot graph {i + 1}/{nb}") 

285 gr.plot(ax=grid[y, x], text=text, **kwargs) 

286 x += 1 

287 if x >= grid.shape[1]: 

288 x = 0 

289 y += 1 

290 return grid