Coverage for src/pymlbenchmark/benchmark/sklearn_helper.py: 91%
22 statements
« prev ^ index » next coverage.py v7.2.1, created at 2023-03-08 00:27 +0100
« prev ^ index » next coverage.py v7.2.1, created at 2023-03-08 00:27 +0100
1"""
2@file
3@brief Helpers about :epkg:`scikit-learn`.
4"""
5from sklearn.base import BaseEstimator
8def get_nb_skl_base_estimators(obj, fitted=True):
9 """
10 Returns the number of :epkg:`scikit-learn` *BaseEstimator*
11 including in a pipeline. The function assumes the pipeline
12 is not recursive.
14 @param obj object to walk through
15 @param fitted count the number of fitted object
16 @return number of base estimators including this one
17 """
18 ct = 0
19 if isinstance(obj, BaseEstimator):
20 ct += 1
21 for k, o in obj.__dict__.items():
22 if k in {'base_estimator_'}:
23 continue
24 t = 0
25 if fitted:
26 if k.endswith('_'):
27 t = get_nb_skl_base_estimators(o, fitted=fitted)
28 elif not k.endswith('_'):
29 t = get_nb_skl_base_estimators(o, fitted=fitted)
30 ct += t
31 elif isinstance(obj, (list, tuple)):
32 for o in obj:
33 ct += get_nb_skl_base_estimators(o, fitted=fitted)
34 elif isinstance(obj, dict):
35 for o in obj.values():
36 ct += get_nb_skl_base_estimators(o, fitted=fitted)
37 return ct