Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# -*- encoding: utf-8 -*- 

2""" 

3@file 

4@brief Shortcut to *onnx_conv*. 

5""" 

6import warnings 

7import numbers 

8import numpy 

9from skl2onnx import update_registered_converter 

10from skl2onnx.common.shape_calculator import ( 

11 calculate_linear_classifier_output_shapes, 

12 calculate_linear_regressor_output_shapes) 

13from .scorers import register_scorers 

14 

15 

16def _custom_parser_xgboost(scope, model, inputs, custom_parsers=None): 

17 """ 

18 Custom parser for *XGBClassifier* and *LGBMClassifier*. 

19 """ 

20 if custom_parsers is not None and model in custom_parsers: 

21 return custom_parsers[model]( 

22 scope, model, inputs, custom_parsers=custom_parsers) 

23 if not all(isinstance(i, (numbers.Real, bool, numpy.bool_)) 

24 for i in model.classes_): 

25 raise NotImplementedError( # pragma: no cover 

26 "Current converter does not support string labels.") 

27 try: 

28 from skl2onnx._parse import _parse_sklearn_classifier 

29 except ImportError as e: # pragma: no cover 

30 import skl2onnx 

31 raise ImportError( 

32 "Hidden API has changed in module 'skl2onnx=={}', " 

33 "installation path is '{}'.".format( 

34 skl2onnx.__version__, skl2onnx.__file__)) from e 

35 return _parse_sklearn_classifier(scope, model, inputs) 

36 

37 

38def _register_converters_lightgbm(exc=True): 

39 """ 

40 This functions registers additional converters 

41 for :epkg:`lightgbm`. 

42 

43 @param exc if True, raises an exception if a converter cannot 

44 registered (missing package for example) 

45 @return list of models supported by the new converters 

46 """ 

47 registered = [] 

48 

49 try: 

50 from lightgbm import LGBMClassifier 

51 except ImportError as e: # pragma: no cover 

52 if exc: 

53 raise e 

54 else: 

55 warnings.warn( 

56 "Cannot register LGBMClassifier due to '{}'.".format(e)) 

57 LGBMClassifier = None 

58 if LGBMClassifier is not None: 

59 try: 

60 from skl2onnx._parse import _parse_sklearn_classifier 

61 except ImportError as e: # pragma: no cover 

62 import skl2onnx 

63 raise ImportError( 

64 "Hidden API has changed in module 'skl2onnx=={}', " 

65 "installation path is '{}'.".format( 

66 skl2onnx.__version__, skl2onnx.__file__)) from e 

67 from .operator_converters.conv_lightgbm import ( 

68 convert_lightgbm, calculate_lightgbm_output_shapes) 

69 update_registered_converter( 

70 LGBMClassifier, 'LgbmClassifier', 

71 calculate_lightgbm_output_shapes, 

72 convert_lightgbm, parser=_parse_sklearn_classifier, 

73 options={'zipmap': [True, False], 'nocl': [True, False]}) 

74 registered.append(LGBMClassifier) 

75 

76 try: 

77 from lightgbm import LGBMRegressor 

78 except ImportError as e: # pragma: no cover 

79 if exc: 

80 raise e 

81 else: 

82 warnings.warn( 

83 "Cannot register LGBMRegressor due to '{}'.".format(e)) 

84 LGBMRegressor = None 

85 if LGBMRegressor is not None: 

86 from .operator_converters.conv_lightgbm import convert_lightgbm 

87 update_registered_converter( 

88 LGBMRegressor, 'LightGbmLGBMRegressor', 

89 calculate_linear_regressor_output_shapes, 

90 convert_lightgbm, options={'split': None}) 

91 registered.append(LGBMRegressor) 

92 

93 try: 

94 from lightgbm import Booster 

95 except ImportError as e: # pragma: no cover 

96 if exc: 

97 raise e 

98 else: 

99 warnings.warn( 

100 "Cannot register LGBMRegressor due to '{}'.".format(e)) 

101 Booster = None 

102 if Booster is not None: 

103 from .operator_converters.conv_lightgbm import ( 

104 convert_lightgbm, calculate_lightgbm_output_shapes) 

105 from .operator_converters.parse_lightgbm import ( 

106 lightgbm_parser, WrappedLightGbmBooster, 

107 WrappedLightGbmBoosterClassifier, 

108 shape_calculator_lightgbm_concat, 

109 converter_lightgbm_concat, 

110 MockWrappedLightGbmBoosterClassifier) 

111 update_registered_converter( 

112 Booster, 'LightGbmBooster', calculate_lightgbm_output_shapes, 

113 convert_lightgbm, parser=lightgbm_parser, 

114 options={'cast': [True, False]}) 

115 update_registered_converter( 

116 WrappedLightGbmBooster, 'WrapperLightGbmBooster', 

117 calculate_lightgbm_output_shapes, 

118 convert_lightgbm, parser=lightgbm_parser) 

119 update_registered_converter( 

120 WrappedLightGbmBoosterClassifier, 'WrappedLightGbmBoosterClassifier', 

121 calculate_lightgbm_output_shapes, 

122 convert_lightgbm, parser=lightgbm_parser, 

123 options={'zipmap': [True, False], 'nocl': [True, False]}) 

