Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Modified converter from 

4`LightGbm.py <https://github.com/onnx/onnxmltools/blob/master/onnxmltools/convert/ 

5lightgbm/operator_converters/LightGbm.py>`_. 

6""" 

7from collections import Counter 

8import copy 

9import numbers 

10import numpy 

11from onnx import TensorProto 

12from skl2onnx.common._apply_operation import apply_div, apply_reshape, apply_sub # pylint: disable=E0611 

13from skl2onnx.common.tree_ensemble import get_default_tree_classifier_attribute_pairs 

14from skl2onnx.proto import onnx_proto 

15from skl2onnx.common.shape_calculator import ( 

16 calculate_linear_regressor_output_shapes, 

17 calculate_linear_classifier_output_shapes) 

18from skl2onnx.common.data_types import guess_numpy_type 

19from skl2onnx.common.tree_ensemble import sklearn_threshold 

20from ..helpers.lgbm_helper import ( 

21 dump_lgbm_booster, modify_tree_for_rule_in_set) 

22 

23 

24def calculate_lightgbm_output_shapes(operator): 

25 """ 

26 Shape calculator for LightGBM Booster 

27 (see :epkg:`lightgbm`). 

28 """ 

29 op = operator.raw_operator 

30 if hasattr(op, "_model_dict"): 

31 objective = op._model_dict['objective'] 

32 elif hasattr(op, 'objective_'): 

33 objective = op.objective_ 

34 else: 

35 raise RuntimeError( # pragma: no cover 

36 "Unable to find attributes '_model_dict' or 'objective_' in " 

37 "instance of type %r (list of attributes=%r)." % ( 

38 type(op), dir(op))) 

39 if objective.startswith('binary') or objective.startswith('multiclass'): 

40 return calculate_linear_classifier_output_shapes(operator) 

41 if objective.startswith('regression'): # pragma: no cover 

42 return calculate_linear_regressor_output_shapes(operator) 

43 raise NotImplementedError( # pragma: no cover 

44 "Objective '{}' is not implemented yet.".format(objective)) 

45 

46 

47def _translate_split_criterion(criterion): 

48 # If the criterion is true, LightGBM use the left child. Otherwise, right child is selected. 

49 if criterion == '<=': 

50 return 'BRANCH_LEQ' 

51 if criterion == '<': # pragma: no cover 

52 return 'BRANCH_LT' 

53 if criterion == '>=': # pragma: no cover 

54 return 'BRANCH_GTE' 

55 if criterion == '>': # pragma: no cover 

56 return 'BRANCH_GT' 

57 if criterion == '==': # pragma: no cover 

58 return 'BRANCH_EQ' 

59 if criterion == '!=': # pragma: no cover 

60 return 'BRANCH_NEQ' 

61 raise ValueError( # pragma: no cover 

62 'Unsupported splitting criterion: %s. Only <=, ' 

63 '<, >=, and > are allowed.') 

64 

65 

66def _create_node_id(node_id_pool): 

67 i = 0 

68 while i in node_id_pool: 

69 i += 1 

70 node_id_pool.add(i) 

71 return i 

72 

73 

74def _parse_tree_structure(tree_id, class_id, learning_rate, 

75 tree_structure, attrs): 

76 """ 

77 The pool of all nodes' indexes created when parsing a single tree. 

78 Different tree use different pools. 

