Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Helpers to manipulate :epkg:`scikit-learn` models. 

4""" 

5import inspect 

6import multiprocessing 

7import numpy 

8from sklearn.base import ( 

9 TransformerMixin, ClassifierMixin, RegressorMixin, BaseEstimator) 

10from sklearn.pipeline import Pipeline, FeatureUnion 

11from sklearn.compose import ColumnTransformer, TransformedTargetRegressor 

12 

13 

14def enumerate_pipeline_models(pipe, coor=None, vs=None): 

15 """ 

16 Enumerates all the models within a pipeline. 

17 

18 @param pipe *scikit-learn* pipeline 

19 @param coor current coordinate 

20 @param vs subset of variables for the model, None for all 

21 @return iterator on models ``tuple(coordinate, model)`` 

22 

23 Example: 

24 

25 .. runpython:: 

26 :showcode: 

27 :warningout: DeprecationWarning 

28 

29 from sklearn.datasets import load_iris 

30 from sklearn.decomposition import PCA 

31 from sklearn.linear_model import LogisticRegression 

32 from sklearn.pipeline import make_pipeline 

33 from sklearn.model_selection import train_test_split 

34 from mlprodict.onnxrt.optim.sklearn_helper import enumerate_pipeline_models 

35 

36 iris = load_iris() 

37 X, y = iris.data, iris.target 

38 X_train, __, y_train, _ = train_test_split(X, y, random_state=11) 

39 clr = make_pipeline(PCA(n_components=2), 

40 LogisticRegression(solver="liblinear")) 

41 clr.fit(X_train, y_train) 

42 

43 for a in enumerate_pipeline_models(clr): 

44 print(a) 

45 """ 

46 if coor is None: 

47 coor = (0,) 

48 yield coor, pipe, vs 

49 if hasattr(pipe, 'transformer_and_mapper_list') and len(pipe.transformer_and_mapper_list): 

50 # azureml DataTransformer 

51 raise NotImplementedError( # pragma: no cover 

52 "Unable to handle this specific case.") 

53 elif hasattr(pipe, 'mapper') and pipe.mapper: 

54 # azureml DataTransformer 

55 for couple in enumerate_pipeline_models( # pragma: no cover 

56 pipe.mapper, coor + (0,)): 

57 yield couple 

58 elif hasattr(pipe, 'built_features'): # pragma: no cover 

59 # sklearn_pandas.dataframe_mapper.DataFrameMapper 

60 for i, (columns, transformers, _) in enumerate( 

61 pipe.built_features): 

62 if isinstance(columns, str): 

63 columns = (columns,) 

64 if transformers is None: 

65 yield (coor + (i,)), None, columns 

66 else: 

67 for couple in enumerate_pipeline_models(transformers, coor + (i,), columns): 

68 yield couple 

69 elif isinstance(pipe, Pipeline): 

70 for i, (_, model) in enumerate(pipe.steps): 

71 for couple in enumerate_pipeline_models(model, coor + (i,)): 

72 yield couple 

73 elif isinstance(pipe, ColumnTransformer): 

74 for i, (_, fitted_transformer, column) in enumerate(pipe.transformers): 

75 for couple in enumerate_pipeline_models( 

76 fitted_transformer, coor + (i,), column): 

77 yield couple 

78 elif isinstance(pipe, FeatureUnion): 

79 for i, (_, model) in enumerate(pipe.transformer_list): 

80 for couple in enumerate_pipeline_models(model, coor + (i,)): 

81 yield couple 

82 elif isinstance(pipe, TransformedTargetRegressor): 

83 raise NotImplementedError( 

84 "Not yet implemented for TransformedTargetRegressor.") 

85 elif isinstance(pipe, (TransformerMixin, ClassifierMixin, RegressorMixin)): 

86 pass 

87 elif isinstance(pipe, BaseEstimator): 

88 pass 

89 elif isinstance(pipe, (list, numpy.ndarray)): 

90 for i, m in enumerate(pipe): 

91 for couple in enumerate_pipeline_models(m, coor + (i,)): 

92 yield couple 

93 else: 

94 raise TypeError( # pragma: no cover 

95 "pipe is not a scikit-learn object: {}\n{}".format(type(pipe), pipe)) 

96 

97 

98def enumerate_fitted_arrays(model): 

99 """ 

