Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Implements a transform which modifies the target
4and applies the reverse transformation on the target.
5"""
6import numpy
7from sklearn.exceptions import NotFittedError
8from sklearn.neighbors import NearestNeighbors
9from .sklearn_transform_inv import BaseReciprocalTransformer
12class FunctionReciprocalTransformer(BaseReciprocalTransformer):
13 """
14 The transform is used to apply a function on a the target,
15 predict, then transform the target back before scoring.
16 The transforms implements a series of predefined functions:
18 .. runpython::
19 :showcode:
21 import pprint
22 from mlinsights.mlmodel.sklearn_transform_inv_fct import FunctionReciprocalTransformer
23 pprint.pprint(FunctionReciprocalTransformer.available_fcts())
24 """
26 @staticmethod
27 def available_fcts():
28 """
29 Returns the list of predefined functions.
30 """
31 return {
32 'log': (numpy.log, 'exp'),
33 'exp': (numpy.exp, 'log'),
34 'log(1+x)': (lambda x: numpy.log(x + 1), 'exp(x)-1'),
35 'log1p': (numpy.log1p, 'expm1'),
36 'exp(x)-1': (lambda x: numpy.exp(x) - 1, 'log'),
37 'expm1': (numpy.expm1, 'log1p'),
38 }
40 def __init__(self, fct, fct_inv=None):
41 """
42 @param fct function name of numerical function
43 @param fct_inv optional if *fct* is a function name,
44 reciprocal function otherwise
45 """
46 BaseReciprocalTransformer.__init__(self)
47 if isinstance(fct, str):
48 if fct_inv is not None:
49 raise ValueError( # pragma: no cover
50 "If fct is a function name, fct_inv must not be specified.")
51 opts = self.__class__.available_fcts()
52 if fct not in opts:
53 raise ValueError( # pragma: no cover
54 "Unknown fct '{}', it should in {}.".format(
55 fct, list(sorted(opts))))
56 else:
57 if fct_inv is None:
58 raise ValueError(
59 "If fct is callable, fct_inv must be specified.")
60 self.fct = fct
61 self.fct_inv = fct_inv
63 def fit(self, X=None, y=None, sample_weight=None):
64 """
65 Just defines *fct* and *fct_inv*.
66 """
67 if callable(self.fct):
68 self.fct_ = self.fct
69 self.fct_inv_ = self.fct_inv
70 else:
71 opts = self.__class__.available_fcts()
72 self.fct_, self.fct_inv_ = opts[self.fct]
73 return self
75 def get_fct_inv(self):
76 """
77 Returns a trained transform which reverse the target
78 after a predictor.
79 """
80 if isinstance(self.fct_inv_, str):
81 res = FunctionReciprocalTransformer(self.fct_inv_)
82 else:
83 res = FunctionReciprocalTransformer(self.fct_inv_, self.fct_)
84 return res.fit()
86 def transform(self, X, y):
87 """
88 Transforms *X* and *y*.
89 Returns transformed *X* and *y*.
90 If *y* is None, the returned value for *y*
91 is None as well.
92 """
93 if y is None:
94 return X, None
95 return X, self.fct_(y)
98class PermutationReciprocalTransformer(BaseReciprocalTransformer):
99 """
100 The transform is used to permute targets,
101 predict, then permute the target back before scoring.
102 nan values remain nan values. Once fitted, the transform
103 has attribute ``permutation_`` which keeps
104 track of the permutation to apply.
105 """
107 def __init__(self, random_state=None, closest=False):
108 """
109 @param random_state random state
110 @param closest if True, finds the closest permuted element
111 """
112 BaseReciprocalTransformer.__init__(self)
113 self.random_state = random_state
114 self.closest = closest
116 def fit(self, X=None, y=None, sample_weight=None):
117 """
118 Defines a random permutation over the targets.
119 """
120 if y is None:
121 raise RuntimeError( # pragma: no cover
122 "targets cannot be empty.")
123 num = numpy.issubdtype(y.dtype, numpy.floating)
124 perm = {}
125 for u in y.ravel():
126 if num and numpy.isnan(u):
127 continue
128 if u in perm:
129 continue
130 perm[u] = len(perm)
132 lin = numpy.arange(len(perm))
133 if self.random_state is None:
134 lin = numpy.random.permutation(lin)
135 else:
136 rs = numpy.random.RandomState( # pylint: disable=E1101
137 self.random_state) # pylint: disable=E1101
138 lin = rs.permutation(lin)
140 for u in perm:
141 perm[u] = lin[perm[u]]
142 self.permutation_ = perm
144 def _check_is_fitted(self):
145 if not hasattr(self, 'permutation_'):
146 raise NotFittedError( # pragma: no cover
147 "This instance {} is not fitted yet. Call 'fit' with "
148 "appropriate arguments before using this method.".format(
149 type(self)))
151 def get_fct_inv(self):
152 """
153 Returns a trained transform which reverse the target
154 after a predictor.
155 """
156 self._check_is_fitted()
157 res = PermutationReciprocalTransformer(
158 self.random_state, closest=self.closest)
159 res.permutation_ = {v: k for k, v in self.permutation_.items()}
160 return res
162 def _find_closest(self, cl):
163 if not hasattr(self, 'knn_'):
164 self.knn_ = NearestNeighbors(n_neighbors=1, algorithm='kd_tree')
165 self.knn_perm_ = numpy.array(list(self.permutation_))
166 self.knn_perm_ = self.knn_perm_.reshape((len(self.knn_perm_), 1))
167 self.knn_.fit(self.knn_perm_)
168 ind = self.knn_.kneighbors([[cl]], return_distance=False)
169 res = self.knn_perm_[ind, 0]
170 if self.knn_perm_.dtype in (numpy.float32, numpy.float64):
171 return float(res)
172 if self.knn_perm_.dtype in (numpy.int32, numpy.int64):
173 return int(res)
174 raise NotImplementedError( # pragma: no cover
175 "The function does not work for type {}.".format(
176 self.knn_perm_.dtype))
178 def transform(self, X, y):
179 """
180 Transforms *X* and *y*.
181 Returns transformed *X* and *y*.
182 If *y* is None, the returned value for *y*
183 is None as well.
184 """
185 if y is None:
186 return X, None
187 self._check_is_fitted()
188 if len(y.shape) == 1 or y.dtype in (numpy.str, numpy.int32, numpy.int64):
189 # permutes classes
190 yp = y.copy().ravel()
191 num = numpy.issubdtype(y.dtype, numpy.floating)
192 for i in range(len(yp)): # pylint: disable=C0200
193 if num and numpy.isnan(yp[i]):
194 continue
195 if yp[i] not in self.permutation_:
196 if self.closest:
197 cl = self._find_closest(yp[i])
198 else:
199 raise RuntimeError("Unable to find key '{}' in {}.".format(
200 yp[i], list(sorted(self.permutation_))))
201 else:
202 cl = yp[i]
203 yp[i] = self.permutation_[cl]
204 return X, yp.reshape(y.shape)
205 else:
206 # y is probababilies or raw score
207 if len(y.shape) != 2:
208 raise RuntimeError(
209 "yp should be a matrix but has shape {}.".format(y.shape))
210 cl = [(v, k) for k, v in self.permutation_.items()]
211 cl.sort()
212 new_perm = {}
213 for cl, current in cl:
214 new_perm[current] = len(new_perm)
215 yp = y.copy()
216 for i in range(y.shape[1]):
217 yp[:, new_perm[i]] = y[:, i]
218 return X, yp