Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Implements a transformer which wraps a predictor 

4to do transfer learning. 

5""" 

6import inspect 

7from sklearn.base import BaseEstimator, TransformerMixin 

8from .sklearn_testing import clone_with_fitted_parameters 

9 

10 

11class TransferTransformer(BaseEstimator, TransformerMixin): 

12 """ 

13 Wraps a predictor or a transformer in a transformer. 

14 This model is frozen: it cannot be trained and only 

15 computes the predictions. 

16 

17 .. index:: transfer learning, frozen model 

18 """ 

19 

20 def __init__(self, estimator, method=None, copy_estimator=True, 

21 trainable=False): 

22 """ 

23 @param estimator estimator to wrap in a transformer, it is cloned 

24 with the training data (deep copy) when fitted 

25 @param method if None, guess what method should be called, 

26 *transform* for a transformer, 

27 *predict_proba* for a classifier, 

28 *decision_function* if found, 

29 *predict* otherwiser 

30 @param copy_estimator copy the model instead of taking a reference 

31 @param trainable the transfered model must be trained 

32 """ 

33 TransformerMixin.__init__(self) 

34 BaseEstimator.__init__(self) 

35 self.estimator = estimator 

36 self.copy_estimator = copy_estimator 

37 self.trainable = trainable 

38 if method is None: 

39 if hasattr(estimator, "transform"): 

40 method = "transform" 

41 elif hasattr(estimator, "predict_proba"): 

42 method = "predict_proba" 

43 elif hasattr(estimator, "decision_function"): 

44 method = "decision_function" 

45 elif hasattr(estimator, "predict"): 

46 method = "predict" 

47 else: 

48 raise AttributeError( # pragma: no cover 

49 "Cannot find a method transform, predict_proba, decision_function, " 

50 "predict in object {}".format(type(estimator))) 

51 if not hasattr(estimator, method): 

52 raise AttributeError( # pragma: no cover 

53 "Cannot find method '{}' in object {}".format( 

54 method, type(estimator))) 

55 self.method = method 

56 

57 def fit(self, X=None, y=None, sample_weight=None): 

58 """ 

59 The function does nothing. 

60 

61 :param X: unused 

62 :param y: unused 

63 :param sample_weight: unused 

64 :return: self: returns an instance of self. 

65 

66 Fitted attributes: 

67 

68 * `estimator_`: already trained estimator 

69 """ 

70 if self.copy_estimator: 

71 self.estimator_ = clone_with_fitted_parameters(self.estimator) 

72 from .sklearn_testing import assert_estimator_equal # pylint: disable=C0415 

73 assert_estimator_equal(self.estimator_, self.estimator) 

74 else: 

75 self.estimator_ = self.estimator 

76 if self.trainable: 

77 insp = inspect.signature(self.estimator_.fit) 

78 pars = insp.parameters 

79 if 'y' in pars and 'sample_weight' in pars: 

80 self.estimator_.fit(X, y, sample_weight) 

81 elif 'y' in pars: 

82 self.estimator_.fit(X, y) 

83 elif 'sample_weight' in pars: 

84 self.estimator_.fit(X, sample_weight=sample_weight) 

85 else: 

86 self.estimator_.fit(X) 

87 return self 

88 

89 def transform(self, X): 

90 """ 

91 Runs the predictions. 

92 

93 :param X: numpy array or sparse matrix of shape [n_samples,n_features] 

94 Training data 

95 :return: tranformed *X* 

96 """ 

97 meth = getattr(self.estimator_, self.method) 

98 return meth(X)