100 Enumerate all fitted arrays included in a 

101 :epkg:`scikit-learn` object. 

102 

103 @param model :epkg:`scikit-learn` object 

104 @return enumerator 

105 

106 One example: 

107 

108 .. runpython:: 

109 :showcode: 

110 :warningout: DeprecationWarning 

111 

112 from sklearn.datasets import load_iris 

113 from sklearn.decomposition import PCA 

114 from sklearn.linear_model import LogisticRegression 

115 from sklearn.pipeline import make_pipeline 

116 from sklearn.model_selection import train_test_split 

117 from mlprodict.onnxrt.optim.sklearn_helper import enumerate_fitted_arrays 

118 

119 iris = load_iris() 

120 X, y = iris.data, iris.target 

121 X_train, __, y_train, _ = train_test_split(X, y, random_state=11) 

122 clr = make_pipeline(PCA(n_components=2), 

123 LogisticRegression(solver="liblinear")) 

124 clr.fit(X_train, y_train) 

125 

126 for a in enumerate_fitted_arrays(clr): 

127 print(a) 

128 """ 

129 def enumerate__(obj): 

130 if isinstance(obj, (tuple, list)): 

131 for el in obj: 

132 for o in enumerate__(el): 

133 yield (obj, el, o) 

134 elif isinstance(obj, dict): 

135 for k, v in obj.items(): 

136 for o in enumerate__(v): 

137 yield (obj, k, v, o) 

138 elif hasattr(obj, '__dict__'): 

139 for k, v in obj.__dict__.items(): 

140 if k[-1] != '_' and k[0] != '_': 

141 continue 

142 if isinstance(v, numpy.ndarray): 

143 yield (obj, k, v) 

144 else: 

145 for row in enumerate__(v): 

146 yield row 

147 

148 for row in enumerate_pipeline_models(model): 

149 coord = row[:-1] 

150 sub = row[1] 

151 last = row[2:] 

152 for sub_row in enumerate__(sub): 

153 yield coord + (sub, sub_row) + last 

154 

155 

156def pairwise_array_distances(l1, l2, metric='l1med'): 

157 """ 

158 Computes pairwise distances between two lists of arrays 

159 *l1* and *l2*. The distance is 1e9 if shapes are not equal. 

160 

161 @param l1 first list of arrays 

162 @param l2 second list of arrays 

163 @param metric metric to use, `'l1med'` compute 

164 the average absolute error divided 

165 by the ansolute median 

166 @return matrix 

167 """ 

168 dist = numpy.full((len(l1), len(l2)), 1e9) 

169 for i, a1 in enumerate(l1): 

170 if not isinstance(a1, numpy.ndarray): 

171 continue # pragma: no cover 

172 for j, a2 in enumerate(l2): 

173 if not isinstance(a2, numpy.ndarray): 

174 continue # pragma: no cover 

175 if a1.shape != a2.shape: 

176 continue 

177 a = numpy.median(numpy.abs(a1)) 

178 if a == 0: 

179 a = 1 

180 diff = numpy.sum(numpy.abs(a1 - a2)) / a 

181 dist[i, j] = diff / diff.size 

182 return dist 

183 

184 

185def max_depth(estimator): 

186 """ 

187 Retrieves the max depth assuming the estimator 

188 is a decision tree. 

