Coverage for src/pymlbenchmark/benchmark/sklearn_helper.py: 91%

22 statements  

« prev     ^ index     » next       coverage.py v7.2.1, created at 2023-03-08 00:27 +0100

1""" 

2@file 

3@brief Helpers about :epkg:`scikit-learn`. 

4""" 

5from sklearn.base import BaseEstimator 

6 

7 

8def get_nb_skl_base_estimators(obj, fitted=True): 

9 """ 

10 Returns the number of :epkg:`scikit-learn` *BaseEstimator* 

11 including in a pipeline. The function assumes the pipeline 

12 is not recursive. 

13 

14 @param obj object to walk through 

15 @param fitted count the number of fitted object 

16 @return number of base estimators including this one 

17 """ 

18 ct = 0 

19 if isinstance(obj, BaseEstimator): 

20 ct += 1 

21 for k, o in obj.__dict__.items(): 

22 if k in {'base_estimator_'}: 

23 continue 

24 t = 0 

25 if fitted: 

26 if k.endswith('_'): 

27 t = get_nb_skl_base_estimators(o, fitted=fitted) 

28 elif not k.endswith('_'): 

29 t = get_nb_skl_base_estimators(o, fitted=fitted) 

30 ct += t 

31 elif isinstance(obj, (list, tuple)): 

32 for o in obj: 

33 ct += get_nb_skl_base_estimators(o, fitted=fitted) 

34 elif isinstance(obj, dict): 

35 for o in obj.values(): 

36 ct += get_nb_skl_base_estimators(o, fitted=fitted) 

37 return ct