Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7from collections import OrderedDict 

8import numpy 

9from ._op_helper import _get_typed_class_attribute 

10from ._op import OpRunClassifierProb, RuntimeTypeError 

11from ._op_classifier_string import _ClassifierCommon 

12from ._new_ops import OperatorSchema 

13from .op_tree_ensemble_classifier_ import ( # pylint: disable=E0611,E0401 

14 RuntimeTreeEnsembleClassifierDouble, 

15 RuntimeTreeEnsembleClassifierFloat, 

16) 

17from .op_tree_ensemble_classifier_p_ import ( # pylint: disable=E0611,E0401 

18 RuntimeTreeEnsembleClassifierPFloat, 

19 RuntimeTreeEnsembleClassifierPDouble, 

20) 

21 

22 

23class TreeEnsembleClassifierCommon(OpRunClassifierProb, _ClassifierCommon): 

24 

25 def __init__(self, dtype, onnx_node, desc=None, 

26 expected_attributes=None, 

27 runtime_version=3, **options): 

28 OpRunClassifierProb.__init__( 

29 self, onnx_node, desc=desc, 

30 expected_attributes=expected_attributes, **options) 

31 self._init(dtype=dtype, version=runtime_version) 

32 

33 def _get_typed_attributes(self, k): 

34 return _get_typed_class_attribute(self, k, self.__class__.atts) 

35 

36 def _find_custom_operator_schema(self, op_name): 

37 """ 

38 Finds a custom operator defined by this runtime. 

39 """ 

40 if op_name == "TreeEnsembleClassifierDouble": 

41 return TreeEnsembleClassifierDoubleSchema() 

42 raise RuntimeError( # pragma: no cover 

43 "Unable to find a schema for operator '{}'.".format(op_name)) 

44 

45 def _init(self, dtype, version): 

46 self._post_process_label_attributes() 

47 if dtype == numpy.float32: 

48 if version == 0: 

49 self.rt_ = RuntimeTreeEnsembleClassifierFloat() 

50 elif version == 1: 

51 self.rt_ = RuntimeTreeEnsembleClassifierPFloat( 

52 60, 20, False, False) 

53 elif version == 2: 

54 self.rt_ = RuntimeTreeEnsembleClassifierPFloat( 

55 60, 20, True, False) 

56 elif version == 3: 

57 self.rt_ = RuntimeTreeEnsembleClassifierPFloat( 

58 60, 20, True, True) 

59 else: 

60 raise ValueError("Unknown version '{}'.".format(version)) 

61 elif dtype == numpy.float64: 

62 if version == 0: 

63 self.rt_ = RuntimeTreeEnsembleClassifierDouble() 

64 elif version == 1: 

65 self.rt_ = RuntimeTreeEnsembleClassifierPDouble( 

66 60, 20, False, False) 

67 elif version == 2: 

68 self.rt_ = RuntimeTreeEnsembleClassifierPDouble( 

69 60, 20, True, False) 

70 elif version == 3: 

71 self.rt_ = RuntimeTreeEnsembleClassifierPDouble( 

72 60, 20, True, True) 

73 else: 

74 raise ValueError( # pragma: no cover 

75 "Unknown version '{}'.".format(version)) 

76 else: 

77 raise RuntimeTypeError( # pragma: no cover 

78 "Unsupported dtype={}.".format(dtype)) 

79 atts = [self._get_typed_attributes(k) 

80 for k in self.__class__.atts] 

81 self.rt_.init(*atts) 

82 

83 def _run(self, x): # pylint: disable=W0221 

84 """ 

85 This is a C++ implementation coming from 

86 :epkg:`onnxruntime`. 

87 `tree_ensemble_classifier.cc 

88 <https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/providers/cpu/ml/tree_ensemble_classifier.cc>`_. 

89 See class :class:`RuntimeTreeEnsembleClassifier 

90 <mlprodict.onnxrt.ops_cpu.op_tree_ensemble_classifier_.RuntimeTreeEnsembleClassifier>`. 

91 """ 

92 label, scores = self.rt_.compute(x) 

93 if scores.shape[0] != label.shape[0]: 

