Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Functions to help get more information about the models. 

4""" 

5import inspect 

6from collections import Counter 

7import numpy 

8 

9 

10def _analyse_tree(tree): 

11 """ 

12 Extract information from a tree. 

13 """ 

14 info = {} 

15 if hasattr(tree, 'node_count'): 

16 info['node_count'] = tree.node_count 

17 

18 n_nodes = tree.node_count 

19 children_left = tree.children_left 

20 children_right = tree.children_right 

21 node_depth = numpy.zeros(shape=n_nodes, dtype=numpy.int64) 

22 is_leaves = numpy.zeros(shape=n_nodes, dtype=bool) 

23 stack = [(0, -1)] 

24 while len(stack) > 0: 

25 node_id, parent_depth = stack.pop() 

26 node_depth[node_id] = parent_depth + 1 

27 if children_left[node_id] != children_right[node_id]: 

28 stack.append((children_left[node_id], parent_depth + 1)) 

29 stack.append((children_right[node_id], parent_depth + 1)) 

30 else: 

31 is_leaves[node_id] = True 

32 

33 info['leave_count'] = sum(is_leaves) 

34 info['max_depth'] = max(node_depth) 

35 return info 

36 

37 

38def _analyse_tree_h(tree): 

39 """ 

40 Extract information from a tree in a 

41 HistGradientBoosting. 

42 """ 

43 info = {} 

44 info['leave_count'] = tree.get_n_leaf_nodes() 

45 info['node_count'] = len(tree.nodes) 

46 info['max_depth'] = tree.get_max_depth() 

47 return info 

48 

49 

50def _reduce_infos(infos): 

51 """ 

52 Produces agregates features. 

53 """ 

54 def tof(obj): 

55 try: 

56 return obj[0] 

57 except TypeError: # pragma: no cover 

58 return obj 

59 

60 if not isinstance(infos, list): 

61 raise TypeError( # pragma: no cover 

62 "infos must a list not {}.".format(type(infos))) 

63 keys = set() 

64 for info in infos: 

65 if not isinstance(info, dict): 

66 raise TypeError( # pragma: no cover 

67 "info must a dictionary not {}.".format(type(info))) 

68 keys |= set(info) 

69 

70 info = {} 

71 for k in keys: 

72 values = [d.get(k, None) for d in infos] 

73 values = [_ for _ in values if _ is not None] 

74 if k.endswith('.leave_count') or k.endswith('.node_count'): 

75 info['sum|%s' % k] = sum(values) 

76 elif k.endswith('.max_depth'): 

77 info['max|%s' % k] = max(values) 

78 elif k.endswith('.size'): 

79 info['sum|%s' % k] = sum(values) # pragma: no cover 

80 else: 

81 try: 

82 un = set(values) 

83 except TypeError: # pragma: no cover 

84 un = set() 

85 if len(un) == 1: 

86 info[k] = list(un)[0] 

87 continue 

88 if k.endswith('.shape'): 

89 row = [_[0] for _ in values] 

90 col = [_[1] for _ in values if len(_) > 1] 

91 if len(col) == 0: 

92 info['max|%s' % k] = (max(row), ) 

93 else: 

94 info['max|%s' % k] = (max(row), max(col)) 

95 continue 

96 if k == 'n_classes_': 

97 info['n_classes_'] = max(tof(_) for _ in values) 

98 continue 

99 raise NotImplementedError( # pragma: no cover 

100 "Unable to reduce key '{}', values={}.".format(k, values)) 

101 return info 

102 

103 

104def _get_info_lgb(model): 

105 """ 

106 Get informations from and :epkg:`lightgbm` trees. 

107 """ 

108 from ..onnx_conv.operator_converters.conv_lightgbm import ( 

109 _parse_tree_structure, 

110 get_default_tree_classifier_attribute_pairs 

111 ) 

112 gbm_text = model.dump_model() 

113 

114 info = {'objective': gbm_text['objective']} 

115 if gbm_text['objective'].startswith('binary'): 

116 info['n_classes'] = 1 

117 elif gbm_text['objective'].startswith('multiclass'): 

118 info['n_classes'] = gbm_text['num_class'] 

119 elif gbm_text['objective'].startswith('regression'): 

120 info['n_targets'] = 1 

121 else: 

122 raise NotImplementedError( # pragma: no cover 

123 "Unknown objective '{}'.".format(gbm_text['objective'])) 

124 n_classes = info.get('n_classes', info.get('n_targets', -1)) 

125 

126 info['estimators_.size'] = len(gbm_text['tree_info']) 

127 attrs = get_default_tree_classifier_attribute_pairs() 

128 for i, tree in enumerate(gbm_text['tree_info']): 

129 tree_id = i 

130 class_id = tree_id % n_classes 

131 learning_rate = 1. 

132 _parse_tree_structure( 

133 tree_id, class_id, learning_rate, tree['tree_structure'], attrs) 

134 

135 info['node_count'] = len(attrs['nodes_nodeids']) 

136 info['ntrees'] = len(set(attrs['nodes_treeids'])) 

137 dist = Counter(attrs['nodes_modes']) 

138 info['leave_count'] = dist['LEAF'] 

139 info['mode_count'] = len(dist) 

140 return info 

141 

142 

143def _get_info_xgb(model): 

144 """ 

145 Get informations from and :epkg:`lightgbm` trees. 