79 """ 

80 node_id_pool = set() 

81 node_pyid_pool = dict() 

82 

83 node_id = _create_node_id(node_id_pool) 

84 node_pyid_pool[id(tree_structure)] = node_id 

85 

86 # The root node is a leaf node. 

87 if ('left_child' not in tree_structure or 

88 'right_child' not in tree_structure): 

89 _parse_node(tree_id, class_id, node_id, node_id_pool, node_pyid_pool, 

90 learning_rate, tree_structure, attrs) 

91 return 

92 

93 left_pyid = id(tree_structure['left_child']) 

94 right_pyid = id(tree_structure['right_child']) 

95 

96 if left_pyid in node_pyid_pool: 

97 left_id = node_pyid_pool[left_pyid] 

98 left_parse = False 

99 else: 

100 left_id = _create_node_id(node_id_pool) 

101 node_pyid_pool[left_pyid] = left_id 

102 left_parse = True 

103 

104 if right_pyid in node_pyid_pool: 

105 right_id = node_pyid_pool[right_pyid] 

106 right_parse = False 

107 else: 

108 right_id = _create_node_id(node_id_pool) 

109 node_pyid_pool[right_pyid] = right_id 

110 right_parse = True 

111 

112 attrs['nodes_treeids'].append(tree_id) 

113 attrs['nodes_nodeids'].append(node_id) 

114 

115 attrs['nodes_featureids'].append(tree_structure['split_feature']) 

116 mode = _translate_split_criterion(tree_structure['decision_type']) 

117 attrs['nodes_modes'].append(mode) 

118 

119 if isinstance(tree_structure['threshold'], str): 

120 try: # pragma: no cover 

121 th = float(tree_structure['threshold']) # pragma: no cover 

122 except ValueError as e: # pragma: no cover 

123 import pprint 

124 text = pprint.pformat(tree_structure) 

125 if len(text) > 99999: 

126 text = text[:99999] + "\n..." 

127 raise TypeError("threshold must be a number not '{}'" 

128 "\n{}".format(tree_structure['threshold'], text)) from e 

129 else: 

130 th = tree_structure['threshold'] 

131 if mode == 'BRANCH_LEQ': 

132 th2 = sklearn_threshold(th, numpy.float32, mode) 

133 else: 

134 # other decision criteria are not implemented 

135 th2 = th 

136 attrs['nodes_values'].append(th2) 

137 

138 # Assume left is the true branch and right is the false branch 

139 attrs['nodes_truenodeids'].append(left_id) 

140 attrs['nodes_falsenodeids'].append(right_id) 

141 if tree_structure['default_left']: 

142 # attrs['nodes_missing_value_tracks_true'].append(1) 

143 if (tree_structure["missing_type"] in ('None', None) and 

144 float(tree_structure['threshold']) < 0.0): 

145 attrs['nodes_missing_value_tracks_true'].append(0) 

146 else: 

147 attrs['nodes_missing_value_tracks_true'].append(1) 

148 else: 

149 attrs['nodes_missing_value_tracks_true'].append(0) 

150 attrs['nodes_hitrates'].append(1.) 

151 if left_parse: 

152 _parse_node( 

153 tree_id, class_id, left_id, node_id_pool, node_pyid_pool, 

154 learning_rate, tree_structure['left_child'], attrs) 

155 if right_parse: 

156 _parse_node( 

157 tree_id, class_id, right_id, node_id_pool, node_pyid_pool, 

158 learning_rate, tree_structure['right_child'], attrs) 

159 

160 

161def _parse_node(tree_id, class_id, node_id, node_id_pool, node_pyid_pool, 

162 learning_rate, node, attrs): 

163 """ 

164 Parses nodes. 

