Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Rewrites some of the converters implemented in 

4:epkg:`sklearn-onnx`. 

5""" 

6import numpy 

7from skl2onnx.operator_converters.decision_tree import ( 

8 convert_sklearn_decision_tree_regressor, 

9 convert_sklearn_decision_tree_classifier) 

10from skl2onnx.operator_converters.gradient_boosting import ( 

11 convert_sklearn_gradient_boosting_regressor, 

12 convert_sklearn_gradient_boosting_classifier) 

13from skl2onnx.operator_converters.random_forest import ( 

14 convert_sklearn_random_forest_classifier, 

15 convert_sklearn_random_forest_regressor_converter) 

16from skl2onnx.common.data_types import guess_numpy_type 

17 

18 

19def _op_type_domain_regressor(dtype): 

20 """ 

21 Defines *op_type* and *op_domain* based on `dtype`. 

22 """ 

23 if dtype == numpy.float32: 

24 return 'TreeEnsembleRegressor', 'ai.onnx.ml', 1 

25 if dtype == numpy.float64: 

26 return 'TreeEnsembleRegressorDouble', 'mlprodict', 1 

27 raise RuntimeError( # pragma: no cover 

28 "Unsupported dtype {}.".format(dtype)) 

29 

30 

31def _op_type_domain_classifier(dtype): 

32 """ 

33 Defines *op_type* and *op_domain* based on `dtype`. 

34 """ 

35 if dtype == numpy.float32: 

36 return 'TreeEnsembleClassifier', 'ai.onnx.ml', 1 

37 if dtype == numpy.float64: 

38 return 'TreeEnsembleClassifierDouble', 'mlprodict', 1 

39 raise RuntimeError( # pragma: no cover 

40 "Unsupported dtype {}.".format(dtype)) 

41 

42 

43def new_convert_sklearn_decision_tree_classifier(scope, operator, container): 

44 """ 

45 Rewrites the converters implemented in 

46 :epkg:`sklearn-onnx` to support an operator supporting 

47 doubles. 

48 """ 

49 dtype = guess_numpy_type(operator.inputs[0].type) 

50 if dtype != numpy.float64: 

51 dtype = numpy.float32 

52 op_type, op_domain, op_version = _op_type_domain_classifier(dtype) 

53 convert_sklearn_decision_tree_classifier( 

54 scope, operator, container, op_type=op_type, op_domain=op_domain, 

55 op_version=op_version) 

56 

57 

58def new_convert_sklearn_decision_tree_regressor(scope, operator, container): 

59 """ 

60 Rewrites the converters implemented in 

61 :epkg:`sklearn-onnx` to support an operator supporting 

62 doubles. 

63 """ 

64 dtype = guess_numpy_type(operator.inputs[0].type) 

65 if dtype != numpy.float64: 

66 dtype = numpy.float32 

67 op_type, op_domain, op_version = _op_type_domain_regressor(dtype) 

68 convert_sklearn_decision_tree_regressor( 

69 scope, operator, container, op_type=op_type, op_domain=op_domain, 

70 op_version=op_version) 

71 

72 

73def new_convert_sklearn_gradient_boosting_classifier(scope, operator, container): 

74 """ 

75 Rewrites the converters implemented in 

76 :epkg:`sklearn-onnx` to support an operator supporting 

77 doubles. 

78 """ 

79 dtype = guess_numpy_type(operator.inputs[0].type) 

80 if dtype != numpy.float64: 

81 dtype = numpy.float32 

82 op_type, op_domain, op_version = _op_type_domain_classifier(dtype) 

83 convert_sklearn_gradient_boosting_classifier( 

84 scope, operator, container, op_type=op_type, op_domain=op_domain, 

85 op_version=op_version) 

86 

87 

88def new_convert_sklearn_gradient_boosting_regressor(scope, operator, container): 

89 """ 

90 Rewrites the converters implemented in 

91 :epkg:`sklearn-onnx` to support an operator supporting 

92 doubles. 

93 """ 

94 dtype = guess_numpy_type(operator.inputs[0].type) 

95 if dtype != numpy.float64: 

96 dtype = numpy.float32 

97 op_type, op_domain, op_version = _op_type_domain_regressor(dtype) 

98 convert_sklearn_gradient_boosting_regressor( 

99 scope, operator, container, op_type=op_type, op_domain=op_domain, 

100 op_version=op_version) 

101 

102 

103def new_convert_sklearn_random_forest_classifier(scope, operator, container): 

104 """ 

105 Rewrites the converters implemented in 

106 :epkg:`sklearn-onnx` to support an operator supporting 

107 doubles. 

108 """ 

109 dtype = guess_numpy_type(operator.inputs[0].type) 

110 if dtype != numpy.float64: 

111 dtype = numpy.float32 

112 op_type, op_domain, op_version = _op_type_domain_classifier(dtype) 

113 convert_sklearn_random_forest_classifier( 

114 scope, operator, container, op_type=op_type, op_domain=op_domain, 

115 op_version=op_version) 

116 

117 

118def new_convert_sklearn_random_forest_regressor(scope, operator, container): 

119 """ 

120 Rewrites the converters implemented in 

121 :epkg:`sklearn-onnx` to support an operator supporting 

122 doubles. 

123 """ 

124 dtype = guess_numpy_type(operator.inputs[0].type) 

125 if dtype != numpy.float64: 

126 dtype = numpy.float32 

127 op_type, op_domain, op_version = _op_type_domain_regressor(dtype) 

128 convert_sklearn_random_forest_regressor_converter( 

129 scope, operator, container, op_type=op_type, op_domain=op_domain, 

130 op_version=op_version)