Coverage for mlinsights/mlmodel/sklearn_testing.py: 99%

166 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-28 08:46 +0100

1""" 

2@file 

3@brief Helpers to test a model which follows :epkg:`scikit-learn` API. 

4""" 

5import copy 

6import pickle 

7import pprint 

8from unittest import TestCase 

9from io import BytesIO 

10from numpy import ndarray 

11from numpy.testing import assert_almost_equal 

12from pandas.testing import assert_frame_equal 

13from sklearn.base import BaseEstimator 

14from sklearn.model_selection import train_test_split 

15from sklearn.base import clone 

16from sklearn.pipeline import make_pipeline 

17from sklearn.model_selection import GridSearchCV 

18 

19 

20def train_test_split_with_none(X, y=None, sample_weight=None, random_state=0): 

21 """ 

22 Splits into train and test data even if they are None. 

23 

24 @param X X 

25 @param y y 

26 @param sample_weight sample weight 

27 @param random_state random state 

28 @return similar to :epkg:`scikit-learn:model_selection:train_test_split`. 

29 """ 

30 not_none = [_ for _ in [X, y, sample_weight] if _ is not None] 

31 res = train_test_split(*not_none) 

32 inc = len(not_none) 

33 trains = [] 

34 tests = [] 

35 for i in range(inc): 

36 trains.append(res[i * 2]) 

37 tests.append(res[i * 2 + 1]) 

38 while len(trains) < 3: 

39 trains.append(None) 

40 tests.append(None) 

41 X_train, y_train, w_train = trains 

42 X_test, y_test, w_test = tests 

43 return X_train, y_train, w_train, X_test, y_test, w_test 

44 

45 

46def test_sklearn_pickle(fct_model, X, y=None, sample_weight=None, **kwargs): 

47 """ 

48 Creates a model, fit, predict and check the prediction 

49 are similar after the model was pickled, unpickled. 

50 

51 @param fct_model function which creates the model 

52 @param X X 

53 @param y y 

54 @param sample_weight sample weight 

55 @param kwargs additional parameters for :epkg:`numpy:testing:assert_almost_equal` 

56 @return model, unpickled model 

57 

58 :raises: 

59 AssertionError 

60 """ 

61 X_train, y_train, w_train, X_test, _, __ = train_test_split_with_none( 

62 X, y, sample_weight) 

63 model = fct_model() 

64 if y_train is None and w_train is None: 

65 model.fit(X_train) 

66 else: 

67 try: 

68 model.fit(X_train, y_train, w_train) 

69 except TypeError: 

70 # Do not accept weights? 

71 model.fit(X_train, y_train) 

72 if hasattr(model, 'predict'): 

73 pred1 = model.predict(X_test) 

74 else: 

75 pred1 = model.transform(X_test) 

76 

77 st = BytesIO() 

78 pickle.dump(model, st) 

79 data = BytesIO(st.getvalue()) 

80 model2 = pickle.load(data) 

81 if hasattr(model2, 'predict'): 

82 pred2 = model2.predict(X_test) 

83 else: 

84 pred2 = model2.transform(X_test) 

85 if isinstance(pred1, ndarray): 

86 assert_almost_equal(pred1, pred2, **kwargs) 

87 else: 

88 assert_frame_equal(pred1, pred2, **kwargs) 

89 return model, model2 

90 

91 

92def _get_test_instance(): 

93 try: 

94 from pyquickhelper.pycode import ExtTestCase # pylint: disable=C0415 

95 cls = ExtTestCase 

96 except ImportError: # pragma: no cover 

97 

98 class _ExtTestCase(TestCase): 

99 "simple test classe with a more methods" 

100 

101 def assertIsInstance(self, inst, cltype): 

102 "checks that one instance is from one type" 

103 if not isinstance(inst, cltype): 

104 raise AssertionError( 

105 f"Unexpected type {type(inst)} != {cltype}.") 

106 

107 cls = _ExtTestCase 

108 return cls() 

109 

110 

111def test_sklearn_clone(fct_model, ext=None, copy_fitted=False): 

112 """ 

113 Tests that a cloned model is similar to the original one. 

114 

115 @param fct_model function which creates the model 

116 @param ext unit test class instance 

117 @param copy_fitted copy fitted parameters as well 

118 @return model, cloned model 

119 

120 :raises: 

121 AssertionError 

122 """ 

123 conv = fct_model() 

124 p1 = conv.get_params(deep=True) 

125 if copy_fitted: 