146 """ 

147 from ..onnx_conv.operator_converters.conv_xgboost import ( 

148 XGBConverter, XGBClassifierConverter) 

149 objective, _, js_trees = XGBConverter.common_members(model, None) 

150 attrs = XGBClassifierConverter._get_default_tree_attribute_pairs() 

151 XGBConverter.fill_tree_attributes( 

152 js_trees, attrs, [1 for _ in js_trees], True) 

153 info = {'objective': objective} 

154 info['estimators_.size'] = len(js_trees) 

155 info['node_count'] = len(attrs['nodes_nodeids']) 

156 info['ntrees'] = len(set(attrs['nodes_treeids'])) 

157 dist = Counter(attrs['nodes_modes']) 

158 info['leave_count'] = dist['LEAF'] 

159 info['mode_count'] = len(dist) 

160 return info 

161 

162 

163def analyze_model(model, simplify=True): 

164 """ 

165 Returns informations, statistics about a model, 

166 its number of nodes, its size... 

167 

168 @param model any model 

169 @param simplify simplifies the tuple of length 1 

170 @return dictionary 

171 

172 .. exref:: 

173 :title: Extract information from a model 

174 

175 The function @see fn analyze_model extracts global 

176 figures about a model, whatever it is. 

177 

178 .. runpython:: 

179 :showcode: 

180 :warningout: DeprecationWarning 

181 

182 import pprint 

183 from sklearn.datasets import load_iris 

184 from sklearn.ensemble import RandomForestClassifier 

185 from mlprodict.tools.model_info import analyze_model 

186 

187 data = load_iris() 

188 X, y = data.data, data.target 

189 model = RandomForestClassifier().fit(X, y) 

190 infos = analyze_model(model) 

191 pprint.pprint(infos) 

192 """ 

193 if hasattr(model, 'SerializeToString'): 

194 # ONNX model 

195 from ..onnx_tools.optim.onnx_helper import onnx_statistics 

196 return onnx_statistics(model) 

197 

198 if isinstance(model, numpy.ndarray): 

199 info = {'shape': model.shape} 

200 infos = [] 

201 for v in model.ravel(): 

202 if hasattr(v, 'fit'): 

203 ii = analyze_model(v, False) 

204 infos.append(ii) 

205 if len(infos) == 0: 

206 return info # pragma: no cover 

207 for k, v in _reduce_infos(infos).items(): 

208 info['.%s' % k] = v 

209 return info 

210 

211 # linear model 

212 info = {} 

213 for k in model.__dict__: 

214 if k in ['tree_']: 

215 continue 

216 if k.endswith('_') and not k.startswith('_'): 

217 v = getattr(model, k) 

218 if isinstance(v, numpy.ndarray): 

219 info['%s.shape' % k] = v.shape 

220 elif isinstance(v, numpy.float64): 

221 info['%s.shape' % k] = 1 

222 elif k in ('_fit_X', ): 

223 v = getattr(model, k) 

224 info['%s.shape' % k] = v.shape 

225 

226 # classification 

227 for f in ['n_classes_', 'n_outputs', 'n_features_']: 

228 if hasattr(model, f): 

229 info[f] = getattr(model, f) 

230 

231 # tree 

232 if hasattr(model, 'tree_'): 

233 for k, v in _analyse_tree(model.tree_).items(): 

234 info['tree_.%s' % k] = v 

235 

236 # tree 

237 if hasattr(model, 'get_n_leaf_nodes'): 

238 for k, v in _analyse_tree_h(model).items(): 

239 info['tree_.%s' % k] = v 

240 

241 # estimators 

242 if hasattr(model, 'estimators_'): 

243 info['estimators_.size'] = len(model.estimators_) 

244 infos = [analyze_model(est, False) for est in model.estimators_] 

245 for k, v in _reduce_infos(infos).items(): 

246 info['estimators_.%s' % k] = v 

247 

248 # predictors 

249 if hasattr(model, '_predictors'): 

250 info['_predictors.size'] = len(model._predictors) 

251 infos = [] 

252 for est in model._predictors: 

253 ii = [analyze_model(e, False) for e in est] 

254 infos.extend(ii) 

255 for k, v in _reduce_infos(infos).items(): 

256 info['_predictors.%s' % k] = v 

257 

258 # LGBM 

259 if hasattr(model, 'booster_'): 

260 info.update(_get_info_lgb(model.booster_)) 

261 

262 # XGB 

263 if hasattr(model, 'get_booster'): 

264 info.update(_get_info_xgb(model)) 

265 

266 # end 

267 if simplify: 

268 up = {} 

269 for k, v in info.items(): 

270 if isinstance(v, tuple) and len(v) == 1: 

271 up[k] = v[0] 

272 info.update(up) 

273 

274 return info 

275 

276 

277def enumerate_models(model): 

278 """ 

279 Enumerates models with models. 

280 

281 @param model :epkg:`scikit-learn` model 

282 @return enumerate models 

283 """ 

284 yield model 

285 sig = inspect.signature(model.__init__) 

286 for k in sig.parameters: 

287 sub = getattr(model, k, None) 

288 if sub is None: 

289 continue 

290 if not hasattr(sub, 'fit'): 

291 continue 

292 for m in enumerate_models(sub): 

293 yield m 

294 

295 

296def set_random_state(model, value=0): 

297 """ 

298 Sets all possible parameter *random_state* to 0. 

299 

300 @param model :epkg:`scikit-learn` model 

301 @param value new value 

302 @return model (same one) 

303 """ 

304 for m in enumerate_models(model): 

305 sig = inspect.signature(m.__init__) 

306 hasit = any(filter(lambda p: p == 'random_state', 

307 sig.parameters)) 

308 if hasit and hasattr(m, 'random_state'): 

309 m.random_state = value 

310 return model