94 scores = scores.reshape(label.shape[0], 

95 scores.shape[0] // label.shape[0]) 

96 return self._post_process_predicted_label(label, scores) 

97 

98 

99class TreeEnsembleClassifier(TreeEnsembleClassifierCommon): 

100 

101 atts = OrderedDict([ 

102 ('base_values', numpy.empty(0, dtype=numpy.float32)), 

103 ('class_ids', numpy.empty(0, dtype=numpy.int64)), 

104 ('class_nodeids', numpy.empty(0, dtype=numpy.int64)), 

105 ('class_treeids', numpy.empty(0, dtype=numpy.int64)), 

106 ('class_weights', numpy.empty(0, dtype=numpy.float32)), 

107 ('classlabels_int64s', numpy.empty(0, dtype=numpy.int64)), 

108 ('classlabels_strings', []), 

109 ('nodes_falsenodeids', numpy.empty(0, dtype=numpy.int64)), 

110 ('nodes_featureids', numpy.empty(0, dtype=numpy.int64)), 

111 ('nodes_hitrates', numpy.empty(0, dtype=numpy.float32)), 

112 ('nodes_missing_value_tracks_true', numpy.empty(0, dtype=numpy.int64)), 

113 ('nodes_modes', []), 

114 ('nodes_nodeids', numpy.empty(0, dtype=numpy.int64)), 

115 ('nodes_treeids', numpy.empty(0, dtype=numpy.int64)), 

116 ('nodes_truenodeids', numpy.empty(0, dtype=numpy.int64)), 

117 ('nodes_values', numpy.empty(0, dtype=numpy.float32)), 

118 ('post_transform', b'NONE') 

119 ]) 

120 

121 def __init__(self, onnx_node, desc=None, **options): 

122 TreeEnsembleClassifierCommon.__init__( 

123 self, numpy.float32, onnx_node, desc=desc, 

124 expected_attributes=TreeEnsembleClassifier.atts, **options) 

125 

126 

127class TreeEnsembleClassifierDouble(TreeEnsembleClassifierCommon): 

128 

129 atts = OrderedDict([ 

130 ('base_values', numpy.empty(0, dtype=numpy.float64)), 

131 ('class_ids', numpy.empty(0, dtype=numpy.int64)), 

132 ('class_nodeids', numpy.empty(0, dtype=numpy.int64)), 

133 ('class_treeids', numpy.empty(0, dtype=numpy.int64)), 

134 ('class_weights', numpy.empty(0, dtype=numpy.float64)), 

135 ('classlabels_int64s', numpy.empty(0, dtype=numpy.int64)), 

136 ('classlabels_strings', []), 

137 ('nodes_falsenodeids', numpy.empty(0, dtype=numpy.int64)), 

138 ('nodes_featureids', numpy.empty(0, dtype=numpy.int64)), 

139 ('nodes_hitrates', numpy.empty(0, dtype=numpy.float64)), 

140 ('nodes_missing_value_tracks_true', numpy.empty(0, dtype=numpy.int64)), 

141 ('nodes_modes', []), 

142 ('nodes_nodeids', numpy.empty(0, dtype=numpy.int64)), 

143 ('nodes_treeids', numpy.empty(0, dtype=numpy.int64)), 

144 ('nodes_truenodeids', numpy.empty(0, dtype=numpy.int64)), 

145 ('nodes_values', numpy.empty(0, dtype=numpy.float64)), 

146 ('post_transform', b'NONE') 

147 ]) 

148 

149 def __init__(self, onnx_node, desc=None, **options): 

150 TreeEnsembleClassifierCommon.__init__( 

151 self, numpy.float64, onnx_node, desc=desc, 

152 expected_attributes=TreeEnsembleClassifier.atts, **options) 

153 

154 

155class TreeEnsembleClassifierDoubleSchema(OperatorSchema): 

156 """ 

157 Defines a schema for operators added in this package 

158 such as @see cl TreeEnsembleClassifierDouble. 

159 """ 

160 

161 def __init__(self): 

162 OperatorSchema.__init__(self, 'TreeEnsembleClassifierDouble') 

163 self.attributes = TreeEnsembleClassifierDouble.atts