Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Implements TextVectorizerTransformer. 

4""" 

5import numpy 

6from scipy import sparse 

7import pandas 

8from sklearn.base import BaseEstimator, TransformerMixin, clone 

9 

10 

11class TextVectorizerTransformer(BaseEstimator, TransformerMixin): 

12 """ 

13 Overwrites *TfIdfVectorizer* or *CountVectorizer* 

14 so that it can be used in a pipeline. 

15 

16 Parameters 

17 ---------- 

18 

19 estimator: estimator to fit on every column 

20 

21 estimators_: trained estimators, one per column 

22 """ 

23 

24 def __init__(self, estimator): 

25 """ 

26 @param estimator *TfIdfVectorizer* or *CountVectorizer* 

27 """ 

28 self.estimator = estimator 

29 

30 def fit(self, X, y=None): 

31 """ 

32 Trains an estimator on every column. 

33 """ 

34 self.estimators_ = [] 

35 for i in range(X.shape[1]): 

36 est = clone(self.estimator) 

37 if isinstance(X, pandas.DataFrame): 

38 col = X.iloc[:, i] 

39 elif isinstance(X, numpy.ndarray): 

40 col = X[:, i] 

41 else: 

42 raise TypeError( # pragma: no cover 

43 "X must be an array or a dataframe.") 

44 est.fit(col) 

45 self.estimators_.append(est) 

46 return self 

47 

48 def transform(self, X): 

49 """ 

50 Applies the vectorizer on X. 

51 """ 

52 if len(self.estimators_) != X.shape[1]: 

53 raise ValueError( # pragma: no cover 

54 "Unexpected number of columns {}, expecting {}".format( 

55 X.shape[1], len(self.estimators_))) 

56 res = [] 

57 for i in range(X.shape[1]): 

58 if isinstance(X, pandas.DataFrame): 

59 col = X.iloc[:, i] 

60 elif isinstance(X, numpy.ndarray): 

61 col = X[:, i] 

62 else: 

63 raise TypeError( # pragma: no cover 

64 "X must be an array or a dataframe.") 

65 r = self.estimators_[i].transform(col) 

66 res.append(r) 

67 if len(res) == 1: 

68 return res[0] 

69 if all(map(lambda r: isinstance(r, numpy.ndarray), res)): 

70 return numpy.hstack(res) 

71 return sparse.hstack(res)