Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Modified converter from
4`LightGbm.py <https://github.com/onnx/onnxmltools/blob/master/onnxmltools/convert/
5lightgbm/operator_converters/LightGbm.py>`_.
6"""
7from collections import Counter
8import copy
9import numbers
10import numpy
11from onnx import TensorProto
12from skl2onnx.common._apply_operation import apply_div, apply_reshape, apply_sub # pylint: disable=E0611
13from skl2onnx.common.tree_ensemble import get_default_tree_classifier_attribute_pairs
14from skl2onnx.proto import onnx_proto
15from skl2onnx.common.shape_calculator import (
16 calculate_linear_regressor_output_shapes,
17 calculate_linear_classifier_output_shapes)
18from skl2onnx.common.data_types import guess_numpy_type
19from skl2onnx.common.tree_ensemble import sklearn_threshold
20from ..helpers.lgbm_helper import (
21 dump_lgbm_booster, modify_tree_for_rule_in_set)
24def calculate_lightgbm_output_shapes(operator):
25 """
26 Shape calculator for LightGBM Booster
27 (see :epkg:`lightgbm`).
28 """
29 op = operator.raw_operator
30 if hasattr(op, "_model_dict"):
31 objective = op._model_dict['objective']
32 elif hasattr(op, 'objective_'):
33 objective = op.objective_
34 else:
35 raise RuntimeError( # pragma: no cover
36 "Unable to find attributes '_model_dict' or 'objective_' in "
37 "instance of type %r (list of attributes=%r)." % (
38 type(op), dir(op)))
39 if objective.startswith('binary') or objective.startswith('multiclass'):
40 return calculate_linear_classifier_output_shapes(operator)
41 if objective.startswith('regression'): # pragma: no cover
42 return calculate_linear_regressor_output_shapes(operator)
43 raise NotImplementedError( # pragma: no cover
44 "Objective '{}' is not implemented yet.".format(objective))
47def _translate_split_criterion(criterion):
48 # If the criterion is true, LightGBM use the left child. Otherwise, right child is selected.
49 if criterion == '<=':
50 return 'BRANCH_LEQ'
51 if criterion == '<': # pragma: no cover
52 return 'BRANCH_LT'
53 if criterion == '>=': # pragma: no cover
54 return 'BRANCH_GTE'
55 if criterion == '>': # pragma: no cover
56 return 'BRANCH_GT'
57 if criterion == '==': # pragma: no cover
58 return 'BRANCH_EQ'
59 if criterion == '!=': # pragma: no cover
60 return 'BRANCH_NEQ'
61 raise ValueError( # pragma: no cover
62 'Unsupported splitting criterion: %s. Only <=, '
63 '<, >=, and > are allowed.')
66def _create_node_id(node_id_pool):
67 i = 0
68 while i in node_id_pool:
69 i += 1
70 node_id_pool.add(i)
71 return i
74def _parse_tree_structure(tree_id, class_id, learning_rate,
75 tree_structure, attrs):
76 """
77 The pool of all nodes' indexes created when parsing a single tree.
78 Different tree use different pools.
79 """
80 node_id_pool = set()
81 node_pyid_pool = dict()
83 node_id = _create_node_id(node_id_pool)
84 node_pyid_pool[id(tree_structure)] = node_id
86 # The root node is a leaf node.
87 if ('left_child' not in tree_structure or
88 'right_child' not in tree_structure):
89 _parse_node(tree_id, class_id, node_id, node_id_pool, node_pyid_pool,
90 learning_rate, tree_structure, attrs)
91 return
93 left_pyid = id(tree_structure['left_child'])
94 right_pyid = id(tree_structure['right_child'])
96 if left_pyid in node_pyid_pool:
97 left_id = node_pyid_pool[left_pyid]
98 left_parse = False
99 else:
100 left_id = _create_node_id(node_id_pool)
101 node_pyid_pool[left_pyid] = left_id
102 left_parse = True
104 if right_pyid in node_pyid_pool:
105 right_id = node_pyid_pool[right_pyid]
106 right_parse = False
107 else:
108 right_id = _create_node_id(node_id_pool)
109 node_pyid_pool[right_pyid] = right_id
110 right_parse = True
112 attrs['nodes_treeids'].append(tree_id)
113 attrs['nodes_nodeids'].append(node_id)
115 attrs['nodes_featureids'].append(tree_structure['split_feature'])
116 mode = _translate_split_criterion(tree_structure['decision_type'])
117 attrs['nodes_modes'].append(mode)
119 if isinstance(tree_structure['threshold'], str):
120 try: # pragma: no cover
121 th = float(tree_structure['threshold']) # pragma: no cover
122 except ValueError as e: # pragma: no cover
123 import pprint
124 text = pprint.pformat(tree_structure)
125 if len(text) > 99999:
126 text = text[:99999] + "\n..."
127 raise TypeError("threshold must be a number not '{}'"
128 "\n{}".format(tree_structure['threshold'], text)) from e
129 else:
130 th = tree_structure['threshold']
131 if mode == 'BRANCH_LEQ':
132 th2 = sklearn_threshold(th, numpy.float32, mode)
133 else:
134 # other decision criteria are not implemented
135 th2 = th
136 attrs['nodes_values'].append(th2)
138 # Assume left is the true branch and right is the false branch
139 attrs['nodes_truenodeids'].append(left_id)
140 attrs['nodes_falsenodeids'].append(right_id)
141 if tree_structure['default_left']:
142 # attrs['nodes_missing_value_tracks_true'].append(1)
143 if (tree_structure["missing_type"] in ('None', None) and
144 float(tree_structure['threshold']) < 0.0):
145 attrs['nodes_missing_value_tracks_true'].append(0)
146 else:
147 attrs['nodes_missing_value_tracks_true'].append(1)
148 else:
149 attrs['nodes_missing_value_tracks_true'].append(0)
150 attrs['nodes_hitrates'].append(1.)
151 if left_parse:
152 _parse_node(
153 tree_id, class_id, left_id, node_id_pool, node_pyid_pool,
154 learning_rate, tree_structure['left_child'], attrs)
155 if right_parse:
156 _parse_node(
157 tree_id, class_id, right_id, node_id_pool, node_pyid_pool,
158 learning_rate, tree_structure['right_child'], attrs)
161def _parse_node(tree_id, class_id, node_id, node_id_pool, node_pyid_pool,
162 learning_rate, node, attrs):
163 """
164 Parses nodes.
165 """
166 if ((hasattr(node, 'left_child') and hasattr(node, 'right_child')) or
167 ('left_child' in node and 'right_child' in node)):
169 left_pyid = id(node['left_child'])
170 right_pyid = id(node['right_child'])
172 if left_pyid in node_pyid_pool:
173 left_id = node_pyid_pool[left_pyid]
174 left_parse = False
175 else:
176 left_id = _create_node_id(node_id_pool)
177 node_pyid_pool[left_pyid] = left_id
178 left_parse = True
180 if right_pyid in node_pyid_pool:
181 right_id = node_pyid_pool[right_pyid]
182 right_parse = False
183 else:
184 right_id = _create_node_id(node_id_pool)
185 node_pyid_pool[right_pyid] = right_id
186 right_parse = True
188 attrs['nodes_treeids'].append(tree_id)
189 attrs['nodes_nodeids'].append(node_id)
191 attrs['nodes_featureids'].append(node['split_feature'])
192 attrs['nodes_modes'].append(
193 _translate_split_criterion(node['decision_type']))
194 if isinstance(node['threshold'], str):
195 try: # pragma: no cover
196 attrs['nodes_values'].append( # pragma: no cover
197 float(node['threshold']))
198 except ValueError as e: # pragma: no cover
199 import pprint
200 text = pprint.pformat(node)
201 if len(text) > 99999:
202 text = text[:99999] + "\n..."
203 raise TypeError("threshold must be a number not '{}'"
204 "\n{}".format(node['threshold'], text)) from e
205 else:
206 attrs['nodes_values'].append(node['threshold'])
208 # Assume left is the true branch and right is the false branch
209 attrs['nodes_truenodeids'].append(left_id)
210 attrs['nodes_falsenodeids'].append(right_id)
211 if node['default_left']:
212 # attrs['nodes_missing_value_tracks_true'].append(1)
213 if (node['missing_type'] in ('None', None) and
214 float(node['threshold']) < 0.0):
215 attrs['nodes_missing_value_tracks_true'].append(0)
216 else:
217 attrs['nodes_missing_value_tracks_true'].append(1)
218 else:
219 attrs['nodes_missing_value_tracks_true'].append(0)
220 attrs['nodes_hitrates'].append(1.)
222 # Recursively dive into the child nodes
223 if left_parse:
224 _parse_node(
225 tree_id, class_id, left_id, node_id_pool, node_pyid_pool,
226 learning_rate, node['left_child'], attrs)
227 if right_parse:
228 _parse_node(
229 tree_id, class_id, right_id, node_id_pool, node_pyid_pool,
230 learning_rate, node['right_child'], attrs)
231 elif hasattr(node, 'left_child') or hasattr(node, 'right_child'):
232 raise ValueError('Need two branches') # pragma: no cover
233 else:
234 # Node attributes
235 attrs['nodes_treeids'].append(tree_id)
236 attrs['nodes_nodeids'].append(node_id)
237 attrs['nodes_featureids'].append(0)
238 attrs['nodes_modes'].append('LEAF')
239 # Leaf node has no threshold. A zero is appended but it will never be used.
240 attrs['nodes_values'].append(0.)
241 # Leaf node has no child. A zero is appended but it will never be used.
242 attrs['nodes_truenodeids'].append(0)
243 # Leaf node has no child. A zero is appended but it will never be used.
244 attrs['nodes_falsenodeids'].append(0)
245 # Leaf node has no split function. A zero is appended but it will never be used.
246 attrs['nodes_missing_value_tracks_true'].append(0)
247 attrs['nodes_hitrates'].append(1.)
249 # Leaf attributes
250 attrs['class_treeids'].append(tree_id)
251 attrs['class_nodeids'].append(node_id)
252 attrs['class_ids'].append(class_id)
253 attrs['class_weights'].append(
254 float(node['leaf_value']) * learning_rate)
257def _split_tree_ensemble_atts(attrs, split):
258 """
259 Splits the attributes of a TreeEnsembleRegressor into
260 multiple trees in order to do the summation in double instead of floats.
261 """
262 trees_id = list(sorted(set(attrs['nodes_treeids'])))
263 results = []
264 index = 0
265 while index < len(trees_id):
266 index2 = min(index + split, len(trees_id))
267 subset = set(trees_id[index: index2])
269 indices_node = []
270 indices_target = []
271 for j, v in enumerate(attrs['nodes_treeids']):
272 if v in subset:
273 indices_node.append(j)
274 for j, v in enumerate(attrs['target_treeids']):
275 if v in subset:
276 indices_target.append(j)
278 if (len(indices_node) >= len(attrs['nodes_treeids']) or
279 len(indices_target) >= len(attrs['target_treeids'])):
280 raise RuntimeError( # pragma: no cover
281 "Initial attributes are not consistant."
282 "\nindex=%r index2=%r subset=%r"
283 "\nnodes_treeids=%r\ntarget_treeids=%r"
284 "\nindices_node=%r\nindices_target=%r" % (
285 index, index2, subset,
286 attrs['nodes_treeids'], attrs['target_treeids'],
287 indices_node, indices_target))
289 ats = {}
290 for name, att in attrs.items():
291 if name == 'nodes_treeids':
292 new_att = [att[i] for i in indices_node]
293 new_att = [i - att[0] for i in new_att]
294 elif name == 'target_treeids':
295 new_att = [att[i] for i in indices_target]
296 new_att = [i - att[0] for i in new_att]
297 elif name.startswith("nodes_"):
298 new_att = [att[i] for i in indices_node]
299 assert len(new_att) == len(indices_node)
300 elif name.startswith("target_"):
301 new_att = [att[i] for i in indices_target]
302 assert len(new_att) == len(indices_target)
303 elif name == 'name':
304 new_att = "%s%d" % (att, len(results))
305 else:
306 new_att = att
307 ats[name] = new_att
309 results.append(ats)
310 index = index2
312 return results
315def convert_lightgbm(scope, operator, container): # pylint: disable=R0914
316 """
317 This converters reuses the code from
318 `LightGbm.py <https://github.com/onnx/onnxmltools/blob/master/onnxmltools/convert/
319 lightgbm/operator_converters/LightGbm.py>`_ and makes
320 some modifications. It implements converters
321 for models in :epkg:`lightgbm`.
322 """
323 verbose = getattr(container, 'verbose', 0)
324 gbm_model = operator.raw_operator
325 if hasattr(gbm_model, '_model_dict_info'):
326 gbm_text, info = gbm_model._model_dict_info
327 else:
328 if verbose >= 2:
329 print("[convert_lightgbm] dump_model")
330 gbm_text, info = dump_lgbm_booster(gbm_model.booster_, verbose=verbose)
331 if verbose >= 2:
332 print("[convert_lightgbm] modify_tree_for_rule_in_set")
333 modify_tree_for_rule_in_set(gbm_text, use_float=True, verbose=verbose,
334 info=info)
336 attrs = get_default_tree_classifier_attribute_pairs()
337 attrs['name'] = operator.full_name
339 # Create different attributes for classifier and
340 # regressor, respectively
341 post_transform = None
342 if gbm_text['objective'].startswith('binary'):
343 n_classes = 1
344 attrs['post_transform'] = 'LOGISTIC'
345 elif gbm_text['objective'].startswith('multiclass'):
346 n_classes = gbm_text['num_class']
347 attrs['post_transform'] = 'SOFTMAX'
348 elif gbm_text['objective'].startswith('regression'):
349 n_classes = 1 # Regressor has only one output variable
350 attrs['post_transform'] = 'NONE'
351 attrs['n_targets'] = n_classes
352 elif gbm_text['objective'].startswith(('poisson', 'gamma')):
353 n_classes = 1 # Regressor has only one output variable
354 attrs['n_targets'] = n_classes
355 # 'Exp' is not a supported post_transform value in the ONNX spec yet,
356 # so we need to add an 'Exp' post transform node to the model
357 attrs['post_transform'] = 'NONE'
358 post_transform = "Exp"
359 else:
360 raise RuntimeError( # pragma: no cover
361 "LightGBM objective should be cleaned already not '{}'.".format(
362 gbm_text['objective']))
364 # Use the same algorithm to parse the tree
365 if verbose >= 2:
366 from tqdm import tqdm
367 loop = tqdm(gbm_text['tree_info'])
368 loop.set_description("parse")
369 else:
370 loop = gbm_text['tree_info']
371 for i, tree in enumerate(loop):
372 tree_id = i
373 class_id = tree_id % n_classes
374 # tree['shrinkage'] --> LightGbm provides figures with it already.
375 learning_rate = 1.
376 _parse_tree_structure(
377 tree_id, class_id, learning_rate, tree['tree_structure'], attrs)
379 if verbose >= 2:
380 print("[convert_lightgbm] onnx")
381 # Sort nodes_* attributes. For one tree, its node indexes
382 # should appear in an ascent order in nodes_nodeids. Nodes
383 # from a tree with a smaller tree index should appear
384 # before trees with larger indexes in nodes_nodeids.
385 node_numbers_per_tree = Counter(attrs['nodes_treeids'])
386 tree_number = len(node_numbers_per_tree.keys())
387 accumulated_node_numbers = [0] * tree_number
388 for i in range(1, tree_number):
389 accumulated_node_numbers[i] = (
390 accumulated_node_numbers[i - 1] + node_numbers_per_tree[i - 1])
391 global_node_indexes = []
392 for i in range(len(attrs['nodes_nodeids'])):
393 tree_id = attrs['nodes_treeids'][i]
394 node_id = attrs['nodes_nodeids'][i]
395 global_node_indexes.append(
396 accumulated_node_numbers[tree_id] + node_id)
397 for k, v in attrs.items():
398 if k.startswith('nodes_'):
399 merged_indexes = zip(
400 copy.deepcopy(global_node_indexes), v)
401 sorted_list = [pair[1]
402 for pair in sorted(merged_indexes,
403 key=lambda x: x[0])]
404 attrs[k] = sorted_list
406 dtype = guess_numpy_type(operator.inputs[0].type)
407 if dtype != numpy.float64:
408 dtype = numpy.float32
410 # Create ONNX object
411 if (gbm_text['objective'].startswith('binary') or
412 gbm_text['objective'].startswith('multiclass')):
413 # Prepare label information for both of TreeEnsembleClassifier
414 # and ZipMap
415 class_type = onnx_proto.TensorProto.STRING # pylint: disable=E1101
416 zipmap_attrs = {'name': scope.get_unique_variable_name('ZipMap')}
417 if all(isinstance(i, (numbers.Real, bool, numpy.bool_))
418 for i in gbm_model.classes_):
419 class_type = onnx_proto.TensorProto.INT64 # pylint: disable=E1101
420 class_labels = [int(i) for i in gbm_model.classes_]
421 attrs['classlabels_int64s'] = class_labels
422 zipmap_attrs['classlabels_int64s'] = class_labels
423 elif all(isinstance(i, str) for i in gbm_model.classes_):
424 class_labels = [str(i) for i in gbm_model.classes_]
425 attrs['classlabels_strings'] = class_labels
426 zipmap_attrs['classlabels_strings'] = class_labels
427 else:
428 raise ValueError( # pragma: no cover
429 'Only string and integer class labels are allowed')
431 # Create tree classifier
432 probability_tensor_name = scope.get_unique_variable_name(
433 'probability_tensor')
434 label_tensor_name = scope.get_unique_variable_name('label_tensor')
436 if dtype == numpy.float64:
437 container.add_node('TreeEnsembleClassifierDouble', operator.input_full_names,
438 [label_tensor_name, probability_tensor_name],
439 op_domain='mlprodict', **attrs)
440 else:
441 container.add_node('TreeEnsembleClassifier', operator.input_full_names,
442 [label_tensor_name, probability_tensor_name],
443 op_domain='ai.onnx.ml', **attrs)
445 prob_tensor = probability_tensor_name
447 if gbm_model.boosting_type == 'rf':
448 col_index_name = scope.get_unique_variable_name('col_index')
449 first_col_name = scope.get_unique_variable_name('first_col')
450 zeroth_col_name = scope.get_unique_variable_name('zeroth_col')
451 denominator_name = scope.get_unique_variable_name('denominator')
452 modified_first_col_name = scope.get_unique_variable_name(
453 'modified_first_col')
454 unit_float_tensor_name = scope.get_unique_variable_name(
455 'unit_float_tensor')
456 merged_prob_name = scope.get_unique_variable_name('merged_prob')
457 predicted_label_name = scope.get_unique_variable_name(
458 'predicted_label')
459 classes_name = scope.get_unique_variable_name('classes')
460 final_label_name = scope.get_unique_variable_name('final_label')
462 container.add_initializer(
463 col_index_name, onnx_proto.TensorProto.INT64, [], [1]) # pylint: disable=E1101
464 container.add_initializer(
465 unit_float_tensor_name, onnx_proto.TensorProto.FLOAT, [], [1.0]) # pylint: disable=E1101
466 container.add_initializer(
467 denominator_name, onnx_proto.TensorProto.FLOAT, [], [100.0]) # pylint: disable=E1101
468 container.add_initializer(classes_name, class_type,
469 [len(class_labels)], class_labels)
471 container.add_node(
472 'ArrayFeatureExtractor',
473 [probability_tensor_name, col_index_name],
474 first_col_name,
475 name=scope.get_unique_operator_name(
476 'ArrayFeatureExtractor'),
477 op_domain='ai.onnx.ml')
478 apply_div(scope, [first_col_name, denominator_name],
479 modified_first_col_name, container, broadcast=1)
480 apply_sub(
481 scope, [unit_float_tensor_name, modified_first_col_name],
482 zeroth_col_name, container, broadcast=1)
483 container.add_node(
484 'Concat', [zeroth_col_name, modified_first_col_name],
485 merged_prob_name,
486 name=scope.get_unique_operator_name('Concat'), axis=1)
487 container.add_node(
488 'ArgMax', merged_prob_name,
489 predicted_label_name,
490 name=scope.get_unique_operator_name('ArgMax'), axis=1)
491 container.add_node(
492 'ArrayFeatureExtractor', [classes_name, predicted_label_name],
493 final_label_name,
494 name=scope.get_unique_operator_name('ArrayFeatureExtractor'),
495 op_domain='ai.onnx.ml')
496 apply_reshape(scope, final_label_name,
497 operator.outputs[0].full_name,
498 container, desired_shape=[-1, ])
499 prob_tensor = merged_prob_name
500 else:
501 container.add_node('Identity', label_tensor_name,
502 operator.outputs[0].full_name,
503 name=scope.get_unique_operator_name('Identity'))
505 # Convert probability tensor to probability map
506 # (keys are labels while values are the associated probabilities)
507 container.add_node('Identity', prob_tensor,
508 operator.outputs[1].full_name)
509 else:
510 # Create tree regressor
511 output_name = scope.get_unique_variable_name('output')
513 keys_to_be_renamed = list(
514 k for k in attrs if k.startswith('class_'))
516 for k in keys_to_be_renamed:
517 # Rename class_* attribute to target_*
518 # because TreeEnsebmleClassifier
519 # and TreeEnsembleClassifier have different ONNX attributes
520 attrs['target' + k[5:]] = copy.deepcopy(attrs[k])
521 del attrs[k]
523 options = container.get_options(gbm_model, dict(split=-1))
524 split = options['split']
525 if split == -1:
526 if dtype == numpy.float64:
527 container.add_node(
528 'TreeEnsembleRegressorDouble', operator.input_full_names,
529 output_name, op_domain='mlprodict', **attrs)
530 else:
531 container.add_node(
532 'TreeEnsembleRegressor', operator.input_full_names,
533 output_name, op_domain='ai.onnx.ml', **attrs)
534 else:
535 tree_attrs = _split_tree_ensemble_atts(attrs, split)
536 tree_nodes = []
537 for i, ats in enumerate(tree_attrs):
538 tree_name = scope.get_unique_variable_name('tree%d' % i)
539 if dtype == numpy.float64:
540 container.add_node(
541 'TreeEnsembleRegressorDouble', operator.input_full_names,
542 tree_name, op_domain='mlprodict', **ats)
543 tree_nodes.append(tree_name)
544 else:
545 container.add_node(
546 'TreeEnsembleRegressor', operator.input_full_names,
547 tree_name, op_domain='ai.onnx.ml', **ats)
548 cast_name = scope.get_unique_variable_name('dtree%d' % i)
549 container.add_node(
550 'Cast', tree_name, cast_name, to=TensorProto.DOUBLE, # pylint: disable=E1101
551 name=scope.get_unique_operator_name("dtree%d" % i))
552 tree_nodes.append(cast_name)
553 if dtype == numpy.float64:
554 container.add_node(
555 'Sum', tree_nodes, output_name,
556 name=scope.get_unique_operator_name("sumtree%d" % len(tree_nodes)))
557 else:
558 cast_name = scope.get_unique_variable_name('ftrees')
559 container.add_node(
560 'Sum', tree_nodes, cast_name,
561 name=scope.get_unique_operator_name("sumtree%d" % len(tree_nodes)))
562 container.add_node(
563 'Cast', cast_name, output_name, to=TensorProto.FLOAT, # pylint: disable=E1101
564 name=scope.get_unique_operator_name("dtree%d" % i))
566 if gbm_model.boosting_type == 'rf':
567 denominator_name = scope.get_unique_variable_name('denominator')
569 container.add_initializer(
570 denominator_name, onnx_proto.TensorProto.FLOAT, # pylint: disable=E1101
571 [], [100.0])
573 apply_div(scope, [output_name, denominator_name],
574 operator.output_full_names, container, broadcast=1)
575 elif post_transform:
576 container.add_node(
577 post_transform, output_name,
578 operator.output_full_names,
579 name=scope.get_unique_operator_name(
580 post_transform))
581 else:
582 container.add_node('Identity', output_name,
583 operator.output_full_names,
584 name=scope.get_unique_operator_name('Identity'))
586 if verbose >= 2:
587 print("[convert_lightgbm] end")