189 """ 

190 n_nodes = estimator.tree_.node_count 

191 children_left = estimator.tree_.children_left 

192 children_right = estimator.tree_.children_right 

193 

194 node_depth = numpy.zeros(shape=n_nodes, dtype=numpy.int64) 

195 is_leaves = numpy.zeros(shape=n_nodes, dtype=bool) 

196 stack = [(0, -1)] # seed is the root node id and its parent depth 

197 while len(stack) > 0: 

198 node_id, parent_depth = stack.pop() 

199 node_depth[node_id] = parent_depth + 1 

200 

201 # If we have a test node 

202 if children_left[node_id] != children_right[node_id]: 

203 stack.append((children_left[node_id], parent_depth + 1)) 

204 stack.append((children_right[node_id], parent_depth + 1)) 

205 else: 

206 is_leaves[node_id] = True 

207 return max(node_depth) 

208 

209 

210def inspect_sklearn_model(model, recursive=True): 

211 """ 

212 Inspects a :epkg:`scikit-learn` model and produces 

213 some figures which tries to represent the complexity of it. 

214 

215 @param model model 

216 @param recursive recursive look 

217 @return dictionary 

218 

219 .. runpython:: 

220 :showcode: 

221 :warningout: DeprecationWarning 

222 

223 import pprint 

224 from sklearn.ensemble import RandomForestClassifier 

225 from sklearn.linear_model import LogisticRegression 

226 from sklearn.datasets import load_iris 

227 from mlprodict.onnxrt.optim.sklearn_helper import inspect_sklearn_model 

228 

229 iris = load_iris() 

230 X = iris.data 

231 y = iris.target 

232 lr = LogisticRegression() 

233 lr.fit(X, y) 

234 pprint.pprint((lr, inspect_sklearn_model(lr))) 

235 

236 

237 iris = load_iris() 

238 X = iris.data 

239 y = iris.target 

240 rf = RandomForestClassifier() 

241 rf.fit(X, y) 

242 pprint.pprint((rf, inspect_sklearn_model(rf))) 

243 """ 

244 def update(sts, st): 

245 for k, v in st.items(): 

246 if k in sts: 

247 if 'max_' in k: 

248 sts[k] = max(v, sts[k]) 

249 else: 

250 sts[k] += v 

251 else: 

252 sts[k] = v 

253 

254 def insmodel(m): 

255 st = {'nop': 1} 

256 if hasattr(m, 'tree_') and hasattr(m.tree_, 'node_count'): 

257 st['nnodes'] = m.tree_.node_count 

258 st['ntrees'] = 1 

259 st['max_depth'] = max_depth(m) 

260 try: 

261 if hasattr(m, 'coef_'): 

262 st['ncoef'] = len(m.coef_) 

263 st['nlin'] = 1 

264 except KeyError: # pragma: no cover 

265 # added to deal with xgboost 1.0 (KeyError: 'weight') 

266 pass 

267 if hasattr(m, 'estimators_'): 

268 for est in m.estimators_: 

269 st_ = inspect_sklearn_model(est, recursive=recursive) 

270 update(st, st_) 

271 return st 

272 

273 if recursive: 

274 sts = {} 

275 for __, m, _ in enumerate_pipeline_models(model): 

276 st = inspect_sklearn_model(m, recursive=False) 

277 update(sts, st) 

278 st = insmodel(m) 

279 update(sts, st) 

280 return st 

281 return insmodel(model) 

282 

283 

284def set_n_jobs(model, params, n_jobs=None): 

285 """ 

286 Looks into model signature and add parameter *n_jobs* 

287 if available. The function does not overwrite the parameter. 

288 

289 @param model model class 

290 @param params current set of parameters 

291 @param n_jobs number of CPU or *n_jobs* if specified or 0 

292 @return new set of parameters 

293 

294 On this machine, the default value is the following. 

295 

296 .. runpython:: 

297 :showcode: 

298 :warningout: DeprecationWarning 

299 

300 import multiprocessing 

301 print(multiprocessing.cpu_count()) 

302 """ 

303 if params is not None and 'n_jobs' in params: 

304 return params 

305 sig = inspect.signature(model.__init__) 

306 if 'n_jobs' not in sig.parameters: 

307 return params 

308 if n_jobs == 0: 

309 n_jobs = None 

310 params = params.copy() if params else {} 

311 params['n_jobs'] = n_jobs or multiprocessing.cpu_count() 

312 return params