Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2"""
3@file
4@brief Registers new converters.
5"""
6import copy
7from sklearn.base import BaseEstimator, TransformerMixin
8from sklearn import __all__ as sklearn__all__, __version__ as sklearn_version
9from skl2onnx import (
10 update_registered_converter,
11 update_registered_parser)
12from skl2onnx.common.data_types import guess_tensor_type
13from skl2onnx.common._apply_operation import apply_identity
16class CustomScorerTransform(BaseEstimator, TransformerMixin):
17 """
18 Wraps a scoring function into a transformer. Function @see fn
19 register_scorers must be called to register the converter
20 associated to this transform. It takes two inputs, expected values
21 and predicted values and returns a score for each observation.
22 """
24 def __init__(self, name, fct, kwargs):
25 """
26 @param name function name
27 @param fct python function
28 @param kwargs parameters function
29 """
30 BaseEstimator.__init__(self)
31 TransformerMixin.__init__(self)
32 self.name_fct = name
33 self._fct = fct
34 self.kwargs = kwargs
36 def __repr__(self): # pylint: disable=W0222
37 return "{}('{}', {}, {})".format(
38 self.__class__.__name__, self.name_fct,
39 self._fct.__name__, self.kwargs)
42def custom_scorer_transform_parser(scope, model, inputs, custom_parsers=None):
43 """
44 This function updates the inputs and the outputs for
45 a @see cl CustomScorerTransform.
47 :param scope: Scope object
48 :param model: A scikit-learn object (e.g., *OneHotEncoder*
49 or *LogisticRegression*)
50 :param inputs: A list of variables
51 :param custom_parsers: parsers determines which outputs is expected
52 for which particular task, default parsers are defined for
53 classifiers, regressors, pipeline but they can be rewritten,
54 *custom_parsers* is a dictionary
55 ``{ type: fct_parser(scope, model, inputs, custom_parsers=None) }``
56 :return: A list of output variables which will be passed to next
57 stage
58 """
59 if custom_parsers is not None: # pragma: no cover
60 raise NotImplementedError(
61 "Case custom_parsers not empty is not implemented yet.")
62 if isinstance(model, str):
63 raise RuntimeError( # pragma: no cover
64 "Parameter model must be an object not a "
65 "string '{0}'.".format(model))
66 if len(inputs) != 2:
67 raise RuntimeError( # pragma: no cover
68 "Two inputs expected not {}.".format(len(inputs)))
69 alias = 'Mlprodict' + model.__class__.__name__
70 this_operator = scope.declare_local_operator(alias, model)
71 this_operator.inputs = inputs
73 scores = scope.declare_local_variable(
74 'scores', guess_tensor_type(inputs[0].type))
75 this_operator.outputs.append(scores)
76 return this_operator.outputs
79def custom_scorer_transform_shape_calculator(operator):
80 """
81 Computes the output shapes for a @see cl CustomScorerTransform.
82 """
83 if len(operator.inputs) != 2:
84 raise RuntimeError("Two inputs expected.") # pragma: no cover
85 if len(operator.outputs) != 1:
86 raise RuntimeError("One output expected.") # pragma: no cover
88 N = operator.inputs[0].type.shape[0]
89 operator.outputs[0].type = copy.deepcopy(operator.inputs[0].type)
90 operator.outputs[0].type.shape = [N, 1]
93def custom_scorer_transform_converter(scope, operator, container):
94 """
95 Selects the appropriate converter for a @see cl CustomScorerTransform.
96 """
97 op = operator.raw_operator
98 name = op.name_fct
99 this_operator = scope.declare_local_operator('fct_' + name)
100 this_operator.raw_operator = op
101 this_operator.inputs = operator.inputs
103 score_name = scope.declare_local_variable(
104 'scores', operator.inputs[0].type)
105 this_operator.outputs.append(score_name)
106 apply_identity(scope, score_name.full_name,
107 operator.outputs[0].full_name, container)
110def empty_shape_calculator(operator):
111 """
112 Does nothing.
113 """
114 pass
117def register_scorers():
118 """
119 Registers operators for @see cl CustomScorerTransform.
120 """
121 from .cdist_score import score_cdist_sum, convert_score_cdist_sum
122 done = []
123 update_registered_parser(
124 CustomScorerTransform,
125 custom_scorer_transform_parser)
127 update_registered_converter(
128 CustomScorerTransform,
129 'MlprodictCustomScorerTransform',
130 custom_scorer_transform_shape_calculator,
131 custom_scorer_transform_converter)
132 done.append(CustomScorerTransform)
134 update_registered_converter(
135 score_cdist_sum, 'fct_score_cdist_sum',
136 empty_shape_calculator, convert_score_cdist_sum,
137 options={'cdist': [None, 'single-node']})
139 return done