Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7from collections import OrderedDict
8import numpy
9from ._op_helper import _get_typed_class_attribute
10from ._op import OpRunClassifierProb, RuntimeTypeError
11from ._op_classifier_string import _ClassifierCommon
12from ._new_ops import OperatorSchema
13from .op_tree_ensemble_classifier_ import ( # pylint: disable=E0611,E0401
14 RuntimeTreeEnsembleClassifierDouble,
15 RuntimeTreeEnsembleClassifierFloat,
16)
17from .op_tree_ensemble_classifier_p_ import ( # pylint: disable=E0611,E0401
18 RuntimeTreeEnsembleClassifierPFloat,
19 RuntimeTreeEnsembleClassifierPDouble,
20)
23class TreeEnsembleClassifierCommon(OpRunClassifierProb, _ClassifierCommon):
25 def __init__(self, dtype, onnx_node, desc=None,
26 expected_attributes=None,
27 runtime_version=3, **options):
28 OpRunClassifierProb.__init__(
29 self, onnx_node, desc=desc,
30 expected_attributes=expected_attributes, **options)
31 self._init(dtype=dtype, version=runtime_version)
33 def _get_typed_attributes(self, k):
34 return _get_typed_class_attribute(self, k, self.__class__.atts)
36 def _find_custom_operator_schema(self, op_name):
37 """
38 Finds a custom operator defined by this runtime.
39 """
40 if op_name == "TreeEnsembleClassifierDouble":
41 return TreeEnsembleClassifierDoubleSchema()
42 raise RuntimeError( # pragma: no cover
43 "Unable to find a schema for operator '{}'.".format(op_name))
45 def _init(self, dtype, version):
46 self._post_process_label_attributes()
47 if dtype == numpy.float32:
48 if version == 0:
49 self.rt_ = RuntimeTreeEnsembleClassifierFloat()
50 elif version == 1:
51 self.rt_ = RuntimeTreeEnsembleClassifierPFloat(
52 60, 20, False, False)
53 elif version == 2:
54 self.rt_ = RuntimeTreeEnsembleClassifierPFloat(
55 60, 20, True, False)
56 elif version == 3:
57 self.rt_ = RuntimeTreeEnsembleClassifierPFloat(
58 60, 20, True, True)
59 else:
60 raise ValueError("Unknown version '{}'.".format(version))
61 elif dtype == numpy.float64:
62 if version == 0:
63 self.rt_ = RuntimeTreeEnsembleClassifierDouble()
64 elif version == 1:
65 self.rt_ = RuntimeTreeEnsembleClassifierPDouble(
66 60, 20, False, False)
67 elif version == 2:
68 self.rt_ = RuntimeTreeEnsembleClassifierPDouble(
69 60, 20, True, False)
70 elif version == 3:
71 self.rt_ = RuntimeTreeEnsembleClassifierPDouble(
72 60, 20, True, True)
73 else:
74 raise ValueError( # pragma: no cover
75 "Unknown version '{}'.".format(version))
76 else:
77 raise RuntimeTypeError( # pragma: no cover
78 "Unsupported dtype={}.".format(dtype))
79 atts = [self._get_typed_attributes(k)
80 for k in self.__class__.atts]
81 self.rt_.init(*atts)
83 def _run(self, x): # pylint: disable=W0221
84 """
85 This is a C++ implementation coming from
86 :epkg:`onnxruntime`.
87 `tree_ensemble_classifier.cc
88 <https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/providers/cpu/ml/tree_ensemble_classifier.cc>`_.
89 See class :class:`RuntimeTreeEnsembleClassifier
90 <mlprodict.onnxrt.ops_cpu.op_tree_ensemble_classifier_.RuntimeTreeEnsembleClassifier>`.
91 """
92 label, scores = self.rt_.compute(x)
93 if scores.shape[0] != label.shape[0]:
94 scores = scores.reshape(label.shape[0],
95 scores.shape[0] // label.shape[0])
96 return self._post_process_predicted_label(label, scores)
99class TreeEnsembleClassifier(TreeEnsembleClassifierCommon):
101 atts = OrderedDict([
102 ('base_values', numpy.empty(0, dtype=numpy.float32)),
103 ('class_ids', numpy.empty(0, dtype=numpy.int64)),
104 ('class_nodeids', numpy.empty(0, dtype=numpy.int64)),
105 ('class_treeids', numpy.empty(0, dtype=numpy.int64)),
106 ('class_weights', numpy.empty(0, dtype=numpy.float32)),
107 ('classlabels_int64s', numpy.empty(0, dtype=numpy.int64)),
108 ('classlabels_strings', []),
109 ('nodes_falsenodeids', numpy.empty(0, dtype=numpy.int64)),
110 ('nodes_featureids', numpy.empty(0, dtype=numpy.int64)),
111 ('nodes_hitrates', numpy.empty(0, dtype=numpy.float32)),
112 ('nodes_missing_value_tracks_true', numpy.empty(0, dtype=numpy.int64)),
113 ('nodes_modes', []),
114 ('nodes_nodeids', numpy.empty(0, dtype=numpy.int64)),
115 ('nodes_treeids', numpy.empty(0, dtype=numpy.int64)),
116 ('nodes_truenodeids', numpy.empty(0, dtype=numpy.int64)),
117 ('nodes_values', numpy.empty(0, dtype=numpy.float32)),
118 ('post_transform', b'NONE')
119 ])
121 def __init__(self, onnx_node, desc=None, **options):
122 TreeEnsembleClassifierCommon.__init__(
123 self, numpy.float32, onnx_node, desc=desc,
124 expected_attributes=TreeEnsembleClassifier.atts, **options)
127class TreeEnsembleClassifierDouble(TreeEnsembleClassifierCommon):
129 atts = OrderedDict([
130 ('base_values', numpy.empty(0, dtype=numpy.float64)),
131 ('class_ids', numpy.empty(0, dtype=numpy.int64)),
132 ('class_nodeids', numpy.empty(0, dtype=numpy.int64)),
133 ('class_treeids', numpy.empty(0, dtype=numpy.int64)),
134 ('class_weights', numpy.empty(0, dtype=numpy.float64)),
135 ('classlabels_int64s', numpy.empty(0, dtype=numpy.int64)),
136 ('classlabels_strings', []),
137 ('nodes_falsenodeids', numpy.empty(0, dtype=numpy.int64)),
138 ('nodes_featureids', numpy.empty(0, dtype=numpy.int64)),
139 ('nodes_hitrates', numpy.empty(0, dtype=numpy.float64)),
140 ('nodes_missing_value_tracks_true', numpy.empty(0, dtype=numpy.int64)),
141 ('nodes_modes', []),
142 ('nodes_nodeids', numpy.empty(0, dtype=numpy.int64)),
143 ('nodes_treeids', numpy.empty(0, dtype=numpy.int64)),
144 ('nodes_truenodeids', numpy.empty(0, dtype=numpy.int64)),
145 ('nodes_values', numpy.empty(0, dtype=numpy.float64)),
146 ('post_transform', b'NONE')
147 ])
149 def __init__(self, onnx_node, desc=None, **options):
150 TreeEnsembleClassifierCommon.__init__(
151 self, numpy.float64, onnx_node, desc=desc,
152 expected_attributes=TreeEnsembleClassifier.atts, **options)
155class TreeEnsembleClassifierDoubleSchema(OperatorSchema):
156 """
157 Defines a schema for operators added in this package
158 such as @see cl TreeEnsembleClassifierDouble.
159 """
161 def __init__(self):
162 OperatorSchema.__init__(self, 'TreeEnsembleClassifierDouble')
163 self.attributes = TreeEnsembleClassifierDouble.atts