Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7from collections import OrderedDict
8import numpy
9from ._op_helper import _get_typed_class_attribute
10from ._op import OpRunUnaryNum, RuntimeTypeError
11from ._new_ops import OperatorSchema
12from .op_tree_ensemble_regressor_ import ( # pylint: disable=E0611,E0401
13 RuntimeTreeEnsembleRegressorFloat, RuntimeTreeEnsembleRegressorDouble)
14from .op_tree_ensemble_regressor_p_ import ( # pylint: disable=E0611,E0401
15 RuntimeTreeEnsembleRegressorPFloat, RuntimeTreeEnsembleRegressorPDouble)
18class TreeEnsembleRegressorCommon(OpRunUnaryNum):
20 def __init__(self, dtype, onnx_node, desc=None,
21 expected_attributes=None, runtime_version=3, **options):
22 OpRunUnaryNum.__init__(
23 self, onnx_node, desc=desc,
24 expected_attributes=expected_attributes, **options)
25 self._init(dtype=dtype, version=runtime_version)
27 def _get_typed_attributes(self, k):
28 return _get_typed_class_attribute(self, k, self.__class__.atts)
30 def _find_custom_operator_schema(self, op_name):
31 """
32 Finds a custom operator defined by this runtime.
33 """
34 if op_name == "TreeEnsembleRegressorDouble":
35 return TreeEnsembleRegressorDoubleSchema()
36 raise RuntimeError( # pragma: no cover
37 "Unable to find a schema for operator '{}'.".format(op_name))
39 def _init(self, dtype, version):
40 if dtype == numpy.float32:
41 if version == 0:
42 self.rt_ = RuntimeTreeEnsembleRegressorFloat()
43 elif version == 1:
44 self.rt_ = RuntimeTreeEnsembleRegressorPFloat(
45 60, 20, False, False)
46 elif version == 2:
47 self.rt_ = RuntimeTreeEnsembleRegressorPFloat(
48 60, 20, True, False)
49 elif version == 3:
50 self.rt_ = RuntimeTreeEnsembleRegressorPFloat(
51 60, 20, True, True)
52 else:
53 raise ValueError("Unknown version '{}'.".format(version))
54 elif dtype == numpy.float64:
55 if version == 0:
56 self.rt_ = RuntimeTreeEnsembleRegressorDouble()
57 elif version == 1:
58 self.rt_ = RuntimeTreeEnsembleRegressorPDouble(
59 60, 20, False, False)
60 elif version == 2:
61 self.rt_ = RuntimeTreeEnsembleRegressorPDouble(
62 60, 20, True, False)
63 elif version == 3:
64 self.rt_ = RuntimeTreeEnsembleRegressorPDouble(
65 60, 20, True, True)
66 else:
67 raise ValueError("Unknown version '{}'.".format(version))
68 else:
69 raise RuntimeTypeError( # pragma: no cover
70 "Unsupported dtype={}.".format(dtype))
71 atts = [self._get_typed_attributes(k)
72 for k in self.__class__.atts]
73 self.rt_.init(*atts)
75 def _run(self, x): # pylint: disable=W0221
76 """
77 This is a C++ implementation coming from
78 :epkg:`onnxruntime`.
79 `tree_ensemble_classifier.cc
80 <https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/providers/cpu/ml/tree_ensemble_classifier.cc>`_.
81 See class :class:`RuntimeTreeEnsembleRegressorFloat
82 <mlprodict.onnxrt.ops_cpu.op_tree_ensemble_regressor_.RuntimeTreeEnsembleRegressorFloat>` or
83 class :class:`RuntimeTreeEnsembleRegressorDouble
84 <mlprodict.onnxrt.ops_cpu.op_tree_ensemble_regressor_.RuntimeTreeEnsembleRegressorDouble>`.
85 """
86 if hasattr(x, 'todense'):
87 x = x.todense()
88 pred = self.rt_.compute(x)
89 if pred.shape[0] != x.shape[0]:
90 pred = pred.reshape(x.shape[0], pred.shape[0] // x.shape[0])
91 return (pred, )
94class TreeEnsembleRegressor(TreeEnsembleRegressorCommon):
96 atts = OrderedDict([
97 ('aggregate_function', b'SUM'),
98 ('base_values', numpy.empty(0, dtype=numpy.float32)),
99 ('n_targets', 1),
100 ('nodes_falsenodeids', numpy.empty(0, dtype=numpy.int64)),
101 ('nodes_featureids', numpy.empty(0, dtype=numpy.int64)),
102 ('nodes_hitrates', numpy.empty(0, dtype=numpy.float32)),
103 ('nodes_missing_value_tracks_true', numpy.empty(0, dtype=numpy.int64)),
104 ('nodes_modes', []),
105 ('nodes_nodeids', numpy.empty(0, dtype=numpy.int64)),
106 ('nodes_treeids', numpy.empty(0, dtype=numpy.int64)),
107 ('nodes_truenodeids', numpy.empty(0, dtype=numpy.int64)),
108 ('nodes_values', numpy.empty(0, dtype=numpy.float32)),
109 ('post_transform', b'NONE'),
110 ('target_ids', numpy.empty(0, dtype=numpy.int64)),
111 ('target_nodeids', numpy.empty(0, dtype=numpy.int64)),
112 ('target_treeids', numpy.empty(0, dtype=numpy.int64)),
113 ('target_weights', numpy.empty(0, dtype=numpy.float32)),
114 ])
116 def __init__(self, onnx_node, desc=None, runtime_version=1, **options):
117 TreeEnsembleRegressorCommon.__init__(
118 self, numpy.float32, onnx_node, desc=desc,
119 expected_attributes=TreeEnsembleRegressor.atts,
120 runtime_version=runtime_version, **options)
123class TreeEnsembleRegressorDouble(TreeEnsembleRegressorCommon):
125 atts = OrderedDict([
126 ('aggregate_function', b'SUM'),
127 ('base_values', numpy.empty(0, dtype=numpy.float64)),
128 ('n_targets', 1),
129 ('nodes_falsenodeids', numpy.empty(0, dtype=numpy.int64)),
130 ('nodes_featureids', numpy.empty(0, dtype=numpy.int64)),
131 ('nodes_hitrates', numpy.empty(0, dtype=numpy.float64)),
132 ('nodes_missing_value_tracks_true', numpy.empty(0, dtype=numpy.int64)),
133 ('nodes_modes', []),
134 ('nodes_nodeids', numpy.empty(0, dtype=numpy.int64)),
135 ('nodes_treeids', numpy.empty(0, dtype=numpy.int64)),
136 ('nodes_truenodeids', numpy.empty(0, dtype=numpy.int64)),
137 ('nodes_values', numpy.empty(0, dtype=numpy.float64)),
138 ('post_transform', b'NONE'),
139 ('target_ids', numpy.empty(0, dtype=numpy.int64)),
140 ('target_nodeids', numpy.empty(0, dtype=numpy.int64)),
141 ('target_treeids', numpy.empty(0, dtype=numpy.int64)),
142 ('target_weights', numpy.empty(0, dtype=numpy.float64)),
143 ])
145 def __init__(self, onnx_node, desc=None, runtime_version=1, **options):
146 TreeEnsembleRegressorCommon.__init__(
147 self, numpy.float64, onnx_node, desc=desc,
148 expected_attributes=TreeEnsembleRegressorDouble.atts,
149 runtime_version=runtime_version, **options)
152class TreeEnsembleRegressorDoubleSchema(OperatorSchema):
153 """
154 Defines a schema for operators added in this package
155 such as @see cl TreeEnsembleRegressorDouble.
156 """
158 def __init__(self):
159 OperatorSchema.__init__(self, 'TreeEnsembleRegressorDouble')
160 self.attributes = TreeEnsembleRegressorDouble.atts