Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Implements TextVectorizerTransformer.
4"""
5import numpy
6from scipy import sparse
7import pandas
8from sklearn.base import BaseEstimator, TransformerMixin, clone
11class TextVectorizerTransformer(BaseEstimator, TransformerMixin):
12 """
13 Overwrites *TfIdfVectorizer* or *CountVectorizer*
14 so that it can be used in a pipeline.
16 Parameters
17 ----------
19 estimator: estimator to fit on every column
21 estimators_: trained estimators, one per column
22 """
24 def __init__(self, estimator):
25 """
26 @param estimator *TfIdfVectorizer* or *CountVectorizer*
27 """
28 self.estimator = estimator
30 def fit(self, X, y=None):
31 """
32 Trains an estimator on every column.
33 """
34 self.estimators_ = []
35 for i in range(X.shape[1]):
36 est = clone(self.estimator)
37 if isinstance(X, pandas.DataFrame):
38 col = X.iloc[:, i]
39 elif isinstance(X, numpy.ndarray):
40 col = X[:, i]
41 else:
42 raise TypeError( # pragma: no cover
43 "X must be an array or a dataframe.")
44 est.fit(col)
45 self.estimators_.append(est)
46 return self
48 def transform(self, X):
49 """
50 Applies the vectorizer on X.
51 """
52 if len(self.estimators_) != X.shape[1]:
53 raise ValueError( # pragma: no cover
54 "Unexpected number of columns {}, expecting {}".format(
55 X.shape[1], len(self.estimators_)))
56 res = []
57 for i in range(X.shape[1]):
58 if isinstance(X, pandas.DataFrame):
59 col = X.iloc[:, i]
60 elif isinstance(X, numpy.ndarray):
61 col = X[:, i]
62 else:
63 raise TypeError( # pragma: no cover
64 "X must be an array or a dataframe.")
65 r = self.estimators_[i].transform(col)
66 res.append(r)
67 if len(res) == 1:
68 return res[0]
69 if all(map(lambda r: isinstance(r, numpy.ndarray), res)):
70 return numpy.hstack(res)
71 return sparse.hstack(res)