Coverage for mlinsights/mlmodel/sklearn_testing.py: 99%
166 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-28 08:46 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-28 08:46 +0100
1"""
2@file
3@brief Helpers to test a model which follows :epkg:`scikit-learn` API.
4"""
5import copy
6import pickle
7import pprint
8from unittest import TestCase
9from io import BytesIO
10from numpy import ndarray
11from numpy.testing import assert_almost_equal
12from pandas.testing import assert_frame_equal
13from sklearn.base import BaseEstimator
14from sklearn.model_selection import train_test_split
15from sklearn.base import clone
16from sklearn.pipeline import make_pipeline
17from sklearn.model_selection import GridSearchCV
20def train_test_split_with_none(X, y=None, sample_weight=None, random_state=0):
21 """
22 Splits into train and test data even if they are None.
24 @param X X
25 @param y y
26 @param sample_weight sample weight
27 @param random_state random state
28 @return similar to :epkg:`scikit-learn:model_selection:train_test_split`.
29 """
30 not_none = [_ for _ in [X, y, sample_weight] if _ is not None]
31 res = train_test_split(*not_none)
32 inc = len(not_none)
33 trains = []
34 tests = []
35 for i in range(inc):
36 trains.append(res[i * 2])
37 tests.append(res[i * 2 + 1])
38 while len(trains) < 3:
39 trains.append(None)
40 tests.append(None)
41 X_train, y_train, w_train = trains
42 X_test, y_test, w_test = tests
43 return X_train, y_train, w_train, X_test, y_test, w_test
46def test_sklearn_pickle(fct_model, X, y=None, sample_weight=None, **kwargs):
47 """
48 Creates a model, fit, predict and check the prediction
49 are similar after the model was pickled, unpickled.
51 @param fct_model function which creates the model
52 @param X X
53 @param y y
54 @param sample_weight sample weight
55 @param kwargs additional parameters for :epkg:`numpy:testing:assert_almost_equal`
56 @return model, unpickled model
58 :raises:
59 AssertionError
60 """
61 X_train, y_train, w_train, X_test, _, __ = train_test_split_with_none(
62 X, y, sample_weight)
63 model = fct_model()
64 if y_train is None and w_train is None:
65 model.fit(X_train)
66 else:
67 try:
68 model.fit(X_train, y_train, w_train)
69 except TypeError:
70 # Do not accept weights?
71 model.fit(X_train, y_train)
72 if hasattr(model, 'predict'):
73 pred1 = model.predict(X_test)
74 else:
75 pred1 = model.transform(X_test)
77 st = BytesIO()
78 pickle.dump(model, st)
79 data = BytesIO(st.getvalue())
80 model2 = pickle.load(data)
81 if hasattr(model2, 'predict'):
82 pred2 = model2.predict(X_test)
83 else:
84 pred2 = model2.transform(X_test)
85 if isinstance(pred1, ndarray):
86 assert_almost_equal(pred1, pred2, **kwargs)
87 else:
88 assert_frame_equal(pred1, pred2, **kwargs)
89 return model, model2
92def _get_test_instance():
93 try:
94 from pyquickhelper.pycode import ExtTestCase # pylint: disable=C0415
95 cls = ExtTestCase
96 except ImportError: # pragma: no cover
98 class _ExtTestCase(TestCase):
99 "simple test classe with a more methods"
101 def assertIsInstance(self, inst, cltype):
102 "checks that one instance is from one type"
103 if not isinstance(inst, cltype):
104 raise AssertionError(
105 f"Unexpected type {type(inst)} != {cltype}.")
107 cls = _ExtTestCase
108 return cls()
111def test_sklearn_clone(fct_model, ext=None, copy_fitted=False):
112 """
113 Tests that a cloned model is similar to the original one.
115 @param fct_model function which creates the model
116 @param ext unit test class instance
117 @param copy_fitted copy fitted parameters as well
118 @return model, cloned model
120 :raises:
121 AssertionError
122 """
123 conv = fct_model()
124 p1 = conv.get_params(deep=True)
125 if copy_fitted:
126 cloned = clone_with_fitted_parameters(conv)
127 else:
128 cloned = clone(conv)
129 p2 = cloned.get_params(deep=True)
130 if ext is None:
131 ext = _get_test_instance()
132 try:
133 ext.assertEqual(set(p1), set(p2))
134 except AssertionError as e: # pragma no cover
135 p1 = pprint.pformat(p1)
136 p2 = pprint.pformat(p2)
137 raise AssertionError(
138 f"Differences between\n----\n{p1}\n----\n{p2}") from e
140 for k in sorted(p1):
141 if isinstance(p1[k], BaseEstimator) and isinstance(p2[k], BaseEstimator):
142 if copy_fitted:
143 assert_estimator_equal(p1[k], p2[k])
144 elif isinstance(p1[k], list) and isinstance(p2[k], list):
145 _assert_list_equal(p1[k], p2[k], ext)
146 else:
147 try:
148 ext.assertEqual(p1[k], p2[k])
149 except AssertionError: # pragma no cover
150 raise AssertionError( # pylint: disable=W0707
151 f"Difference for key '{k}'\n==1 {p1[k]}\n==2 {p2[k]}")
152 return conv, cloned
155def _assert_list_equal(l1, l2, ext):
156 if len(l1) != len(l2):
157 raise AssertionError( # pragma no cover
158 f"Lists have different length {len(l1)} != {len(l2)}")
159 for a, b in zip(l1, l2):
160 if isinstance(a, tuple) and isinstance(b, tuple):
161 _assert_tuple_equal(a, b, ext)
162 else:
163 ext.assertEqual(a, b)
166def _assert_dict_equal(a, b, ext):
167 if not isinstance(a, dict): # pragma no cover
168 raise TypeError(f'a is not dict but {type(a)}')
169 if not isinstance(b, dict): # pragma no cover
170 raise TypeError(f'b is not dict but {type(b)}')
171 rows = []
172 for key in sorted(b):
173 if key not in a:
174 rows.append(f"** Added key '{key}' in b")
175 elif isinstance(a[key], BaseEstimator) and isinstance(b[key], BaseEstimator):
176 assert_estimator_equal(a[key], b[key], ext)
177 else:
178 if a[key] != b[key]:
179 rows.append(
180 "** Value != for key '{0}': != id({1}) != id({2})\n==1 {3}\n==2 {4}".format(
181 key, id(a[key]), id(b[key]), a[key], b[key]))
182 for key in sorted(a):
183 if key not in b:
184 rows.append(f"** Removed key '{key}' in a")
185 if len(rows) > 0:
186 raise AssertionError( # pragma: no cover
187 "Dictionaries are different\n{0}".format('\n'.join(rows)))
190def _assert_tuple_equal(t1, t2, ext):
191 if len(t1) != len(t2): # pragma no cover
192 raise AssertionError(
193 f"Lists have different length {len(t1)} != {len(t2)}")
194 for a, b in zip(t1, t2):
195 if isinstance(a, BaseEstimator) and isinstance(b, BaseEstimator):
196 assert_estimator_equal(a, b, ext)
197 else:
198 ext.assertEqual(a, b)
201def assert_estimator_equal(esta, estb, ext=None):
202 """
203 Checks that two models are equal.
205 @param esta first estimator
206 @param estb second estimator
207 @param ext unit test class
209 The function raises an exception if the comparison fails.
210 """
211 if ext is None:
212 ext = _get_test_instance()
213 ext.assertIsInstance(esta, estb.__class__)
214 ext.assertIsInstance(estb, esta.__class__)
215 _assert_dict_equal(esta.get_params(), estb.get_params(), ext)
216 for att in esta.__dict__:
217 if (att.endswith('_') and not att.endswith('__')) or \
218 (att.startswith('_') and not att.startswith('__')):
219 if not hasattr(estb, att): # pragma no cover
220 raise AssertionError(
221 "Missing fitted attribute '{}' class {}\n==1 {}\n==2 {}".format(
222 att, esta.__class__, list(sorted(esta.__dict__)), list(sorted(estb.__dict__))))
223 if isinstance(getattr(esta, att), BaseEstimator):
224 assert_estimator_equal(
225 getattr(esta, att), getattr(estb, att), ext)
226 else:
227 ext.assertEqual(getattr(esta, att), getattr(estb, att))
228 for att in estb.__dict__:
229 if att.endswith('_') and not att.endswith('__'):
230 if not hasattr(esta, att): # pragma no cover
231 raise AssertionError(
232 "Missing fitted attribute\n==1 {}\n==2 {}".format(
233 list(sorted(esta.__dict__)), list(sorted(estb.__dict__))))
236def test_sklearn_grid_search_cv(fct_model, X, y=None, sample_weight=None, **grid_params):
237 """
238 Creates a model, checks that a grid search works with it.
240 @param fct_model function which creates the model
241 @param X X
242 @param y y
243 @param sample_weight sample weight
244 @param grid_params parameter to use to run the grid search.
245 @return dictionary with results
247 :raises:
248 AssertionError
249 """
250 X_train, y_train, w_train, X_test, y_test, w_test = (
251 train_test_split_with_none(X, y, sample_weight))
252 model = fct_model()
253 pipe = make_pipeline(model)
254 name = model.__class__.__name__.lower()
255 parameters = {name + "__" + k: v for k, v in grid_params.items()}
256 if len(parameters) == 0:
257 raise ValueError(
258 "Some parameters must be tested when running grid search.")
259 clf = GridSearchCV(pipe, parameters)
260 if y_train is None and w_train is None:
261 clf.fit(X_train)
262 elif w_train is None:
263 clf.fit(X_train, y_train) # pylint: disable=E1121
264 else:
265 clf.fit(X_train, y_train, w_train) # pylint: disable=E1121
266 score = clf.score(X_test, y_test)
267 ext = _get_test_instance()
268 ext.assertIsInstance(score, float)
269 return dict(model=clf, X_train=X_train, y_train=y_train, w_train=w_train,
270 X_test=X_test, y_test=y_test, w_test=w_test, score=score)
273def clone_with_fitted_parameters(est):
274 """
275 Clones an estimator with the fitted results.
277 @param est estimator
278 @return cloned object
279 """
280 def adjust(obj1, obj2):
281 if isinstance(obj1, list) and isinstance(obj2, list):
282 for a, b in zip(obj1, obj2):
283 adjust(a, b)
284 elif isinstance(obj1, tuple) and isinstance(obj2, tuple):
285 for a, b in zip(obj1, obj2):
286 adjust(a, b)
287 elif isinstance(obj1, dict) and isinstance(obj2, dict):
288 for a, b in zip(obj1, obj2):
289 adjust(obj1[a], obj2[b])
290 elif isinstance(obj1, BaseEstimator) and isinstance(obj2, BaseEstimator):
291 for k in obj1.__dict__:
292 if hasattr(obj2, k):
293 v1 = getattr(obj1, k)
294 if callable(v1):
295 raise RuntimeError( # pragma: no cover
296 f"Cannot migrate trained parameters for {obj1}.")
297 elif isinstance(v1, BaseEstimator):
298 v1 = getattr(obj1, k)
299 setattr(obj2, k, clone_with_fitted_parameters(v1))
300 else:
301 adjust(getattr(obj1, k), getattr(obj2, k))
302 elif (k.endswith('_') and not k.endswith('__')) or \
303 (k.startswith('_') and not k.startswith('__')):
304 v1 = getattr(obj1, k)
305 setattr(obj2, k, clone_with_fitted_parameters(v1))
306 else:
307 raise RuntimeError( # pragma: no cover
308 f"Cloned object is missing '{k}' in {obj2}.")
310 if isinstance(est, BaseEstimator):
311 cloned = clone(est)
312 adjust(est, cloned)
313 res = cloned
314 elif isinstance(est, list):
315 res = list(clone_with_fitted_parameters(o) for o in est)
316 elif isinstance(est, tuple):
317 res = tuple(clone_with_fitted_parameters(o) for o in est)
318 elif isinstance(est, dict):
319 res = {k: clone_with_fitted_parameters(v) for k, v in est.items()}
320 else:
321 res = copy.deepcopy(est)
322 return res