126 cloned = clone_with_fitted_parameters(conv) 

127 else: 

128 cloned = clone(conv) 

129 p2 = cloned.get_params(deep=True) 

130 if ext is None: 

131 ext = _get_test_instance() 

132 try: 

133 ext.assertEqual(set(p1), set(p2)) 

134 except AssertionError as e: # pragma no cover 

135 p1 = pprint.pformat(p1) 

136 p2 = pprint.pformat(p2) 

137 raise AssertionError( 

138 f"Differences between\n----\n{p1}\n----\n{p2}") from e 

139 

140 for k in sorted(p1): 

141 if isinstance(p1[k], BaseEstimator) and isinstance(p2[k], BaseEstimator): 

142 if copy_fitted: 

143 assert_estimator_equal(p1[k], p2[k]) 

144 elif isinstance(p1[k], list) and isinstance(p2[k], list): 

145 _assert_list_equal(p1[k], p2[k], ext) 

146 else: 

147 try: 

148 ext.assertEqual(p1[k], p2[k]) 

149 except AssertionError: # pragma no cover 

150 raise AssertionError( # pylint: disable=W0707 

151 f"Difference for key '{k}'\n==1 {p1[k]}\n==2 {p2[k]}") 

152 return conv, cloned 

153 

154 

155def _assert_list_equal(l1, l2, ext): 

156 if len(l1) != len(l2): 

157 raise AssertionError( # pragma no cover 

158 f"Lists have different length {len(l1)} != {len(l2)}") 

159 for a, b in zip(l1, l2): 

160 if isinstance(a, tuple) and isinstance(b, tuple): 

161 _assert_tuple_equal(a, b, ext) 

162 else: 

163 ext.assertEqual(a, b) 

164 

165 

166def _assert_dict_equal(a, b, ext): 

167 if not isinstance(a, dict): # pragma no cover 

168 raise TypeError(f'a is not dict but {type(a)}') 

169 if not isinstance(b, dict): # pragma no cover 

170 raise TypeError(f'b is not dict but {type(b)}') 

171 rows = [] 

172 for key in sorted(b): 

173 if key not in a: 

174 rows.append(f"** Added key '{key}' in b") 

175 elif isinstance(a[key], BaseEstimator) and isinstance(b[key], BaseEstimator): 

176 assert_estimator_equal(a[key], b[key], ext) 

177 else: 

178 if a[key] != b[key]: 

179 rows.append( 

180 "** Value != for key '{0}': != id({1}) != id({2})\n==1 {3}\n==2 {4}".format( 

181 key, id(a[key]), id(b[key]), a[key], b[key])) 

182 for key in sorted(a): 

183 if key not in b: 

184 rows.append(f"** Removed key '{key}' in a") 

185 if len(rows) > 0: 

186 raise AssertionError( # pragma: no cover 

187 "Dictionaries are different\n{0}".format('\n'.join(rows))) 

188 

189 

190def _assert_tuple_equal(t1, t2, ext): 

191 if len(t1) != len(t2): # pragma no cover 

192 raise AssertionError( 

193 f"Lists have different length {len(t1)} != {len(t2)}") 

194 for a, b in zip(t1, t2): 

195 if isinstance(a, BaseEstimator) and isinstance(b, BaseEstimator): 

196 assert_estimator_equal(a, b, ext) 

197 else: 

198 ext.assertEqual(a, b) 

199 

200 

201def assert_estimator_equal(esta, estb, ext=None): 

202 """ 

203 Checks that two models are equal. 

204 

205 @param esta first estimator 

206 @param estb second estimator 

207 @param ext unit test class 

208 

209 The function raises an exception if the comparison fails. 

210 """ 

211 if ext is None: 

212 ext = _get_test_instance() 

213 ext.assertIsInstance(esta, estb.__class__) 

214 ext.assertIsInstance(estb, esta.__class__) 

215 _assert_dict_equal(esta.get_params(), estb.get_params(), ext) 

216 for att in esta.__dict__: 

217 if (att.endswith('_') and not att.endswith('__')) or \ 

218 (att.startswith('_') and not att.startswith('__')): 

219 if not hasattr(estb, att): # pragma no cover 

220 raise AssertionError( 

221 "Missing fitted attribute '{}' class {}\n==1 {}\n==2 {}".format( 

222 att, esta.__class__, list(sorted(esta.__dict__)), list(sorted(estb.__dict__)))) 

223 if isinstance(getattr(esta, att), BaseEstimator): 