165 """ 

166 if ((hasattr(node, 'left_child') and hasattr(node, 'right_child')) or 

167 ('left_child' in node and 'right_child' in node)): 

168 

169 left_pyid = id(node['left_child']) 

170 right_pyid = id(node['right_child']) 

171 

172 if left_pyid in node_pyid_pool: 

173 left_id = node_pyid_pool[left_pyid] 

174 left_parse = False 

175 else: 

176 left_id = _create_node_id(node_id_pool) 

177 node_pyid_pool[left_pyid] = left_id 

178 left_parse = True 

179 

180 if right_pyid in node_pyid_pool: 

181 right_id = node_pyid_pool[right_pyid] 

182 right_parse = False 

183 else: 

184 right_id = _create_node_id(node_id_pool) 

185 node_pyid_pool[right_pyid] = right_id 

186 right_parse = True 

187 

188 attrs['nodes_treeids'].append(tree_id) 

189 attrs['nodes_nodeids'].append(node_id) 

190 

191 attrs['nodes_featureids'].append(node['split_feature']) 

192 attrs['nodes_modes'].append( 

193 _translate_split_criterion(node['decision_type'])) 

194 if isinstance(node['threshold'], str): 

195 try: # pragma: no cover 

196 attrs['nodes_values'].append( # pragma: no cover 

197 float(node['threshold'])) 

198 except ValueError as e: # pragma: no cover 

199 import pprint 

200 text = pprint.pformat(node) 

201 if len(text) > 99999: 

202 text = text[:99999] + "\n..." 

203 raise TypeError("threshold must be a number not '{}'" 

204 "\n{}".format(node['threshold'], text)) from e 

205 else: 

206 attrs['nodes_values'].append(node['threshold']) 

207 

208 # Assume left is the true branch and right is the false branch 

209 attrs['nodes_truenodeids'].append(left_id) 

210 attrs['nodes_falsenodeids'].append(right_id) 

211 if node['default_left']: 

212 # attrs['nodes_missing_value_tracks_true'].append(1) 

213 if (node['missing_type'] in ('None', None) and 

214 float(node['threshold']) < 0.0): 

215 attrs['nodes_missing_value_tracks_true'].append(0) 

216 else: 

217 attrs['nodes_missing_value_tracks_true'].append(1) 

218 else: 

219 attrs['nodes_missing_value_tracks_true'].append(0) 

220 attrs['nodes_hitrates'].append(1.) 

221 

222 # Recursively dive into the child nodes 

223 if left_parse: 

224 _parse_node( 

225 tree_id, class_id, left_id, node_id_pool, node_pyid_pool, 

226 learning_rate, node['left_child'], attrs) 

227 if right_parse: 

228 _parse_node( 

229 tree_id, class_id, right_id, node_id_pool, node_pyid_pool, 

230 learning_rate, node['right_child'], attrs) 

231 elif hasattr(node, 'left_child') or hasattr(node, 'right_child'): 

232 raise ValueError('Need two branches') # pragma: no cover 

233 else: 

234 # Node attributes 

235 attrs['nodes_treeids'].append(tree_id) 

236 attrs['nodes_nodeids'].append(node_id) 

237 attrs['nodes_featureids'].append(0) 

238 attrs['nodes_modes'].append('LEAF') 

239 # Leaf node has no threshold. A zero is appended but it will never be used. 

240 attrs['nodes_values'].append(0.) 

241 # Leaf node has no child. A zero is appended but it will never be used. 

242 attrs['nodes_truenodeids'].append(0) 

243 # Leaf node has no child. A zero is appended but it will never be used. 

244 attrs['nodes_falsenodeids'].append(0) 

245 # Leaf node has no split function. A zero is appended but it will never be used. 

246 attrs['nodes_missing_value_tracks_true'].append(0) 

247 attrs['nodes_hitrates'].append(1.) 

248 

249 # Leaf attributes 

250 attrs['class_treeids'].append(tree_id) 

251 attrs['class_nodeids'].append(node_id) 

252 attrs['class_ids'].append(class_id) 

253 attrs['class_weights'].append( 

254 float(node['leaf_value']) * learning_rate) 

255 

256 

257def _split_tree_ensemble_atts(attrs, split): 

258 """ 

259 Splits the attributes of a TreeEnsembleRegressor into 

260 multiple trees in order to do the summation in double instead of floats. 

261 """ 

262 trees_id = list(sorted(set(attrs['nodes_treeids']))) 

263 results = [] 

264 index = 0 

265 while index < len(trees_id): 

266 index2 = min(index + split, len(trees_id)) 

267 subset = set(trees_id[index: index2]) 

268 

269 indices_node = [] 

270 indices_target = [] 

271 for j, v in enumerate(attrs['nodes_treeids']): 

272 if v in subset: 

273 indices_node.append(j) 

274 for j, v in enumerate(attrs['target_treeids']): 

275 if v in subset: 

276 indices_target.append(j) 

277 

278 if (len(indices_node) >= len(attrs['nodes_treeids']) or 

279 len(indices_target) >= len(attrs['target_treeids'])): 

280 raise RuntimeError( # pragma: no cover 

281 "Initial attributes are not consistant." 

282 "\nindex=%r index2=%r subset=%r" 

283 "\nnodes_treeids=%r\ntarget_treeids=%r" 

284 "\nindices_node=%r\nindices_target=%r" % ( 

285 index, index2, subset, 

286 attrs['nodes_treeids'], attrs['target_treeids'], 

287 indices_node, indices_target)) 

288 

289 ats = {} 

290 for name, att in attrs.items(): 

291 if name == 'nodes_treeids': 

292 new_att = [att[i] for i in indices_node] 

293 new_att = [i - att[0] for i in new_att] 

294 elif name == 'target_treeids': 

295 new_att = [att[i] for i in indices_target] 

296 new_att = [i - att[0] for i in new_att] 

297 elif name.startswith("nodes_"): 

298 new_att = [att[i] for i in indices_node] 

299 assert len(new_att) == len(indices_node) 

300 elif name.startswith("target_"): 

301 new_att = [att[i] for i in indices_target] 

302 assert len(new_att) == len(indices_target) 

303 elif name == 'name': 

304 new_att = "%s%d" % (att, len(results)) 

305 else: 

306 new_att = att 

307 ats[name] = new_att 

308 

309 results.append(ats) 

310 index = index2 

311 

312 return results 

313 

314 

315def convert_lightgbm(scope, operator, container): # pylint: disable=R0914 

316 """ 

