Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7from collections import OrderedDict 

8import numpy 

9from ._op_helper import _get_typed_class_attribute 

10from ._op import OpRunUnaryNum, RuntimeTypeError 

11from ._new_ops import OperatorSchema 

12from .op_tree_ensemble_regressor_ import ( # pylint: disable=E0611,E0401 

13 RuntimeTreeEnsembleRegressorFloat, RuntimeTreeEnsembleRegressorDouble) 

14from .op_tree_ensemble_regressor_p_ import ( # pylint: disable=E0611,E0401 

15 RuntimeTreeEnsembleRegressorPFloat, RuntimeTreeEnsembleRegressorPDouble) 

16 

17 

18class TreeEnsembleRegressorCommon(OpRunUnaryNum): 

19 

20 def __init__(self, dtype, onnx_node, desc=None, 

21 expected_attributes=None, runtime_version=3, **options): 

22 OpRunUnaryNum.__init__( 

23 self, onnx_node, desc=desc, 

24 expected_attributes=expected_attributes, **options) 

25 self._init(dtype=dtype, version=runtime_version) 

26 

27 def _get_typed_attributes(self, k): 

28 return _get_typed_class_attribute(self, k, self.__class__.atts) 

29 

30 def _find_custom_operator_schema(self, op_name): 

31 """ 

32 Finds a custom operator defined by this runtime. 

33 """ 

34 if op_name == "TreeEnsembleRegressorDouble": 

35 return TreeEnsembleRegressorDoubleSchema() 

36 raise RuntimeError( # pragma: no cover 

37 "Unable to find a schema for operator '{}'.".format(op_name)) 

38 

39 def _init(self, dtype, version): 

40 if dtype == numpy.float32: 

41 if version == 0: 

42 self.rt_ = RuntimeTreeEnsembleRegressorFloat() 

43 elif version == 1: 

44 self.rt_ = RuntimeTreeEnsembleRegressorPFloat( 

45 60, 20, False, False) 

46 elif version == 2: 

47 self.rt_ = RuntimeTreeEnsembleRegressorPFloat( 

48 60, 20, True, False) 

49 elif version == 3: 

50 self.rt_ = RuntimeTreeEnsembleRegressorPFloat( 

51 60, 20, True, True) 

52 else: 

53 raise ValueError("Unknown version '{}'.".format(version)) 

54 elif dtype == numpy.float64: 

55 if version == 0: 

56 self.rt_ = RuntimeTreeEnsembleRegressorDouble() 

57 elif version == 1: 

58 self.rt_ = RuntimeTreeEnsembleRegressorPDouble( 

59 60, 20, False, False) 

60 elif version == 2: 

61 self.rt_ = RuntimeTreeEnsembleRegressorPDouble( 

62 60, 20, True, False) 

63 elif version == 3: 

64 self.rt_ = RuntimeTreeEnsembleRegressorPDouble( 

65 60, 20, True, True) 

66 else: 

67 raise ValueError("Unknown version '{}'.".format(version)) 

68 else: 

69 raise RuntimeTypeError( # pragma: no cover 

70 "Unsupported dtype={}.".format(dtype)) 

71 atts = [self._get_typed_attributes(k) 

72 for k in self.__class__.atts] 

73 self.rt_.init(*atts) 

74 

75 def _run(self, x): # pylint: disable=W0221 

76 """ 

77 This is a C++ implementation coming from 

78 :epkg:`onnxruntime`. 

79 `tree_ensemble_classifier.cc 

80 <https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/providers/cpu/ml/tree_ensemble_classifier.cc>`_. 

81 See class :class:`RuntimeTreeEnsembleRegressorFloat 

82 <mlprodict.onnxrt.ops_cpu.op_tree_ensemble_regressor_.RuntimeTreeEnsembleRegressorFloat>` or 

83 class :class:`RuntimeTreeEnsembleRegressorDouble 

84 <mlprodict.onnxrt.ops_cpu.op_tree_ensemble_regressor_.RuntimeTreeEnsembleRegressorDouble>`. 

85 """ 

86 if hasattr(x, 'todense'): 

87 x = x.todense() 

88 pred = self.rt_.compute(x) 

89 if pred.shape[0] != x.shape[0]: 

