Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Implements a way to get close examples based
4on the output of a machine learned model.
5"""
6import numpy
7from .search_engine_predictions import SearchEnginePredictions
10class SearchEnginePredictionImages(SearchEnginePredictions):
11 """
12 Extends class @see cl SearchEnginePredictions.
13 Vectors are coming from images. The metadata must contains
14 information about path names. We assume all images can hold
15 in memory. An example can found in notebook
16 :ref:`searchimageskerasrst` or :ref:`searchimagestorchrst`.
17 Another example can be found there:
18 `search_images_dogcat.py
19 <https://github.com/sdpython/ensae_projects/blob/master/src/
20 ensae_projects/restapi/search_images_dogcat.py>`_.
21 """
23 def _prepare_fit(self, data=None, features=None, metadata=None,
24 transform=None, n=None, fLOG=None):
25 """
26 Stores data in the class itself.
28 @param data a dataframe or None if the
29 the features and the metadata
30 are specified with an array and a
31 dictionary
32 @param features features columns or an array
33 @param metadata data
34 @param transform transform each vector before using it
35 @param n takes *n* images (or ``len(iter_images)``)
36 @param fLOG logging function
37 """
38 if "torch" in str(type(data)):
39 self.module_ = "torch"
40 from torch.utils.data import DataLoader # pylint: disable=E0401,C0415,E0611
41 dataloader = DataLoader(
42 data, batch_size=1, shuffle=False, num_workers=0)
43 self.iter_images_ = iter_images = iter(
44 zip(dataloader, data.samples))
45 if n is None:
46 n = len(data)
47 elif "keras" in str(type(data)): # pragma: no cover
48 self.module_ = "keras"
49 iter_images = data
50 # We delay the import as keras backend is not necessarily installed.
51 from keras.preprocessing.image import Iterator # pylint: disable=E0401,C0415,E0611
52 from keras_preprocessing.image import DirectoryIterator, NumpyArrayIterator # pylint: disable=E0401,C0415
53 if not isinstance(iter_images, (Iterator, DirectoryIterator, NumpyArrayIterator)):
54 raise NotImplementedError( # pragma: no cover
55 "iter_images must be a keras Iterator. No option implemented for type {0}."
56 "".format(type(iter_images)))
57 if iter_images.batch_size != 1:
58 raise ValueError( # pragma: no cover
59 "batch_size must be 1 not {0}".format(
60 iter_images.batch_size))
61 self.iter_images_ = iter_images
62 if n is None:
63 n = len(iter_images)
64 if not hasattr(iter_images, "filenames"):
65 raise NotImplementedError( # pragma: no cover
66 "Iterator does not iterate on images but numpy arrays (not implemented).")
67 else:
68 raise TypeError( # pragma: no cover
69 "Unexpected data type {0}.".format(type(data)))
71 def get_current_index(flow):
72 "get current index"
73 return flow.index_array[(flow.batch_index + flow.n - 1) % flow.n]
75 def iterator_feature_meta():
76 "iterators on metadata"
77 def accessor(iter_images):
78 if hasattr(iter_images, 'filenames'):
79 # keras
80 return (lambda i, ite: (ite, iter_images.filenames[get_current_index(iter_images)]))
81 else:
82 # torch
83 return (lambda i, ite: (ite[0], ite[1][0]))
84 acc = accessor(iter_images)
86 for i, it in zip(range(n), iter_images):
87 im, name = acc(i, it)
88 if not isinstance(name, str):
89 raise TypeError( # pragma: no cover
90 "name should be a string, not {0}".format(type(name)))
91 yield im[0], dict(name=name, i=i)
92 if fLOG and i % 10000 == 0:
93 fLOG(
94 '[SearchEnginePredictionImages.fit] i={}/{} - {}'.format(i, n, name))
95 super()._prepare_fit(data=iterator_feature_meta(), transform=transform)
97 def fit(self, iter_images, n=None, fLOG=None): # pylint: disable=W0237
98 """
99 Processes images through the model and fits a *k-nn*.
101 @param iter_images `Iterator <https://github.com/fchollet/keras/blob/master/keras/preprocessing/image.py#L719>`_
102 @param n takes *n* images (or ``len(iter_images)``)
103 @param fLOG logging function
104 @param kwimg parameters used to preprocess the images
105 """
106 self._prepare_fit(data=iter_images, transform=self.fct, n=n, fLOG=fLOG)
107 return self._fit_knn()
109 def kneighbors(self, iter_images, n_neighbors=None): # pylint: disable=W0237
110 """
111 Searches for neighbors close to the first image
112 returned by *iter_images*. It returns the neighbors
113 only for the first image.
115 @param iter_images `Iterator <https://github.com/fchollet/keras/blob/master/keras/preprocessing/image.py#L719>`_
116 @return score, ind, meta
118 *score* is an array representing the lengths to points,
119 *ind* contains the indices of the nearest points in the population matrix,
120 *meta* is the metadata.
121 """
122 if isinstance(iter_images, numpy.ndarray):
123 if self.module_ == "keras": # pragma: no cover
124 raise NotImplementedError("Not yet implemented or Keras.")
125 elif self.module_ == "torch":
126 from torch import from_numpy # pylint: disable=E0611,E0401,C0415
127 X = from_numpy(iter_images[numpy.newaxis, :, :, :])
128 return super().kneighbors(X, n_neighbors=n_neighbors)
129 raise RuntimeError( # pragma: no cover
130 "Unknown module '{0}'.".format(self.module_))
131 elif "keras" in str(iter_images): # pragma: no cover
132 if self.module_ != "keras":
133 raise RuntimeError( # pragma: no cover
134 "Keras object but {0} was used to train the KNN.".format(self.module_))
135 # We delay the import as keras backend is not necessarily installed.
136 # keras, it expects an iterator.
137 from keras.preprocessing.image import Iterator # pylint: disable=E0401,C0415,E0611
138 from keras_preprocessing.image import DirectoryIterator, NumpyArrayIterator # pylint: disable=E0401,C0415,E0611
139 if not isinstance(iter_images, (Iterator, DirectoryIterator, NumpyArrayIterator)):
140 raise NotImplementedError( # pragma: no cover
141 "iter_images must be a keras Iterator. No option implemented for type {0}.".format(type(iter_images)))
142 if iter_images.batch_size != 1:
143 raise ValueError( # pragma: no cover
144 "batch_size must be 1 not {0}".format(
145 iter_images.batch_size))
146 for img in iter_images:
147 X = img[0]
148 break
149 return super().kneighbors(X, n_neighbors=n_neighbors)
150 elif "torch" in str(type(iter_images)):
151 if self.module_ != "torch":
152 raise RuntimeError( # pragma: no cover
153 "Torch object but {0} was used to train the KNN.".format(self.module_))
154 # torch: it expects a tensor
155 X = iter_images
156 return super().kneighbors(X, n_neighbors=n_neighbors)
157 elif isinstance(iter_images, list):
158 res = [self.kneighbors(it, n_neighbors=n_neighbors)
159 for it in iter_images]
160 return (numpy.vstack([_[0] for _ in res]),
161 numpy.vstack([_[1] for _ in res]),
162 numpy.vstack([_[2] for _ in res]))
163 else:
164 raise TypeError( # pragma: no cover
165 "Unexpected type {0} in SearchEnginePredictionImages.kneighbors".format(
166 type(iter_images)))