Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2"""
3@file
4@brief Shortcut to *onnx_conv*.
5"""
6import warnings
7import numbers
8import numpy
9from skl2onnx import update_registered_converter
10from skl2onnx.common.shape_calculator import (
11 calculate_linear_classifier_output_shapes,
12 calculate_linear_regressor_output_shapes)
13from .scorers import register_scorers
16def _custom_parser_xgboost(scope, model, inputs, custom_parsers=None):
17 """
18 Custom parser for *XGBClassifier* and *LGBMClassifier*.
19 """
20 if custom_parsers is not None and model in custom_parsers:
21 return custom_parsers[model](
22 scope, model, inputs, custom_parsers=custom_parsers)
23 if not all(isinstance(i, (numbers.Real, bool, numpy.bool_))
24 for i in model.classes_):
25 raise NotImplementedError( # pragma: no cover
26 "Current converter does not support string labels.")
27 try:
28 from skl2onnx._parse import _parse_sklearn_classifier
29 except ImportError as e: # pragma: no cover
30 import skl2onnx
31 raise ImportError(
32 "Hidden API has changed in module 'skl2onnx=={}', "
33 "installation path is '{}'.".format(
34 skl2onnx.__version__, skl2onnx.__file__)) from e
35 return _parse_sklearn_classifier(scope, model, inputs)
38def _register_converters_lightgbm(exc=True):
39 """
40 This functions registers additional converters
41 for :epkg:`lightgbm`.
43 @param exc if True, raises an exception if a converter cannot
44 registered (missing package for example)
45 @return list of models supported by the new converters
46 """
47 registered = []
49 try:
50 from lightgbm import LGBMClassifier
51 except ImportError as e: # pragma: no cover
52 if exc:
53 raise e
54 else:
55 warnings.warn(
56 "Cannot register LGBMClassifier due to '{}'.".format(e))
57 LGBMClassifier = None
58 if LGBMClassifier is not None:
59 try:
60 from skl2onnx._parse import _parse_sklearn_classifier
61 except ImportError as e: # pragma: no cover
62 import skl2onnx
63 raise ImportError(
64 "Hidden API has changed in module 'skl2onnx=={}', "
65 "installation path is '{}'.".format(
66 skl2onnx.__version__, skl2onnx.__file__)) from e
67 from .operator_converters.conv_lightgbm import (
68 convert_lightgbm, calculate_lightgbm_output_shapes)
69 update_registered_converter(
70 LGBMClassifier, 'LgbmClassifier',
71 calculate_lightgbm_output_shapes,
72 convert_lightgbm, parser=_parse_sklearn_classifier,
73 options={'zipmap': [True, False], 'nocl': [True, False]})
74 registered.append(LGBMClassifier)
76 try:
77 from lightgbm import LGBMRegressor
78 except ImportError as e: # pragma: no cover
79 if exc:
80 raise e
81 else:
82 warnings.warn(
83 "Cannot register LGBMRegressor due to '{}'.".format(e))
84 LGBMRegressor = None
85 if LGBMRegressor is not None:
86 from .operator_converters.conv_lightgbm import convert_lightgbm
87 update_registered_converter(
88 LGBMRegressor, 'LightGbmLGBMRegressor',
89 calculate_linear_regressor_output_shapes,
90 convert_lightgbm, options={'split': None})
91 registered.append(LGBMRegressor)
93 try:
94 from lightgbm import Booster
95 except ImportError as e: # pragma: no cover
96 if exc:
97 raise e
98 else:
99 warnings.warn(
100 "Cannot register LGBMRegressor due to '{}'.".format(e))
101 Booster = None
102 if Booster is not None:
103 from .operator_converters.conv_lightgbm import (
104 convert_lightgbm, calculate_lightgbm_output_shapes)
105 from .operator_converters.parse_lightgbm import (
106 lightgbm_parser, WrappedLightGbmBooster,
107 WrappedLightGbmBoosterClassifier,
108 shape_calculator_lightgbm_concat,
109 converter_lightgbm_concat,
110 MockWrappedLightGbmBoosterClassifier)
111 update_registered_converter(
112 Booster, 'LightGbmBooster', calculate_lightgbm_output_shapes,
113 convert_lightgbm, parser=lightgbm_parser,
114 options={'cast': [True, False]})
115 update_registered_converter(
116 WrappedLightGbmBooster, 'WrapperLightGbmBooster',
117 calculate_lightgbm_output_shapes,
118 convert_lightgbm, parser=lightgbm_parser)
119 update_registered_converter(
120 WrappedLightGbmBoosterClassifier, 'WrappedLightGbmBoosterClassifier',
121 calculate_lightgbm_output_shapes,
122 convert_lightgbm, parser=lightgbm_parser,
123 options={'zipmap': [True, False], 'nocl': [True, False]})
124 update_registered_converter(
125 MockWrappedLightGbmBoosterClassifier, 'MockWrappedLightGbmBoosterClassifier',
126 calculate_lightgbm_output_shapes,
127 convert_lightgbm, parser=lightgbm_parser)
128 update_registered_converter(
129 None, 'LightGBMConcat',
130 shape_calculator_lightgbm_concat,
131 converter_lightgbm_concat)
132 registered.append(Booster)
133 registered.append(WrappedLightGbmBooster)
134 registered.append(WrappedLightGbmBoosterClassifier)
136 return registered
139def _register_converters_xgboost(exc=True):
140 """
141 This functions registers additional converters
142 for :epkg:`xgboost`.
144 @param exc if True, raises an exception if a converter cannot
145 registered (missing package for example)
146 @return list of models supported by the new converters
147 """
148 registered = []
150 try:
151 from xgboost import XGBClassifier
152 except ImportError as e: # pragma: no cover
153 if exc:
154 raise e
155 else:
156 warnings.warn(
157 "Cannot register XGBClassifier due to '{}'.".format(e))
158 XGBClassifier = None
159 if XGBClassifier is not None:
160 from .operator_converters.conv_xgboost import convert_xgboost
161 update_registered_converter(
162 XGBClassifier, 'XGBoostXGBClassifier',
163 calculate_linear_classifier_output_shapes,
164 convert_xgboost, parser=_custom_parser_xgboost,
165 options={'zipmap': [True, False], 'raw_scores': [True, False],
166 'nocl': [True, False]})
167 registered.append(XGBClassifier)
169 try:
170 from xgboost import XGBRegressor
171 except ImportError as e: # pragma: no cover
172 if exc:
173 raise e
174 else:
175 warnings.warn(
176 "Cannot register LGBMRegressor due to '{}'.".format(e))
177 XGBRegressor = None
178 if XGBRegressor is not None:
179 from .operator_converters.conv_xgboost import convert_xgboost
180 update_registered_converter(XGBRegressor, 'XGBoostXGBRegressor',
181 calculate_linear_regressor_output_shapes,
182 convert_xgboost)
183 registered.append(XGBRegressor)
184 return registered
187def _register_converters_mlinsights(exc=True):
188 """
189 This functions registers additional converters
190 for :epkg:`mlinsights`.
192 @param exc if True, raises an exception if a converter cannot
193 registered (missing package for example)
194 @return list of models supported by the new converters
195 """
196 registered = []
198 try:
199 from mlinsights.mlmodel import TransferTransformer
200 except ImportError as e: # pragma: no cover
201 if exc:
202 raise e
203 else:
204 warnings.warn(
205 "Cannot register models from 'mlinsights' due to '{}'.".format(e))
206 TransferTransformer = None
208 if TransferTransformer is not None:
209 from .operator_converters.conv_transfer_transformer import (
210 shape_calculator_transfer_transformer, convert_transfer_transformer,
211 parser_transfer_transformer)
212 update_registered_converter(
213 TransferTransformer, 'MlInsightsTransferTransformer',
214 shape_calculator_transfer_transformer,
215 convert_transfer_transformer,
216 parser=parser_transfer_transformer,
217 options='passthrough')
218 registered.append(TransferTransformer)
220 return registered
223def _register_converters_skl2onnx(exc=True):
224 """
225 This functions registers additional converters
226 for :epkg:`skl2onnx`.
228 @param exc if True, raises an exception if a converter cannot
229 registered (missing package for example)
230 @return list of models supported by the new converters
231 """
232 registered = []
234 try:
235 import skl2onnx.sklapi.register # pylint: disable=W0611
236 from skl2onnx.sklapi import WOETransformer
237 model = [WOETransformer]
238 except ImportError as e: # pragma: no cover
239 try:
240 import skl2onnx
241 from pyquickhelper.texthelper.version_helper import (
242 compare_module_version)
243 if compare_module_version(skl2onnx.__version__, '1.9.3') < 0:
244 # Too old version of skl2onnx.
245 return []
246 except ImportError:
247 pass
248 if exc:
249 raise e
250 else:
251 warnings.warn(
252 "Cannot register models from 'skl2onnx' due to %r." % e)
253 model = None
255 if model is not None:
256 registered.extend(model)
257 return registered
260def register_converters(exc=True):
261 """
262 This functions registers additional converters
263 to the list of converters :epkg:`sklearn-onnx` declares.
265 @param exc if True, raises an exception if a converter cannot
266 registered (missing package for example)
267 @return list of models supported by the new converters
268 """
269 ext = _register_converters_lightgbm(exc=exc)
270 ext += _register_converters_xgboost(exc=exc)
271 ext += _register_converters_mlinsights(exc=exc)
272 ext += _register_converters_skl2onnx(exc=exc)
273 ext += register_scorers()
274 return ext