Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# -*- encoding: utf-8 -*- 

2""" 

3@file 

4@brief Registers new converters. 

5""" 

6import copy 

7from sklearn.base import BaseEstimator, TransformerMixin 

8from sklearn import __all__ as sklearn__all__, __version__ as sklearn_version 

9from skl2onnx import ( 

10 update_registered_converter, 

11 update_registered_parser) 

12from skl2onnx.common.data_types import guess_tensor_type 

13from skl2onnx.common._apply_operation import apply_identity 

14 

15 

16class CustomScorerTransform(BaseEstimator, TransformerMixin): 

17 """ 

18 Wraps a scoring function into a transformer. Function @see fn 

19 register_scorers must be called to register the converter 

20 associated to this transform. It takes two inputs, expected values 

21 and predicted values and returns a score for each observation. 

22 """ 

23 

24 def __init__(self, name, fct, kwargs): 

25 """ 

26 @param name function name 

27 @param fct python function 

28 @param kwargs parameters function 

29 """ 

30 BaseEstimator.__init__(self) 

31 TransformerMixin.__init__(self) 

32 self.name_fct = name 

33 self._fct = fct 

34 self.kwargs = kwargs 

35 

36 def __repr__(self): # pylint: disable=W0222 

37 return "{}('{}', {}, {})".format( 

38 self.__class__.__name__, self.name_fct, 

39 self._fct.__name__, self.kwargs) 

40 

41 

42def custom_scorer_transform_parser(scope, model, inputs, custom_parsers=None): 

43 """ 

44 This function updates the inputs and the outputs for 

45 a @see cl CustomScorerTransform. 

46 

47 :param scope: Scope object 

48 :param model: A scikit-learn object (e.g., *OneHotEncoder* 

49 or *LogisticRegression*) 

50 :param inputs: A list of variables 

51 :param custom_parsers: parsers determines which outputs is expected 

52 for which particular task, default parsers are defined for 

53 classifiers, regressors, pipeline but they can be rewritten, 

54 *custom_parsers* is a dictionary 

55 ``{ type: fct_parser(scope, model, inputs, custom_parsers=None) }`` 

56 :return: A list of output variables which will be passed to next 

57 stage 

58 """ 

59 if custom_parsers is not None: # pragma: no cover 

60 raise NotImplementedError( 

61 "Case custom_parsers not empty is not implemented yet.") 

62 if isinstance(model, str): 

63 raise RuntimeError( # pragma: no cover 

64 "Parameter model must be an object not a " 

65 "string '{0}'.".format(model)) 

66 if len(inputs) != 2: 

67 raise RuntimeError( # pragma: no cover 

68 "Two inputs expected not {}.".format(len(inputs))) 

69 alias = 'Mlprodict' + model.__class__.__name__ 

70 this_operator = scope.declare_local_operator(alias, model) 

71 this_operator.inputs = inputs 

72 

73 scores = scope.declare_local_variable( 

74 'scores', guess_tensor_type(inputs[0].type)) 

75 this_operator.outputs.append(scores) 

76 return this_operator.outputs 

77 

78 

79def custom_scorer_transform_shape_calculator(operator): 

80 """ 

81 Computes the output shapes for a @see cl CustomScorerTransform. 

82 """ 

83 if len(operator.inputs) != 2: 

84 raise RuntimeError("Two inputs expected.") # pragma: no cover 

85 if len(operator.outputs) != 1: 

86 raise RuntimeError("One output expected.") # pragma: no cover 

87 

88 N = operator.inputs[0].type.shape[0] 

89 operator.outputs[0].type = copy.deepcopy(operator.inputs[0].type) 

90 operator.outputs[0].type.shape = [N, 1] 

91 

92 

93def custom_scorer_transform_converter(scope, operator, container): 

94 """ 

95 Selects the appropriate converter for a @see cl CustomScorerTransform. 

96 """ 

97 op = operator.raw_operator 

98 name = op.name_fct 

99 this_operator = scope.declare_local_operator('fct_' + name) 

100 this_operator.raw_operator = op 

101 this_operator.inputs = operator.inputs 

102 

103 score_name = scope.declare_local_variable( 

104 'scores', operator.inputs[0].type) 

105 this_operator.outputs.append(score_name) 

106 apply_identity(scope, score_name.full_name, 

107 operator.outputs[0].full_name, container) 

108 

109 

110def empty_shape_calculator(operator): 

111 """ 

112 Does nothing. 

113 """ 

114 pass 

115 

116 

117def register_scorers(): 

118 """ 

119 Registers operators for @see cl CustomScorerTransform. 

120 """ 

121 from .cdist_score import score_cdist_sum, convert_score_cdist_sum 

122 done = [] 

123 update_registered_parser( 

124 CustomScorerTransform, 

125 custom_scorer_transform_parser) 

126 

127 update_registered_converter( 

128 CustomScorerTransform, 

129 'MlprodictCustomScorerTransform', 

130 custom_scorer_transform_shape_calculator, 

131 custom_scorer_transform_converter) 

132 done.append(CustomScorerTransform) 

133 

134 update_registered_converter( 

135 score_cdist_sum, 'fct_score_cdist_sum', 

136 empty_shape_calculator, convert_score_cdist_sum, 

137 options={'cdist': [None, 'single-node']}) 

138 

139 return done