Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Modified converter from
4`XGBoost.py <https://github.com/onnx/onnxmltools/blob/master/onnxmltools/convert/
5xgboost/operator_converters/XGBoost.py>`_.
6"""
7import json
8from pprint import pformat
9import numpy
10from xgboost import XGBClassifier
11from skl2onnx.common.data_types import guess_numpy_type # pylint: disable=C0411
14class XGBConverter:
15 "common methods for converters"
17 @staticmethod
18 def get_xgb_params(xgb_node):
19 """
20 Retrieves parameters of a model.
21 """
22 pars = xgb_node.get_xgb_params()
23 # xgboost >= 1.0
24 if 'n_estimators' not in pars:
25 pars['n_estimators'] = xgb_node.n_estimators
26 return pars
28 @staticmethod
29 def validate(xgb_node):
30 "validates the model"
31 params = XGBConverter.get_xgb_params(xgb_node)
32 try:
33 if "objective" not in params:
34 raise AttributeError('ojective')
35 except AttributeError as e: # pragma: no cover
36 raise RuntimeError('Missing attribute in XGBoost model.') from e
38 @staticmethod
39 def common_members(xgb_node, inputs):
40 "common to regresssor and classifier"
41 params = XGBConverter.get_xgb_params(xgb_node)
42 objective = params["objective"]
43 base_score = params["base_score"]
44 booster = xgb_node.get_booster()
45 # The json format was available in October 2017.
46 # XGBoost 0.7 was the first version released with it.
47 js_tree_list = booster.get_dump(with_stats=True, dump_format='json')
48 js_trees = [json.loads(s) for s in js_tree_list]
49 return objective, base_score, js_trees
51 @staticmethod
52 def _get_default_tree_attribute_pairs(is_classifier):
53 attrs = {}
54 for k in {'nodes_treeids', 'nodes_nodeids',
55 'nodes_featureids', 'nodes_modes', 'nodes_values',
56 'nodes_truenodeids', 'nodes_falsenodeids', 'nodes_missing_value_tracks_true'}:
57 attrs[k] = []
58 if is_classifier:
59 for k in {'class_treeids', 'class_nodeids', 'class_ids', 'class_weights'}:
60 attrs[k] = []
61 else:
62 for k in {'target_treeids', 'target_nodeids', 'target_ids', 'target_weights'}:
63 attrs[k] = []
64 return attrs
66 @staticmethod
67 def _add_node(attr_pairs, is_classifier, tree_id, tree_weight, node_id,
68 feature_id, mode, value, true_child_id, false_child_id, weights, weight_id_bias,
69 missing, hitrate):
70 if isinstance(feature_id, str):
71 # Something like f0, f1...
72 if feature_id[0] == "f":
73 try:
74 feature_id = int(feature_id[1:])
75 except ValueError as e: # pragma: no cover
76 raise RuntimeError(
77 "Unable to interpret '{0}'".format(feature_id)) from e
78 else: # pragma: no cover
79 try:
80 feature_id = int(feature_id)
81 except ValueError:
82 raise RuntimeError(
83 "Unable to interpret '{0}'".format(feature_id)) from e
85 # Split condition for sklearn
86 # * if X_ptr[X_sample_stride * i + X_fx_stride * node.feature] <= node.threshold:
87 # * https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/tree/_tree.pyx#L946
88 # Split condition for xgboost
89 # * if (fvalue < split_value)
90 # * https://github.com/dmlc/xgboost/blob/master/include/xgboost/tree_model.h#L804
92 attr_pairs['nodes_treeids'].append(tree_id)
93 attr_pairs['nodes_nodeids'].append(node_id)
94 attr_pairs['nodes_featureids'].append(feature_id)
95 attr_pairs['nodes_modes'].append(mode)
96 attr_pairs['nodes_values'].append(float(value))
97 attr_pairs['nodes_truenodeids'].append(true_child_id)
98 attr_pairs['nodes_falsenodeids'].append(false_child_id)
99 attr_pairs['nodes_missing_value_tracks_true'].append(missing)
100 if 'nodes_hitrates' in attr_pairs:
101 attr_pairs['nodes_hitrates'].append(hitrate) # pragma: no cover
102 if mode == 'LEAF':
103 if is_classifier:
104 for i, w in enumerate(weights):
105 attr_pairs['class_treeids'].append(tree_id)
106 attr_pairs['class_nodeids'].append(node_id)
107 attr_pairs['class_ids'].append(i + weight_id_bias)
108 attr_pairs['class_weights'].append(float(tree_weight * w))
109 else:
110 for i, w in enumerate(weights):
111 attr_pairs['target_treeids'].append(tree_id)
112 attr_pairs['target_nodeids'].append(node_id)
113 attr_pairs['target_ids'].append(i + weight_id_bias)
114 attr_pairs['target_weights'].append(float(tree_weight * w))
116 @staticmethod
117 def _fill_node_attributes(treeid, tree_weight, jsnode, attr_pairs, is_classifier, remap):
118 if 'children' in jsnode:
119 XGBConverter._add_node(attr_pairs=attr_pairs, is_classifier=is_classifier,
120 tree_id=treeid, tree_weight=tree_weight,
121 value=jsnode['split_condition'], node_id=remap[jsnode['nodeid']],
122 feature_id=jsnode['split'],
123 mode='BRANCH_LT', # 'BRANCH_LEQ' --> is for sklearn
124 # ['children'][0]['nodeid'],
125 true_child_id=remap[jsnode['yes']],
126 # ['children'][1]['nodeid'],
127 false_child_id=remap[jsnode['no']],
128 weights=None, weight_id_bias=None,
129 # ['children'][0]['nodeid'],
130 missing=jsnode.get(
131 'missing', -1) == jsnode['yes'],
132 hitrate=jsnode.get('cover', 0))
134 for ch in jsnode['children']:
135 if 'children' in ch or 'leaf' in ch:
136 XGBConverter._fill_node_attributes(
137 treeid, tree_weight, ch, attr_pairs, is_classifier, remap)
138 else:
139 raise RuntimeError( # pragma: no cover
140 "Unable to convert this node {0}".format(ch))
142 else:
143 weights = [jsnode['leaf']]
144 weights_id_bias = 0
145 XGBConverter._add_node(attr_pairs=attr_pairs, is_classifier=is_classifier,
146 tree_id=treeid, tree_weight=tree_weight,
147 value=0., node_id=remap[jsnode['nodeid']],
148 feature_id=0, mode='LEAF',
149 true_child_id=0, false_child_id=0,
150 weights=weights, weight_id_bias=weights_id_bias,
151 missing=False, hitrate=jsnode.get('cover', 0))
153 @staticmethod
154 def _remap_nodeid(jsnode, remap=None):
155 if remap is None:
156 remap = {}
157 nid = jsnode['nodeid']
158 remap[nid] = len(remap)
159 if 'children' in jsnode:
160 for ch in jsnode['children']:
161 XGBConverter._remap_nodeid(ch, remap)
162 return remap
164 @staticmethod
165 def fill_tree_attributes(js_xgb_node, attr_pairs, tree_weights, is_classifier):
166 "fills tree attributes"
167 if not isinstance(js_xgb_node, list):
168 raise TypeError( # pragma: no cover
169 "js_xgb_node must be a list")
170 for treeid, (jstree, w) in enumerate(zip(js_xgb_node, tree_weights)):
171 remap = XGBConverter._remap_nodeid(jstree)
172 XGBConverter._fill_node_attributes(
173 treeid, w, jstree, attr_pairs, is_classifier, remap)
176class XGBRegressorConverter(XGBConverter):
177 "converter class"
179 @staticmethod
180 def validate(xgb_node):
181 return XGBConverter.validate(xgb_node)
183 @staticmethod
184 def _get_default_tree_attribute_pairs(): # pylint: disable=W0221
185 attrs = XGBConverter._get_default_tree_attribute_pairs(False)
186 attrs['post_transform'] = 'NONE'
187 attrs['n_targets'] = 1
188 return attrs
190 @staticmethod
191 def convert(scope, operator, container):
192 "converter method"
193 dtype = guess_numpy_type(operator.inputs[0].type)
194 if dtype != numpy.float64:
195 dtype = numpy.float32
196 xgb_node = operator.raw_operator
197 inputs = operator.inputs
198 objective, base_score, js_trees = XGBConverter.common_members(
199 xgb_node, inputs)
201 if objective in ["reg:gamma", "reg:tweedie"]:
202 raise RuntimeError( # pragma: no cover
203 "Objective '{}' not supported.".format(objective))
205 booster = xgb_node.get_booster()
206 if booster is None:
207 raise RuntimeError( # pragma: no cover
208 "The model was probably not trained.")
210 best_ntree_limit = getattr(booster, 'best_ntree_limit', len(js_trees))
211 if best_ntree_limit < len(js_trees):
212 js_trees = js_trees[:best_ntree_limit]
214 attr_pairs = XGBRegressorConverter._get_default_tree_attribute_pairs()
215 attr_pairs['base_values'] = [base_score]
216 XGBConverter.fill_tree_attributes(
217 js_trees, attr_pairs, [1 for _ in js_trees], False)
219 # add nodes
220 if dtype == numpy.float64:
221 container.add_node('TreeEnsembleRegressorDouble', operator.input_full_names,
222 operator.output_full_names,
223 name=scope.get_unique_operator_name(
224 'TreeEnsembleRegressorDouble'),
225 op_domain='mlprodict', **attr_pairs)
226 else:
227 container.add_node('TreeEnsembleRegressor', operator.input_full_names,
228 operator.output_full_names,
229 name=scope.get_unique_operator_name(
230 'TreeEnsembleRegressor'),
231 op_domain='ai.onnx.ml', **attr_pairs)
234class XGBClassifierConverter(XGBConverter):
235 "converter for XGBClassifier"
237 @staticmethod
238 def validate(xgb_node):
239 return XGBConverter.validate(xgb_node)
241 @staticmethod
242 def _get_default_tree_attribute_pairs(): # pylint: disable=W0221
243 attrs = XGBConverter._get_default_tree_attribute_pairs(True)
244 # attrs['nodes_hitrates'] = []
245 return attrs
247 @staticmethod
248 def convert(scope, operator, container):
249 "convert method"
250 dtype = guess_numpy_type(operator.inputs[0].type)
251 if dtype != numpy.float64:
252 dtype = numpy.float32
253 xgb_node = operator.raw_operator
254 inputs = operator.inputs
256 objective, base_score, js_trees = XGBConverter.common_members(
257 xgb_node, inputs)
258 if base_score is None:
259 raise RuntimeError( # pragma: no cover
260 "base_score cannot be None")
261 params = XGBConverter.get_xgb_params(xgb_node)
263 attr_pairs = XGBClassifierConverter._get_default_tree_attribute_pairs()
264 XGBConverter.fill_tree_attributes(
265 js_trees, attr_pairs, [1 for _ in js_trees], True)
267 ncl = (max(attr_pairs['class_treeids']) + 1) // params['n_estimators']
269 bst = xgb_node.get_booster()
270 best_ntree_limit = getattr(
271 bst, 'best_ntree_limit', len(js_trees)) * ncl
272 if best_ntree_limit < len(js_trees):
273 js_trees = js_trees[:best_ntree_limit]
274 attr_pairs = XGBClassifierConverter._get_default_tree_attribute_pairs()
275 XGBConverter.fill_tree_attributes(
276 js_trees, attr_pairs, [1 for _ in js_trees], True)
278 if len(attr_pairs['class_treeids']) == 0:
279 raise RuntimeError( # pragma: no cover
280 "XGBoost model is empty.")
281 if 'n_estimators' not in params:
282 raise RuntimeError( # pragma: no cover
283 "Parameters not found, existing:\n{}".format(
284 pformat(params)))
285 if ncl <= 1:
286 ncl = 2
287 # See https://github.com/dmlc/xgboost/blob/master/src/common/math.h#L23.
288 attr_pairs['post_transform'] = "LOGISTIC"
289 attr_pairs['class_ids'] = [0 for v in attr_pairs['class_treeids']]
290 else:
291 # See https://github.com/dmlc/xgboost/blob/master/src/common/math.h#L35.
292 attr_pairs['post_transform'] = "SOFTMAX"
293 # attr_pairs['base_values'] = [base_score for n in range(ncl)]
294 attr_pairs['class_ids'] = [v % ncl
295 for v in attr_pairs['class_treeids']]
297 classes = xgb_node.classes_
298 if (numpy.issubdtype(classes.dtype, numpy.floating) or
299 numpy.issubdtype(classes.dtype, numpy.signedinteger)):
300 attr_pairs['classlabels_int64s'] = classes.astype('int')
301 else:
302 classes = numpy.array([s.encode('utf-8') for s in classes])
303 attr_pairs['classlabels_strings'] = classes
305 if dtype == numpy.float64:
306 op_name = "TreeEnsembleClassifierDouble"
307 else:
308 op_name = "TreeEnsembleClassifier"
310 # add nodes
311 if objective == "binary:logistic":
312 ncl = 2
313 container.add_node(op_name, operator.input_full_names,
314 operator.output_full_names,
315 name=scope.get_unique_operator_name(
316 op_name),
317 op_domain='ai.onnx.ml', **attr_pairs)
318 elif objective == "multi:softprob":
319 ncl = len(js_trees) // params['n_estimators']
320 container.add_node(op_name, operator.input_full_names,
321 operator.output_full_names,
322 name=scope.get_unique_operator_name(
323 op_name),
324 op_domain='ai.onnx.ml', **attr_pairs)
325 elif objective == "reg:logistic":
326 ncl = len(js_trees) // params['n_estimators']
327 if ncl == 1:
328 ncl = 2
329 container.add_node(op_name, operator.input_full_names,
330 operator.output_full_names,
331 name=scope.get_unique_operator_name(
332 op_name),
333 op_domain='ai.onnx.ml', **attr_pairs)
334 else:
335 raise RuntimeError( # pragma: no cover
336 "Unexpected objective: {0}".format(objective))
339def convert_xgboost(scope, operator, container):
340 """
341 This converters reuses the code from
342 `XGBoost.py <https://github.com/onnx/onnxmltools/blob/master/onnxmltools/convert/
343 xgboost/operator_converters/XGBoost.py>`_ and makes
344 some modifications. It implements converters
345 for models in :epkg:`xgboost`.
346 """
347 xgb_node = operator.raw_operator
348 if isinstance(xgb_node, XGBClassifier):
349 cls = XGBClassifierConverter
350 else:
351 cls = XGBRegressorConverter
352 cls.convert(scope, operator, container)