Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Rewrites some of the converters implemented in
4:epkg:`sklearn-onnx`.
5"""
6from skl2onnx.common._registration import (
7 _converter_pool, _shape_calculator_pool)
8try:
9 from skl2onnx.common._registration import RegisteredConverter
10except ImportError: # pragma: no cover
11 # sklearn-onnx <= 1.6.0
12 RegisteredConverter = lambda fct, opts: fct
13from .sklconv.tree_converters import (
14 new_convert_sklearn_decision_tree_classifier,
15 new_convert_sklearn_decision_tree_regressor,
16 new_convert_sklearn_gradient_boosting_classifier,
17 new_convert_sklearn_gradient_boosting_regressor,
18 new_convert_sklearn_random_forest_classifier,
19 new_convert_sklearn_random_forest_regressor)
20from .sklconv.svm_converters import (
21 new_convert_sklearn_svm_classifier,
22 new_convert_sklearn_svm_regressor)
23from .sklconv.function_transformer_converters import (
24 new_calculate_sklearn_function_transformer_output_shapes,
25 new_convert_sklearn_function_transformer)
28_overwritten_operators = {
29 #
30 'SklearnOneClassSVM': RegisteredConverter(
31 new_convert_sklearn_svm_regressor,
32 _converter_pool['SklearnOneClassSVM'].get_allowed_options()),
33 'SklearnSVR': RegisteredConverter(
34 new_convert_sklearn_svm_regressor,
35 _converter_pool['SklearnSVR'].get_allowed_options()),
36 'SklearnSVC': RegisteredConverter(
37 new_convert_sklearn_svm_classifier,
38 _converter_pool['SklearnSVC'].get_allowed_options()),
39 #
40 'SklearnDecisionTreeRegressor': RegisteredConverter(
41 new_convert_sklearn_decision_tree_regressor,
42 _converter_pool['SklearnDecisionTreeRegressor'].get_allowed_options()),
43 'SklearnDecisionTreeClassifier': RegisteredConverter(
44 new_convert_sklearn_decision_tree_classifier,
45 _converter_pool['SklearnDecisionTreeClassifier'].get_allowed_options()),
46 #
47 'SklearnExtraTreeRegressor': RegisteredConverter(
48 new_convert_sklearn_decision_tree_regressor,
49 _converter_pool['SklearnExtraTreeRegressor'].get_allowed_options()),
50 'SklearnExtraTreeClassifier': RegisteredConverter(
51 new_convert_sklearn_decision_tree_classifier,
52 _converter_pool['SklearnExtraTreeClassifier'].get_allowed_options()),
53 #
54 'SklearnExtraTreesRegressor': RegisteredConverter(
55 new_convert_sklearn_random_forest_regressor,
56 _converter_pool['SklearnExtraTreesRegressor'].get_allowed_options()),
57 'SklearnExtraTreesClassifier': RegisteredConverter(
58 new_convert_sklearn_random_forest_classifier,
59 _converter_pool['SklearnExtraTreesClassifier'].get_allowed_options()),
60 #
61 'SklearnFunctionTransformer': RegisteredConverter(
62 new_convert_sklearn_function_transformer,
63 _converter_pool['SklearnFunctionTransformer'].get_allowed_options()),
64 #
65 'SklearnGradientBoostingRegressor': RegisteredConverter(
66 new_convert_sklearn_gradient_boosting_regressor,
67 _converter_pool['SklearnGradientBoostingRegressor'].get_allowed_options()),
68 'SklearnGradientBoostingClassifier': RegisteredConverter(
69 new_convert_sklearn_gradient_boosting_classifier,
70 _converter_pool['SklearnGradientBoostingClassifier'].get_allowed_options()),
71 #
72 'SklearnHistGradientBoostingRegressor': RegisteredConverter(
73 new_convert_sklearn_random_forest_regressor,
74 _converter_pool['SklearnHistGradientBoostingRegressor'].get_allowed_options()),
75 'SklearnHistGradientBoostingClassifier': RegisteredConverter(
76 new_convert_sklearn_random_forest_classifier,
77 _converter_pool['SklearnHistGradientBoostingClassifier'].get_allowed_options()),
78 #
79 'SklearnRandomForestRegressor': RegisteredConverter(
80 new_convert_sklearn_random_forest_regressor,
81 _converter_pool['SklearnRandomForestRegressor'].get_allowed_options()),
82 'SklearnRandomForestClassifier': RegisteredConverter(
83 new_convert_sklearn_random_forest_classifier,
84 _converter_pool['SklearnRandomForestClassifier'].get_allowed_options()),
85}
87_overwritten_shape_calculator = {
88 "SklearnFunctionTransformer":
89 new_calculate_sklearn_function_transformer_output_shapes,
90}
93def register_rewritten_operators(new_converters=None,
94 new_shape_calculators=None):
95 """
96 Registers modified operators and returns the old values.
98 :param new_converters: converters to rewrite or None
99 to rewrite default ones
100 :param new_shape_calculators: shape calculators to rewrite or
101 None to rewrite default ones
102 @return old converters, old shape calculators
103 """
104 old_conv = None
105 old_shape = None
107 if new_converters is None:
108 for rew in _overwritten_operators:
109 if rew not in _converter_pool:
110 raise KeyError( # pragma: no cover
111 "skl2onnx was not imported and '{}' was not registered."
112 "".format(rew))
113 old_conv = {k: _converter_pool[k] for k in _overwritten_operators}
114 _converter_pool.update(_overwritten_operators)
115 else:
116 for rew in new_converters:
117 if rew not in _converter_pool:
118 raise KeyError( # pragma: no cover
119 "skl2onnx was not imported and '{}' was not registered."
120 "".format(rew))
121 old_conv = {k: _converter_pool[k] for k in new_converters}
122 _converter_pool.update(new_converters)
124 if new_shape_calculators is None:
125 for rew in _overwritten_shape_calculator:
126 if rew not in _shape_calculator_pool:
127 raise KeyError( # pragma: no cover
128 "skl2onnx was not imported and '{}' was not registered."
129 "".format(rew))
130 old_shape = {k: _shape_calculator_pool[k]
131 for k in _overwritten_shape_calculator}
132 _shape_calculator_pool.update(_overwritten_shape_calculator)
133 else:
134 for rew in new_shape_calculators:
135 if rew not in _shape_calculator_pool:
136 raise KeyError( # pragma: no cover
137 "skl2onnx was not imported and '{}' was not registered."
138 "".format(rew))
139 old_shape = {k: _shape_calculator_pool[k]
140 for k in new_shape_calculators}
141 _shape_calculator_pool.update(new_shape_calculators)
143 return old_conv, old_shape