224 assert_estimator_equal( 

225 getattr(esta, att), getattr(estb, att), ext) 

226 else: 

227 ext.assertEqual(getattr(esta, att), getattr(estb, att)) 

228 for att in estb.__dict__: 

229 if att.endswith('_') and not att.endswith('__'): 

230 if not hasattr(esta, att): # pragma no cover 

231 raise AssertionError( 

232 "Missing fitted attribute\n==1 {}\n==2 {}".format( 

233 list(sorted(esta.__dict__)), list(sorted(estb.__dict__)))) 

234 

235 

236def test_sklearn_grid_search_cv(fct_model, X, y=None, sample_weight=None, **grid_params): 

237 """ 

238 Creates a model, checks that a grid search works with it. 

239 

240 @param fct_model function which creates the model 

241 @param X X 

242 @param y y 

243 @param sample_weight sample weight 

244 @param grid_params parameter to use to run the grid search. 

245 @return dictionary with results 

246 

247 :raises: 

248 AssertionError 

249 """ 

250 X_train, y_train, w_train, X_test, y_test, w_test = ( 

251 train_test_split_with_none(X, y, sample_weight)) 

252 model = fct_model() 

253 pipe = make_pipeline(model) 

254 name = model.__class__.__name__.lower() 

255 parameters = {name + "__" + k: v for k, v in grid_params.items()} 

256 if len(parameters) == 0: 

257 raise ValueError( 

258 "Some parameters must be tested when running grid search.") 

259 clf = GridSearchCV(pipe, parameters) 

260 if y_train is None and w_train is None: 

261 clf.fit(X_train) 

262 elif w_train is None: 

263 clf.fit(X_train, y_train) # pylint: disable=E1121 

264 else: 

265 clf.fit(X_train, y_train, w_train) # pylint: disable=E1121 

266 score = clf.score(X_test, y_test) 

267 ext = _get_test_instance() 

268 ext.assertIsInstance(score, float) 

269 return dict(model=clf, X_train=X_train, y_train=y_train, w_train=w_train, 

270 X_test=X_test, y_test=y_test, w_test=w_test, score=score) 

271 

272 

273def clone_with_fitted_parameters(est): 

274 """ 

275 Clones an estimator with the fitted results. 

276 

277 @param est estimator 

278 @return cloned object 

279 """ 

280 def adjust(obj1, obj2): 

281 if isinstance(obj1, list) and isinstance(obj2, list): 

282 for a, b in zip(obj1, obj2): 

283 adjust(a, b) 

284 elif isinstance(obj1, tuple) and isinstance(obj2, tuple): 

285 for a, b in zip(obj1, obj2): 

286 adjust(a, b) 

287 elif isinstance(obj1, dict) and isinstance(obj2, dict): 

288 for a, b in zip(obj1, obj2): 

289 adjust(obj1[a], obj2[b]) 

290 elif isinstance(obj1, BaseEstimator) and isinstance(obj2, BaseEstimator): 

291 for k in obj1.__dict__: 

292 if hasattr(obj2, k): 

293 v1 = getattr(obj1, k) 

294 if callable(v1): 

295 raise RuntimeError( # pragma: no cover 

296 f"Cannot migrate trained parameters for {obj1}.") 

297 elif isinstance(v1, BaseEstimator): 

298 v1 = getattr(obj1, k) 

299 setattr(obj2, k, clone_with_fitted_parameters(v1)) 

300 else: 

301 adjust(getattr(obj1, k), getattr(obj2, k)) 

302 elif (k.endswith('_') and not k.endswith('__')) or \ 

303 (k.startswith('_') and not k.startswith('__')): 

304 v1 = getattr(obj1, k) 

305 setattr(obj2, k, clone_with_fitted_parameters(v1)) 

306 else: 

307 raise RuntimeError( # pragma: no cover 

308 f"Cloned object is missing '{k}' in {obj2}.") 

309 

310 if isinstance(est, BaseEstimator): 

311 cloned = clone(est) 

312 adjust(est, cloned) 

313 res = cloned 

314 elif isinstance(est, list): 

315 res = list(clone_with_fitted_parameters(o) for o in est) 

316 elif isinstance(est, tuple): 

317 res = tuple(clone_with_fitted_parameters(o) for o in est) 

318 elif isinstance(est, dict): 

319 res = {k: clone_with_fitted_parameters(v) for k, v in est.items()} 

320 else: 

321 res = copy.deepcopy(est) 

322 return res