317 This converters reuses the code from 

318 `LightGbm.py <https://github.com/onnx/onnxmltools/blob/master/onnxmltools/convert/ 

319 lightgbm/operator_converters/LightGbm.py>`_ and makes 

320 some modifications. It implements converters 

321 for models in :epkg:`lightgbm`. 

322 """ 

323 verbose = getattr(container, 'verbose', 0) 

324 gbm_model = operator.raw_operator 

325 if hasattr(gbm_model, '_model_dict_info'): 

326 gbm_text, info = gbm_model._model_dict_info 

327 else: 

328 if verbose >= 2: 

329 print("[convert_lightgbm] dump_model") 

330 gbm_text, info = dump_lgbm_booster(gbm_model.booster_, verbose=verbose) 

331 if verbose >= 2: 

332 print("[convert_lightgbm] modify_tree_for_rule_in_set") 

333 modify_tree_for_rule_in_set(gbm_text, use_float=True, verbose=verbose, 

334 info=info) 

335 

336 attrs = get_default_tree_classifier_attribute_pairs() 

337 attrs['name'] = operator.full_name 

338 

339 # Create different attributes for classifier and 

340 # regressor, respectively 

341 post_transform = None 

342 if gbm_text['objective'].startswith('binary'): 

343 n_classes = 1 

344 attrs['post_transform'] = 'LOGISTIC' 

345 elif gbm_text['objective'].startswith('multiclass'): 

346 n_classes = gbm_text['num_class'] 

347 attrs['post_transform'] = 'SOFTMAX' 

348 elif gbm_text['objective'].startswith('regression'): 

349 n_classes = 1 # Regressor has only one output variable 

350 attrs['post_transform'] = 'NONE' 

351 attrs['n_targets'] = n_classes 

352 elif gbm_text['objective'].startswith(('poisson', 'gamma')): 

353 n_classes = 1 # Regressor has only one output variable 

354 attrs['n_targets'] = n_classes 

355 # 'Exp' is not a supported post_transform value in the ONNX spec yet, 

356 # so we need to add an 'Exp' post transform node to the model 

357 attrs['post_transform'] = 'NONE' 

358 post_transform = "Exp" 

359 else: 

360 raise RuntimeError( # pragma: no cover 

361 "LightGBM objective should be cleaned already not '{}'.".format( 

362 gbm_text['objective'])) 

363 

364 # Use the same algorithm to parse the tree 

365 if verbose >= 2: 

366 from tqdm import tqdm 

367 loop = tqdm(gbm_text['tree_info']) 

368 loop.set_description("parse") 

369 else: 

370 loop = gbm_text['tree_info'] 

371 for i, tree in enumerate(loop): 

372 tree_id = i 

373 class_id = tree_id % n_classes 

374 # tree['shrinkage'] --> LightGbm provides figures with it already. 

375 learning_rate = 1. 

376 _parse_tree_structure( 

377 tree_id, class_id, learning_rate, tree['tree_structure'], attrs) 

378 

379 if verbose >= 2: 

380 print("[convert_lightgbm] onnx") 

381 # Sort nodes_* attributes. For one tree, its node indexes 

382 # should appear in an ascent order in nodes_nodeids. Nodes 

383 # from a tree with a smaller tree index should appear 

384 # before trees with larger indexes in nodes_nodeids. 

385 node_numbers_per_tree = Counter(attrs['nodes_treeids']) 

386 tree_number = len(node_numbers_per_tree.keys()) 

387 accumulated_node_numbers = [0] * tree_number 

388 for i in range(1, tree_number): 

389 accumulated_node_numbers[i] = ( 

390 accumulated_node_numbers[i - 1] + node_numbers_per_tree[i - 1]) 

391 global_node_indexes = [] 

392 for i in range(len(attrs['nodes_nodeids'])): 

393 tree_id = attrs['nodes_treeids'][i] 

394 node_id = attrs['nodes_nodeids'][i] 

395 global_node_indexes.append( 

396 accumulated_node_numbers[tree_id] + node_id) 

397 for k, v in attrs.items(): 

398 if k.startswith('nodes_'): 

399 merged_indexes = zip( 

400 copy.deepcopy(global_node_indexes), v) 

401 sorted_list = [pair[1] 

402 for pair in sorted(merged_indexes, 

403 key=lambda x: x[0])] 

404 attrs[k] = sorted_list 

405 

406 dtype = guess_numpy_type(operator.inputs[0].type) 

407 if dtype != numpy.float64: 

408 dtype = numpy.float32 

409 

410 # Create ONNX object 

411 if (gbm_text['objective'].startswith('binary') or 

412 gbm_text['objective'].startswith('multiclass')): 

413 # Prepare label information for both of TreeEnsembleClassifier 

414 # and ZipMap 

415 class_type = onnx_proto.TensorProto.STRING # pylint: disable=E1101 

416 zipmap_attrs = {'name': scope.get_unique_variable_name('ZipMap')} 

417 if all(isinstance(i, (numbers.Real, bool, numpy.bool_)) 

418 for i in gbm_model.classes_): 

419 class_type = onnx_proto.TensorProto.INT64 # pylint: disable=E1101 

420 class_labels = [int(i) for i in gbm_model.classes_] 

421 attrs['classlabels_int64s'] = class_labels 

422 zipmap_attrs['classlabels_int64s'] = class_labels 

423 elif all(isinstance(i, str) for i in gbm_model.classes_): 

424 class_labels = [str(i) for i in gbm_model.classes_] 

425 attrs['classlabels_strings'] = class_labels 

426 zipmap_attrs['classlabels_strings'] = class_labels 

427 else: 

428 raise ValueError( # pragma: no cover 

429 'Only string and integer class labels are allowed') 

430 

431 # Create tree classifier 

432 probability_tensor_name = scope.get_unique_variable_name( 

433 'probability_tensor') 

434 label_tensor_name = scope.get_unique_variable_name('label_tensor') 

435 

436 if dtype == numpy.float64: 

437 container.add_node('TreeEnsembleClassifierDouble', operator.input_full_names, 

438 [label_tensor_name, probability_tensor_name], 

439 op_domain='mlprodict', **attrs) 

440 else: 

441 container.add_node('TreeEnsembleClassifier', operator.input_full_names, 

442 [label_tensor_name, probability_tensor_name], 

443 op_domain='ai.onnx.ml', **attrs) 

444 

445 prob_tensor = probability_tensor_name 

446 

447 if gbm_model.boosting_type == 'rf': 

448 col_index_name = scope.get_unique_variable_name('col_index') 

449 first_col_name = scope.get_unique_variable_name('first_col') 

450 zeroth_col_name = scope.get_unique_variable_name('zeroth_col') 

451 denominator_name = scope.get_unique_variable_name('denominator') 

452 modified_first_col_name = scope.get_unique_variable_name( 

453 'modified_first_col') 

454 unit_float_tensor_name = scope.get_unique_variable_name( 

455 'unit_float_tensor') 

456 merged_prob_name = scope.get_unique_variable_name('merged_prob') 

457 predicted_label_name = scope.get_unique_variable_name( 

458 'predicted_label') 

459 classes_name = scope.get_unique_variable_name('classes') 

460 final_label_name = scope.get_unique_variable_name('final_label') 

461 

462 container.add_initializer( 

463 col_index_name, onnx_proto.TensorProto.INT64, [], [1]) # pylint: disable=E1101 

464 container.add_initializer( 

465 unit_float_tensor_name, onnx_proto.TensorProto.FLOAT, [], [1.0]) # pylint: disable=E1101 

466 container.add_initializer( 

467 denominator_name, onnx_proto.TensorProto.FLOAT, [], [100.0]) # pylint: disable=E1101 

468 container.add_initializer(classes_name, class_type, 

469 [len(class_labels)], class_labels) 

470 

471 container.add_node( 

472 'ArrayFeatureExtractor', 

473 [probability_tensor_name, col_index_name], 

474 first_col_name, 

475 name=scope.get_unique_operator_name( 

476 'ArrayFeatureExtractor'), 

477 op_domain='ai.onnx.ml') 

478 apply_div(scope, [first_col_name, denominator_name], 

479 modified_first_col_name, container, broadcast=1) 

480 apply_sub( 

481 scope, [unit_float_tensor_name, modified_first_col_name], 

482 zeroth_col_name, container, broadcast=1) 

483 container.add_node( 

484 'Concat', [zeroth_col_name, modified_first_col_name], 

485 merged_prob_name, 

486 name=scope.get_unique_operator_name('Concat'), axis=1) 

487 container.add_node( 

488 'ArgMax', merged_prob_name, 

489 predicted_label_name, 

490 name=scope.get_unique_operator_name('ArgMax'), axis=1) 

491 container.add_node( 

492 'ArrayFeatureExtractor', [classes_name, predicted_label_name], 

493 final_label_name, 

494 name=scope.get_unique_operator_name('ArrayFeatureExtractor'), 

495 op_domain='ai.onnx.ml') 

496 apply_reshape(scope, final_label_name, 

497 operator.outputs[0].full_name, 

498 container, desired_shape=[-1, ]) 

499 prob_tensor = merged_prob_name 

500 else: 

501 container.add_node('Identity', label_tensor_name, 

502 operator.outputs[0].full_name, 

503 name=scope.get_unique_operator_name('Identity')) 

504 

505 # Convert probability tensor to probability map 

506 # (keys are labels while values are the associated probabilities) 

507 container.add_node('Identity', prob_tensor, 

508 operator.outputs[1].full_name) 

509 else: 

510 # Create tree regressor 

511 output_name = scope.get_unique_variable_name('output') 

512 

513 keys_to_be_renamed = list( 

514 k for k in attrs if k.startswith('class_')) 

515 

516 for k in keys_to_be_renamed: 

517 # Rename class_* attribute to target_* 

518 # because TreeEnsebmleClassifier 

519 # and TreeEnsembleClassifier have different ONNX attributes 

520 attrs['target' + k[5:]] = copy.deepcopy(attrs[k]) 

521 del attrs[k] 

522 

523 options = container.get_options(gbm_model, dict(split=-1)) 

524 split = options['split'] 

525 if split == -1: 

526 if dtype == numpy.float64: 

527 container.add_node( 

528 'TreeEnsembleRegressorDouble', operator.input_full_names, 

529 output_name, op_domain='mlprodict', **attrs) 

530 else: 

531 container.add_node( 

532 'TreeEnsembleRegressor', operator.input_full_names, 

533 output_name, op_domain='ai.onnx.ml', **attrs) 

534 else: 

535 tree_attrs = _split_tree_ensemble_atts(attrs, split) 

536 tree_nodes = [] 

537 for i, ats in enumerate(tree_attrs): 

538 tree_name = scope.get_unique_variable_name('tree%d' % i) 

539 if dtype == numpy.float64: 

540 container.add_node( 

541 'TreeEnsembleRegressorDouble', operator.input_full_names, 

542 tree_name, op_domain='mlprodict', **ats) 

543 tree_nodes.append(tree_name) 

544 else: 

545 container.add_node( 

546 'TreeEnsembleRegressor', operator.input_full_names, 

547 tree_name, op_domain='ai.onnx.ml', **ats) 

548 cast_name = scope.get_unique_variable_name('dtree%d' % i) 

549 container.add_node( 

550 'Cast', tree_name, cast_name, to=TensorProto.DOUBLE, # pylint: disable=E1101 

551 name=scope.get_unique_operator_name("dtree%d" % i)) 

552 tree_nodes.append(cast_name) 

553 if dtype == numpy.float64: 

554 container.add_node( 

555 'Sum', tree_nodes, output_name, 

556 name=scope.get_unique_operator_name("sumtree%d" % len(tree_nodes))) 

557 else: 

558 cast_name = scope.get_unique_variable_name('ftrees') 

559 container.add_node( 

560 'Sum', tree_nodes, cast_name, 

561 name=scope.get_unique_operator_name("sumtree%d" % len(tree_nodes))) 

562 container.add_node( 

563 'Cast', cast_name, output_name, to=TensorProto.FLOAT, # pylint: disable=E1101 

564 name=scope.get_unique_operator_name("dtree%d" % i)) 

565 

566 if gbm_model.boosting_type == 'rf': 

567 denominator_name = scope.get_unique_variable_name('denominator') 

568 

569 container.add_initializer( 

570 denominator_name, onnx_proto.TensorProto.FLOAT, # pylint: disable=E1101 

571 [], [100.0]) 

572 

573 apply_div(scope, [output_name, denominator_name], 

574 operator.output_full_names, container, broadcast=1) 

575 elif post_transform: 

576 container.add_node( 

577 post_transform, output_name, 

578 operator.output_full_names, 

579 name=scope.get_unique_operator_name( 

580 post_transform)) 

581 else: 

582 container.add_node('Identity', output_name, 

583 operator.output_full_names, 

584 name=scope.get_unique_operator_name('Identity')) 

585 

586 if verbose >= 2: 

587 print("[convert_lightgbm] end")