124 update_registered_converter( 

125 MockWrappedLightGbmBoosterClassifier, 'MockWrappedLightGbmBoosterClassifier', 

126 calculate_lightgbm_output_shapes, 

127 convert_lightgbm, parser=lightgbm_parser) 

128 update_registered_converter( 

129 None, 'LightGBMConcat', 

130 shape_calculator_lightgbm_concat, 

131 converter_lightgbm_concat) 

132 registered.append(Booster) 

133 registered.append(WrappedLightGbmBooster) 

134 registered.append(WrappedLightGbmBoosterClassifier) 

135 

136 return registered 

137 

138 

139def _register_converters_xgboost(exc=True): 

140 """ 

141 This functions registers additional converters 

142 for :epkg:`xgboost`. 

143 

144 @param exc if True, raises an exception if a converter cannot 

145 registered (missing package for example) 

146 @return list of models supported by the new converters 

147 """ 

148 registered = [] 

149 

150 try: 

151 from xgboost import XGBClassifier 

152 except ImportError as e: # pragma: no cover 

153 if exc: 

154 raise e 

155 else: 

156 warnings.warn( 

157 "Cannot register XGBClassifier due to '{}'.".format(e)) 

158 XGBClassifier = None 

159 if XGBClassifier is not None: 

160 from .operator_converters.conv_xgboost import convert_xgboost 

161 update_registered_converter( 

162 XGBClassifier, 'XGBoostXGBClassifier', 

163 calculate_linear_classifier_output_shapes, 

164 convert_xgboost, parser=_custom_parser_xgboost, 

165 options={'zipmap': [True, False], 'raw_scores': [True, False], 

166 'nocl': [True, False]}) 

167 registered.append(XGBClassifier) 

168 

169 try: 

170 from xgboost import XGBRegressor 

171 except ImportError as e: # pragma: no cover 

172 if exc: 

173 raise e 

174 else: 

175 warnings.warn( 

176 "Cannot register LGBMRegressor due to '{}'.".format(e)) 

177 XGBRegressor = None 

178 if XGBRegressor is not None: 

179 from .operator_converters.conv_xgboost import convert_xgboost 

180 update_registered_converter(XGBRegressor, 'XGBoostXGBRegressor', 

181 calculate_linear_regressor_output_shapes, 

182 convert_xgboost) 

183 registered.append(XGBRegressor) 

184 return registered 

185 

186 

187def _register_converters_mlinsights(exc=True): 

188 """ 

189 This functions registers additional converters 

190 for :epkg:`mlinsights`. 

191 

192 @param exc if True, raises an exception if a converter cannot 

193 registered (missing package for example) 

194 @return list of models supported by the new converters 

195 """ 

196 registered = [] 

197 

198 try: 

199 from mlinsights.mlmodel import TransferTransformer 

200 except ImportError as e: # pragma: no cover 

201 if exc: 

202 raise e 

203 else: 

204 warnings.warn( 

205 "Cannot register models from 'mlinsights' due to '{}'.".format(e)) 

206 TransferTransformer = None 

207 

208 if TransferTransformer is not None: 

209 from .operator_converters.conv_transfer_transformer import ( 

210 shape_calculator_transfer_transformer, convert_transfer_transformer, 

211 parser_transfer_transformer) 

212 update_registered_converter( 

213 TransferTransformer, 'MlInsightsTransferTransformer', 

214 shape_calculator_transfer_transformer, 

215 convert_transfer_transformer, 

216 parser=parser_transfer_transformer, 

217 options='passthrough') 

218 registered.append(TransferTransformer) 

219 

220 return registered 

221 

222 

223def _register_converters_skl2onnx(exc=True): 

224 """ 

225 This functions registers additional converters 

226 for :epkg:`skl2onnx`. 

227 

228 @param exc if True, raises an exception if a converter cannot 

229 registered (missing package for example) 

230 @return list of models supported by the new converters 

231 """ 

232 registered = [] 

233 

234 try: 

235 import skl2onnx.sklapi.register # pylint: disable=W0611 

236 from skl2onnx.sklapi import WOETransformer 

237 model = [WOETransformer] 

238 except ImportError as e: # pragma: no cover 

239 try: 

240 import skl2onnx 

241 from pyquickhelper.texthelper.version_helper import ( 

242 compare_module_version) 

243 if compare_module_version(skl2onnx.__version__, '1.9.3') < 0: 

244 # Too old version of skl2onnx. 

245 return [] 

246 except ImportError: 

247 pass 

248 if exc: 

249 raise e 

250 else: 

251 warnings.warn( 

252 "Cannot register models from 'skl2onnx' due to %r." % e) 

253 model = None 

254 

255 if model is not None: 

256 registered.extend(model) 

257 return registered 

258 

259 

260def register_converters(exc=True): 

261 """ 

262 This functions registers additional converters 

263 to the list of converters :epkg:`sklearn-onnx` declares. 

264 

265 @param exc if True, raises an exception if a converter cannot 

266 registered (missing package for example) 

267 @return list of models supported by the new converters 

268 """ 

269 ext = _register_converters_lightgbm(exc=exc) 

270 ext += _register_converters_xgboost(exc=exc) 

271 ext += _register_converters_mlinsights(exc=exc) 

272 ext += _register_converters_skl2onnx(exc=exc) 

273 ext += register_scorers() 

274 return ext