90 pred = pred.reshape(x.shape[0], pred.shape[0] // x.shape[0]) 

91 return (pred, ) 

92 

93 

94class TreeEnsembleRegressor(TreeEnsembleRegressorCommon): 

95 

96 atts = OrderedDict([ 

97 ('aggregate_function', b'SUM'), 

98 ('base_values', numpy.empty(0, dtype=numpy.float32)), 

99 ('n_targets', 1), 

100 ('nodes_falsenodeids', numpy.empty(0, dtype=numpy.int64)), 

101 ('nodes_featureids', numpy.empty(0, dtype=numpy.int64)), 

102 ('nodes_hitrates', numpy.empty(0, dtype=numpy.float32)), 

103 ('nodes_missing_value_tracks_true', numpy.empty(0, dtype=numpy.int64)), 

104 ('nodes_modes', []), 

105 ('nodes_nodeids', numpy.empty(0, dtype=numpy.int64)), 

106 ('nodes_treeids', numpy.empty(0, dtype=numpy.int64)), 

107 ('nodes_truenodeids', numpy.empty(0, dtype=numpy.int64)), 

108 ('nodes_values', numpy.empty(0, dtype=numpy.float32)), 

109 ('post_transform', b'NONE'), 

110 ('target_ids', numpy.empty(0, dtype=numpy.int64)), 

111 ('target_nodeids', numpy.empty(0, dtype=numpy.int64)), 

112 ('target_treeids', numpy.empty(0, dtype=numpy.int64)), 

113 ('target_weights', numpy.empty(0, dtype=numpy.float32)), 

114 ]) 

115 

116 def __init__(self, onnx_node, desc=None, runtime_version=1, **options): 

117 TreeEnsembleRegressorCommon.__init__( 

118 self, numpy.float32, onnx_node, desc=desc, 

119 expected_attributes=TreeEnsembleRegressor.atts, 

120 runtime_version=runtime_version, **options) 

121 

122 

123class TreeEnsembleRegressorDouble(TreeEnsembleRegressorCommon): 

124 

125 atts = OrderedDict([ 

126 ('aggregate_function', b'SUM'), 

127 ('base_values', numpy.empty(0, dtype=numpy.float64)), 

128 ('n_targets', 1), 

129 ('nodes_falsenodeids', numpy.empty(0, dtype=numpy.int64)), 

130 ('nodes_featureids', numpy.empty(0, dtype=numpy.int64)), 

131 ('nodes_hitrates', numpy.empty(0, dtype=numpy.float64)), 

132 ('nodes_missing_value_tracks_true', numpy.empty(0, dtype=numpy.int64)), 

133 ('nodes_modes', []), 

134 ('nodes_nodeids', numpy.empty(0, dtype=numpy.int64)), 

135 ('nodes_treeids', numpy.empty(0, dtype=numpy.int64)), 

136 ('nodes_truenodeids', numpy.empty(0, dtype=numpy.int64)), 

137 ('nodes_values', numpy.empty(0, dtype=numpy.float64)), 

138 ('post_transform', b'NONE'), 

139 ('target_ids', numpy.empty(0, dtype=numpy.int64)), 

140 ('target_nodeids', numpy.empty(0, dtype=numpy.int64)), 

141 ('target_treeids', numpy.empty(0, dtype=numpy.int64)), 

142 ('target_weights', numpy.empty(0, dtype=numpy.float64)), 

143 ]) 

144 

145 def __init__(self, onnx_node, desc=None, runtime_version=1, **options): 

146 TreeEnsembleRegressorCommon.__init__( 

147 self, numpy.float64, onnx_node, desc=desc, 

148 expected_attributes=TreeEnsembleRegressorDouble.atts, 

149 runtime_version=runtime_version, **options) 

150 

151 

152class TreeEnsembleRegressorDoubleSchema(OperatorSchema): 

153 """ 

154 Defines a schema for operators added in this package 

155 such as @see cl TreeEnsembleRegressorDouble. 

156 """ 

157 

158 def __init__(self): 

159 OperatorSchema.__init__(self, 'TreeEnsembleRegressorDouble') 

160 self.attributes = TreeEnsembleRegressorDouble.atts