Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Implements a way to get close examples based
4on the output of a machine learned model.
5"""
6import json
7import zipfile
8import pandas
9import numpy
10from sklearn.neighbors import NearestNeighbors
11from pandas_streaming.df import to_zip, read_zip
12from ..helpers.parameters import format_function_call
15class SearchEngineVectors:
16 """
17 Implements a kind of local search engine which
18 looks for similar results assuming they are vectors.
19 The class is using
20 :epkg:`sklearn:neighborsNearestNeighbors` to find
21 the nearest neighbors of a vector and follows
22 the same API.
23 The class populates members:
25 * ``features_``: vectors used to compute the neighbors
26 * ``knn_``: parameters for the :epkg:`sklearn:neighborsNearestNeighbors`
27 * ``metadata_``: metadata, can be None
28 """
30 def __init__(self, **pknn):
31 """
32 @param pknn list of parameters, see :epkg:`sklearn:neighborsNearestNeighbors`
33 """
34 self.pknn = pknn
36 def __repr__(self):
37 """
38 usual
39 """
40 return format_function_call(self.__class__.__name__, self.pknn)
42 def _is_iterable(self, data):
43 """
44 Tells if an objet is an iterator or not.
45 """
46 try:
47 iter(data)
48 return not isinstance(data, (list, tuple, pandas.DataFrame, numpy.ndarray))
49 except TypeError:
50 return False
52 def _prepare_fit(self, data=None, features=None, metadata=None, transform=None):
53 """
54 Stores data in the class itself.
56 @param data a :epkg:`dataframe` or None if the
57 the features and the metadata
58 are specified with an array and a
59 dictionary
60 @param features features columns or an array
61 @param metadata data
62 @param transform transform each vector before using it
64 *transform* is a function whose signature::
66 def transform(vec, many):
67 # Many tells is the functions receives many vectors
68 # or just one (many=False).
70 Function *transform* is applied only if
71 *data* is not None.
72 """
73 iterate = self._is_iterable(data)
74 if iterate:
75 if data is None:
76 raise ValueError( # pragma: no cover
77 "iterator is True, data must be specified.")
78 if features is not None:
79 raise ValueError( # pragma: no cover
80 "iterator is True, features must be None.")
81 if metadata is not None:
82 raise ValueError( # pragma: no cover
83 "iterator is True, metadata must be None.")
84 metas = []
85 arrays = []
86 for row in data:
87 if not isinstance(row, tuple):
88 raise TypeError( # pragma: no cover
89 'data must be an iterator on tuple')
90 if len(row) != 2:
91 raise ValueError( # pragma: no cover
92 'data must be an iterator on tuple on two elements')
93 arr, meta = row
94 if not isinstance(meta, dict):
95 raise TypeError( # pragma: no cover
96 'Second element of the tuple must be a dictionary')
97 metas.append(meta)
98 if transform is None:
99 tradd = arr
100 else:
101 tradd = transform(arr, False)
102 if not isinstance(tradd, numpy.ndarray):
103 if transform is None:
104 raise TypeError( # pragma: no cover
105 "feature should be of type numpy.array not {}".format(type(tradd)))
106 else:
107 raise TypeError( # pragma: no cover
108 "output of method transform ({}) should be of type numpy.array not {}".format(
109 transform, type(tradd)))
110 arrays.append(tradd)
111 self.features_ = numpy.vstack(arrays)
112 self.metadata_ = pandas.DataFrame(metas)
113 elif data is None:
114 if not isinstance(features, numpy.ndarray):
115 raise TypeError( # pragma: no cover
116 "features must be an array if data is None")
117 self.features_ = features
118 self.metadata_ = metadata
119 else:
120 if not isinstance(data, pandas.DataFrame):
121 raise ValueError( # pragma: no cover
122 "data should be a dataframe")
123 self.features_ = data[features]
124 self.metadata_ = data[metadata] if metadata else None
126 def fit(self, data=None, features=None, metadata=None):
127 """
128 Every vector comes with a list of metadata.
130 @param data a dataframe or None if the
131 the features and the metadata
132 are specified with an array and a
133 dictionary
134 @param features features columns or an array
135 @param metadata data
136 """
137 self._prepare_fit(data=data, features=features, metadata=metadata)
138 return self._fit_knn()
140 def _fit_knn(self):
141 """
142 Fits the nearest neighbors.
143 """
144 self.knn_ = NearestNeighbors(**self.pknn)
145 self.knn_.fit(self.features_)
146 return self
148 def _first_pass(self, X, n_neighbors=None):
149 """
150 Finds the closest *n_neighbors*.
152 @param X features
153 @param n_neighbors number of neighbors to get (default is the value passed to the constructor)
154 @return *dist*, *ind*
156 *dist* is an array representing the lengths to points,
157 *ind* contains the indices of the nearest points in the population matrix.
158 """
159 if isinstance(X, list):
160 if len(X) == 0 or isinstance(X[0], (list, tuple)):
161 raise TypeError( # pragma: no cover
162 "X must be a list or a vector (1)")
163 X = [X]
164 if isinstance(X, numpy.ndarray) and (len(X.shape) > 1 and X.shape[0] != 1):
165 raise TypeError( # pragma: no cover
166 "X must be a list or a vector (2)")
167 dist, ind = self.knn_.kneighbors(
168 X, n_neighbors=n_neighbors, return_distance=True)
169 ind = ind.ravel()
170 dist = dist.ravel()
171 return dist, ind
173 def _second_pass(self, X, dist, ind):
174 """
175 Reorders the closest *n_neighbors*.
177 @param X features
178 @param dist array representing the lengths to points
179 @param ind indices of the nearest points in the population matrix
180 @return *score*, *ind*
182 *score* is an array representing the lengths to points,
183 *ind* contains the indices of the nearest points in the population matrix.
184 """
185 return dist, ind
187 def kneighbors(self, X, n_neighbors=None):
188 """
189 Searches for neighbors close to *X*.
191 @param X features
192 @return score, ind, meta
194 *score* is an array representing the lengths to points,
195 *ind* contains the indices of the nearest points in the population matrix,
196 *meta* is the metadata
197 """
198 dist, ind = self._first_pass(X, n_neighbors=n_neighbors)
199 score, ind = self._second_pass(X, dist, ind)
200 rind = ind
201 if self.metadata_ is None:
202 rmeta = None
203 elif hasattr(self.metadata_, 'iloc'):
204 rmeta = self.metadata_.iloc[ind, :]
205 elif len(self.metadata_.shape) == 1:
206 rmeta = self.metadata_[ind]
207 else:
208 rmeta = self.metadata_[ind, :]
209 return score, rind, rmeta
211 def to_zip(self, zipfilename, **kwargs):
212 """
213 Saves the features and the metadata into a zipfile.
214 The function does not save the *k-nn*.
216 @param zipfilename a :epkg:`*py:zipfile:ZipFile` or a filename
217 @param kwargs parameters for :epkg:`pandas:to_csv` (for the metadata)
218 @return zipfilename
220 The function relies on function
221 `to_zip <http://www.xavierdupre.fr/app/pandas_streaming/helpsphinx/pandas_streaming/df/
222 dataframe_io.html#pandas_streaming.df.dataframe_io.to_zip>`_.
223 It only works for :epkg:`Python` 3.6+.
224 """
225 if isinstance(zipfilename, str):
226 zf = zipfile.ZipFile(zipfilename, 'w')
227 close = True
228 else:
229 zf = zipfilename
230 close = False
231 if 'index' is not kwargs:
232 kwargs['index'] = False
233 to_zip(self.features_, zf, 'SearchEngineVectors-features.npy')
234 to_zip(self.metadata_, zf, 'SearchEngineVectors-metadata.csv', **kwargs)
235 js = json.dumps(self.pknn)
236 zf.writestr('SearchEngineVectors-knn.json', js)
237 if close:
238 zf.close()
240 @staticmethod
241 def read_zip(zipfilename, **kwargs):
242 """
243 Restore the features, the metadata to a @see cl SearchEngineVectors.
245 @param zipfilename a :epkg:`*py:zipfile:ZipFile` or a filename
246 @param zname a filename in th zipfile
247 @param kwargs parameters for :epkg:`pandas:read_csv`
248 @return @see cl SearchEngineVectors
250 It only works for :epkg:`Python` 3.6+.
251 """
252 if isinstance(zipfilename, str):
253 zf = zipfile.ZipFile(zipfilename, 'r')
254 close = True
255 else:
256 zf = zipfilename
257 close = False
258 feat = read_zip(zf, 'SearchEngineVectors-features.npy')
259 meta = read_zip(zf, 'SearchEngineVectors-metadata.csv', **kwargs)
260 js = zf.read('SearchEngineVectors-knn.json')
261 knn = json.loads(js)
262 if close:
263 zf.close()
265 obj = SearchEngineVectors(**knn)
266 obj.fit(features=feat, metadata=meta)
267 return obj