Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Rewrites some of the converters implemented in 

4:epkg:`sklearn-onnx`. 

5""" 

6from skl2onnx.common._registration import ( 

7 _converter_pool, _shape_calculator_pool) 

8try: 

9 from skl2onnx.common._registration import RegisteredConverter 

10except ImportError: # pragma: no cover 

11 # sklearn-onnx <= 1.6.0 

12 RegisteredConverter = lambda fct, opts: fct 

13from .sklconv.tree_converters import ( 

14 new_convert_sklearn_decision_tree_classifier, 

15 new_convert_sklearn_decision_tree_regressor, 

16 new_convert_sklearn_gradient_boosting_classifier, 

17 new_convert_sklearn_gradient_boosting_regressor, 

18 new_convert_sklearn_random_forest_classifier, 

19 new_convert_sklearn_random_forest_regressor) 

20from .sklconv.svm_converters import ( 

21 new_convert_sklearn_svm_classifier, 

22 new_convert_sklearn_svm_regressor) 

23from .sklconv.function_transformer_converters import ( 

24 new_calculate_sklearn_function_transformer_output_shapes, 

25 new_convert_sklearn_function_transformer) 

26 

27 

28_overwritten_operators = { 

29 # 

30 'SklearnOneClassSVM': RegisteredConverter( 

31 new_convert_sklearn_svm_regressor, 

32 _converter_pool['SklearnOneClassSVM'].get_allowed_options()), 

33 'SklearnSVR': RegisteredConverter( 

34 new_convert_sklearn_svm_regressor, 

35 _converter_pool['SklearnSVR'].get_allowed_options()), 

36 'SklearnSVC': RegisteredConverter( 

37 new_convert_sklearn_svm_classifier, 

38 _converter_pool['SklearnSVC'].get_allowed_options()), 

39 # 

40 'SklearnDecisionTreeRegressor': RegisteredConverter( 

41 new_convert_sklearn_decision_tree_regressor, 

42 _converter_pool['SklearnDecisionTreeRegressor'].get_allowed_options()), 

43 'SklearnDecisionTreeClassifier': RegisteredConverter( 

44 new_convert_sklearn_decision_tree_classifier, 

45 _converter_pool['SklearnDecisionTreeClassifier'].get_allowed_options()), 

46 # 

47 'SklearnExtraTreeRegressor': RegisteredConverter( 

48 new_convert_sklearn_decision_tree_regressor, 

49 _converter_pool['SklearnExtraTreeRegressor'].get_allowed_options()), 

50 'SklearnExtraTreeClassifier': RegisteredConverter( 

51 new_convert_sklearn_decision_tree_classifier, 

52 _converter_pool['SklearnExtraTreeClassifier'].get_allowed_options()), 

53 # 

54 'SklearnExtraTreesRegressor': RegisteredConverter( 

55 new_convert_sklearn_random_forest_regressor, 

56 _converter_pool['SklearnExtraTreesRegressor'].get_allowed_options()), 

57 'SklearnExtraTreesClassifier': RegisteredConverter( 

58 new_convert_sklearn_random_forest_classifier, 

59 _converter_pool['SklearnExtraTreesClassifier'].get_allowed_options()), 

60 # 

61 'SklearnFunctionTransformer': RegisteredConverter( 

62 new_convert_sklearn_function_transformer, 

63 _converter_pool['SklearnFunctionTransformer'].get_allowed_options()), 

64 # 

65 'SklearnGradientBoostingRegressor': RegisteredConverter( 

66 new_convert_sklearn_gradient_boosting_regressor, 

67 _converter_pool['SklearnGradientBoostingRegressor'].get_allowed_options()), 

68 'SklearnGradientBoostingClassifier': RegisteredConverter( 

69 new_convert_sklearn_gradient_boosting_classifier, 

70 _converter_pool['SklearnGradientBoostingClassifier'].get_allowed_options()), 

71 # 

72 'SklearnHistGradientBoostingRegressor': RegisteredConverter( 

73 new_convert_sklearn_random_forest_regressor, 

74 _converter_pool['SklearnHistGradientBoostingRegressor'].get_allowed_options()), 

75 'SklearnHistGradientBoostingClassifier': RegisteredConverter( 

76 new_convert_sklearn_random_forest_classifier, 

77 _converter_pool['SklearnHistGradientBoostingClassifier'].get_allowed_options()), 

78 # 

79 'SklearnRandomForestRegressor': RegisteredConverter( 

80 new_convert_sklearn_random_forest_regressor, 

81 _converter_pool['SklearnRandomForestRegressor'].get_allowed_options()), 

82 'SklearnRandomForestClassifier': RegisteredConverter( 

83 new_convert_sklearn_random_forest_classifier, 

84 _converter_pool['SklearnRandomForestClassifier'].get_allowed_options()), 

85} 

86 

87_overwritten_shape_calculator = { 

88 "SklearnFunctionTransformer": 

89 new_calculate_sklearn_function_transformer_output_shapes, 

90} 

91 

92 

93def register_rewritten_operators(new_converters=None, 

94 new_shape_calculators=None): 

95 """ 

96 Registers modified operators and returns the old values. 

97 

98 :param new_converters: converters to rewrite or None 

99 to rewrite default ones 

100 :param new_shape_calculators: shape calculators to rewrite or 

101 None to rewrite default ones 

102 @return old converters, old shape calculators 

103 """ 

104 old_conv = None 

105 old_shape = None 

106 

107 if new_converters is None: 

108 for rew in _overwritten_operators: 

109 if rew not in _converter_pool: 

110 raise KeyError( # pragma: no cover 

111 "skl2onnx was not imported and '{}' was not registered." 

112 "".format(rew)) 

113 old_conv = {k: _converter_pool[k] for k in _overwritten_operators} 

114 _converter_pool.update(_overwritten_operators) 

115 else: 

116 for rew in new_converters: 

117 if rew not in _converter_pool: 

118 raise KeyError( # pragma: no cover 

119 "skl2onnx was not imported and '{}' was not registered." 

120 "".format(rew)) 

121 old_conv = {k: _converter_pool[k] for k in new_converters} 

122 _converter_pool.update(new_converters) 

123 

124 if new_shape_calculators is None: 

125 for rew in _overwritten_shape_calculator: 

126 if rew not in _shape_calculator_pool: 

127 raise KeyError( # pragma: no cover 

128 "skl2onnx was not imported and '{}' was not registered." 

129 "".format(rew)) 

130 old_shape = {k: _shape_calculator_pool[k] 

131 for k in _overwritten_shape_calculator} 

132 _shape_calculator_pool.update(_overwritten_shape_calculator) 

133 else: 

134 for rew in new_shape_calculators: 

135 if rew not in _shape_calculator_pool: 

136 raise KeyError( # pragma: no cover 

137 "skl2onnx was not imported and '{}' was not registered." 

138 "".format(rew)) 

139 old_shape = {k: _shape_calculator_pool[k] 

140 for k in new_shape_calculators} 

141 _shape_calculator_pool.update(new_shape_calculators) 

142 

143 return old_conv, old_shape