Source code for mlinsights.mlmodel.transfer_transformer

"""
Implements a transformer which wraps a predictor
to do transfer learning.


:githublink:`%|py|6`
"""
import inspect
from sklearn.base import BaseEstimator, TransformerMixin
from .sklearn_testing import clone_with_fitted_parameters


[docs]class TransferTransformer(BaseEstimator, TransformerMixin): """ Wraps a predictor or a transformer in a transformer. This model is frozen: it cannot be trained and only computes the predictions. .. index:: transfer learning, frozen model :githublink:`%|py|18` """
[docs] def __init__(self, estimator, method=None, copy_estimator=True, trainable=False): """ :param estimator: estimator to wrap in a transformer, it is cloned with the training data (deep copy) when fitted :param method: if None, guess what method should be called, *transform* for a transformer, *predict_proba* for a classifier, *decision_function* if found, *predict* otherwiser :param copy_estimator: copy the model instead of taking a reference :param trainable: the transfered model must be trained :githublink:`%|py|32` """ TransformerMixin.__init__(self) BaseEstimator.__init__(self) self.estimator = estimator self.copy_estimator = copy_estimator self.trainable = trainable if method is None: if hasattr(estimator, "transform"): method = "transform" elif hasattr(estimator, "predict_proba"): method = "predict_proba" elif hasattr(estimator, "decision_function"): method = "decision_function" elif hasattr(estimator, "predict"): method = "predict" else: raise AttributeError( # pragma: no cover "Cannot find a method transform, predict_proba, decision_function, " "predict in object {}".format(type(estimator))) if not hasattr(estimator, method): raise AttributeError( # pragma: no cover "Cannot find method '{}' in object {}".format( method, type(estimator))) self.method = method
[docs] def fit(self, X=None, y=None, sample_weight=None): """ The function does nothing. :param X: unused :param y: unused :param sample_weight: unused :return: self: returns an instance of self. Fitted attributes: * `estimator_`: already trained estimator :githublink:`%|py|69` """ if self.copy_estimator: self.estimator_ = clone_with_fitted_parameters(self.estimator) from .sklearn_testing import assert_estimator_equal # pylint: disable=C0415 assert_estimator_equal(self.estimator_, self.estimator) else: self.estimator_ = self.estimator if self.trainable: insp = inspect.signature(self.estimator_.fit) pars = insp.parameters if 'y' in pars and 'sample_weight' in pars: self.estimator_.fit(X, y, sample_weight) elif 'y' in pars: self.estimator_.fit(X, y) elif 'sample_weight' in pars: self.estimator_.fit(X, sample_weight=sample_weight) else: self.estimator_.fit(X) return self
[docs] def transform(self, X): """ Runs the predictions. :param X: numpy array or sparse matrix of shape [n_samples,n_features] Training data :return: tranformed *X* :githublink:`%|py|96` """ meth = getattr(self.estimator_, self.method) return meth(X)