Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Rewrites some of the converters implemented in
4:epkg:`sklearn-onnx`.
5"""
6import numpy
7from skl2onnx.operator_converters.decision_tree import (
8 convert_sklearn_decision_tree_regressor,
9 convert_sklearn_decision_tree_classifier)
10from skl2onnx.operator_converters.gradient_boosting import (
11 convert_sklearn_gradient_boosting_regressor,
12 convert_sklearn_gradient_boosting_classifier)
13from skl2onnx.operator_converters.random_forest import (
14 convert_sklearn_random_forest_classifier,
15 convert_sklearn_random_forest_regressor_converter)
16from skl2onnx.common.data_types import guess_numpy_type
19def _op_type_domain_regressor(dtype):
20 """
21 Defines *op_type* and *op_domain* based on `dtype`.
22 """
23 if dtype == numpy.float32:
24 return 'TreeEnsembleRegressor', 'ai.onnx.ml', 1
25 if dtype == numpy.float64:
26 return 'TreeEnsembleRegressorDouble', 'mlprodict', 1
27 raise RuntimeError( # pragma: no cover
28 "Unsupported dtype {}.".format(dtype))
31def _op_type_domain_classifier(dtype):
32 """
33 Defines *op_type* and *op_domain* based on `dtype`.
34 """
35 if dtype == numpy.float32:
36 return 'TreeEnsembleClassifier', 'ai.onnx.ml', 1
37 if dtype == numpy.float64:
38 return 'TreeEnsembleClassifierDouble', 'mlprodict', 1
39 raise RuntimeError( # pragma: no cover
40 "Unsupported dtype {}.".format(dtype))
43def new_convert_sklearn_decision_tree_classifier(scope, operator, container):
44 """
45 Rewrites the converters implemented in
46 :epkg:`sklearn-onnx` to support an operator supporting
47 doubles.
48 """
49 dtype = guess_numpy_type(operator.inputs[0].type)
50 if dtype != numpy.float64:
51 dtype = numpy.float32
52 op_type, op_domain, op_version = _op_type_domain_classifier(dtype)
53 convert_sklearn_decision_tree_classifier(
54 scope, operator, container, op_type=op_type, op_domain=op_domain,
55 op_version=op_version)
58def new_convert_sklearn_decision_tree_regressor(scope, operator, container):
59 """
60 Rewrites the converters implemented in
61 :epkg:`sklearn-onnx` to support an operator supporting
62 doubles.
63 """
64 dtype = guess_numpy_type(operator.inputs[0].type)
65 if dtype != numpy.float64:
66 dtype = numpy.float32
67 op_type, op_domain, op_version = _op_type_domain_regressor(dtype)
68 convert_sklearn_decision_tree_regressor(
69 scope, operator, container, op_type=op_type, op_domain=op_domain,
70 op_version=op_version)
73def new_convert_sklearn_gradient_boosting_classifier(scope, operator, container):
74 """
75 Rewrites the converters implemented in
76 :epkg:`sklearn-onnx` to support an operator supporting
77 doubles.
78 """
79 dtype = guess_numpy_type(operator.inputs[0].type)
80 if dtype != numpy.float64:
81 dtype = numpy.float32
82 op_type, op_domain, op_version = _op_type_domain_classifier(dtype)
83 convert_sklearn_gradient_boosting_classifier(
84 scope, operator, container, op_type=op_type, op_domain=op_domain,
85 op_version=op_version)
88def new_convert_sklearn_gradient_boosting_regressor(scope, operator, container):
89 """
90 Rewrites the converters implemented in
91 :epkg:`sklearn-onnx` to support an operator supporting
92 doubles.
93 """
94 dtype = guess_numpy_type(operator.inputs[0].type)
95 if dtype != numpy.float64:
96 dtype = numpy.float32
97 op_type, op_domain, op_version = _op_type_domain_regressor(dtype)
98 convert_sklearn_gradient_boosting_regressor(
99 scope, operator, container, op_type=op_type, op_domain=op_domain,
100 op_version=op_version)
103def new_convert_sklearn_random_forest_classifier(scope, operator, container):
104 """
105 Rewrites the converters implemented in
106 :epkg:`sklearn-onnx` to support an operator supporting
107 doubles.
108 """
109 dtype = guess_numpy_type(operator.inputs[0].type)
110 if dtype != numpy.float64:
111 dtype = numpy.float32
112 op_type, op_domain, op_version = _op_type_domain_classifier(dtype)
113 convert_sklearn_random_forest_classifier(
114 scope, operator, container, op_type=op_type, op_domain=op_domain,
115 op_version=op_version)
118def new_convert_sklearn_random_forest_regressor(scope, operator, container):
119 """
120 Rewrites the converters implemented in
121 :epkg:`sklearn-onnx` to support an operator supporting
122 doubles.
123 """
124 dtype = guess_numpy_type(operator.inputs[0].type)
125 if dtype != numpy.float64:
126 dtype = numpy.float32
127 op_type, op_domain, op_version = _op_type_domain_regressor(dtype)
128 convert_sklearn_random_forest_regressor_converter(
129 scope, operator, container, op_type=op_type, op_domain=op_domain,
130 op_version=op_version)