Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- coding: utf-8 -*-
2"""
3@file
4@brief Builds a knn classifier for image in order to find close images.
5"""
6import os
7import numpy
8from PIL.Image import Image
9from sklearn.neighbors import NearestNeighbors
10from .image_helper import img2gray, enumerate_image_class, read_image, image_zoom
13class ImageNearestNeighbors(NearestNeighbors):
14 """
15 Builds a model on the top of :epkg:`NearestNeighbors`
16 in order to find close images.
17 """
19 def __init__(self, transform='gray', size=(10, 10), **kwargs):
20 """
21 @param transform function which transform every image
22 @param size every image is zoomed to keep the same dimension
23 @param kwargs see :epkg:`NearestNeighbors`
24 """
25 NearestNeighbors.__init__(self, **kwargs)
26 self.image_size = size
27 self.transform = transform
28 self._get_transform()
30 def _get_transform(self):
31 """
32 Returns the associated transform function with ``self.transform_``.
33 """
34 if self.transform == "gray":
35 pre = img2gray
36 elif self.transform is None:
37 pre = None
38 else:
39 raise ValueError(
40 "No transform assicated with value '{0}'.".format(self.transform))
41 if pre is None:
42 return lambda img: image_zoom(img, new_size=self.image_size)
43 else:
44 return lambda img: image_zoom(pre(img), new_size=self.image_size)
46 def _folder2matrix(self, folder, fLOG):
47 """
48 Converts images stored in a folder into a matrix of features.
49 """
50 transform = self._get_transform()
51 imgs = []
52 subs = []
53 stack = []
54 for i, (name, sub) in enumerate(enumerate_image_class(folder, abspath=False)):
55 if fLOG is not None and i % 1000 == 0:
56 fLOG("[ImageNearestNeighbors] processing image {0}: "
57 "'{1}' - class '{2}'".format(i, name, sub))
58 imgs.append(name.replace("\\", "/"))
59 subs.append(sub)
60 img = read_image(os.path.join(folder, name))
61 trimg = transform(img)
62 stack.append(numpy.array(trimg).ravel())
63 X = numpy.vstack(stack)
64 return X, imgs, subs
66 def _imglist2matrix(self, list_of_images, fLOG):
67 """
68 Converts a list of images into a matrix of features.
69 """
70 transform = self._get_transform()
71 imgs = []
72 subs = []
73 stack = []
74 for i, name in enumerate(list_of_images):
75 if isinstance(name, tuple):
76 name, sub = name
77 else:
78 sub = None
79 if fLOG is not None and i % 1000 == 0:
80 fLOG("[ImageNearestNeighbors] processing image {0}: "
81 "'{1}' - class '{2}'".format(i, img, sub))
82 if isinstance(name, Image):
83 imgs.append(None)
84 img = name
85 else:
86 imgs.append(name.replace("\\", "/"))
87 img = read_image(name)
88 subs.append(sub)
89 trimg = transform(img)
90 stack.append(numpy.array(trimg).ravel())
91 X = numpy.vstack(stack)
92 return X, imgs, subs
94 def fit(self, X, y=None, fLOG=None): # pylint: disable=W0221
95 """
96 Fits the model. *X* can be a folder.
98 @param X matrix or str for a subfolder of images
99 @param y unused
100 @param fLOG logging function
102 If *X* is a folder, the method relies on function
103 @see fct enumerate_image_class. In that case, the method
104 also creates attributes:
106 * ``image_names_``: all image names
107 * ``image_classes_``: subfolder the image belongs too
108 """
109 if isinstance(X, str):
110 if not os.path.exists(X):
111 raise FileNotFoundError("Folder '{0}' not found.".format(X))
112 X, imgs, subs = self._folder2matrix(X, fLOG)
113 self.image_names_ = imgs # pylint: disable=W0201
114 self.image_classes_ = subs # pylint: disable=W0201
116 elif isinstance(X, list):
117 if isinstance(X[0], Image):
118 transform = self._get_transform()
119 X = numpy.array([numpy.array(transform(img)).ravel()
120 for img in X])
121 elif isinstance(X[0], str):
122 # image names
123 X, imgs, subs = self._imglist2matrix(X, fLOG)
124 self.image_names_ = imgs # pylint: disable=W0201
125 self.image_classes_ = subs # pylint: disable=W0201
126 elif isinstance(X[0], tuple):
127 self.image_classes_ = list( # pylint: disable=W0201
128 map(lambda t: t[1], X))
129 X, imgs, _ = self._imglist2matrix([_[0] for _ in X], fLOG)
130 self.image_names_ = imgs # pylint: disable=W0201
131 else:
132 raise TypeError(
133 "X should be a list of PIL.Image not {0}".format(type(X[0])))
135 super(ImageNearestNeighbors, self).fit(X, y)
136 return self
138 def _private_kn(self, method, X, *args, fLOG=None, **kwargs):
139 """
140 Commun private function to handle the same kind of
141 inputs in all transform functions.
143 @param method method to run
144 @param X inputs, matrix, folder or list of images
145 @param args additional positinal arguments
146 @param fLOG logging function
147 @param kwargs additional named arguements
148 @return depends on *method*
149 """
150 if isinstance(X, str):
151 if not os.path.exists(X):
152 raise FileNotFoundError("Folder '{0}' not found.".format(X))
153 if os.path.isfile(X):
154 X = [X]
155 return self._private_kn(method, X, *args, **kwargs)
156 X = self._folder2matrix(X, fLOG=fLOG)[0]
158 elif isinstance(X, list):
159 if isinstance(X[0], Image):
160 transform = self._get_transform()
161 X = numpy.array([numpy.array(transform(img)).ravel()
162 for img in X])
163 elif isinstance(X[0], str):
164 # image names
165 X = self._imglist2matrix(X, None)[0]
166 elif isinstance(X[0], tuple):
167 # image names
168 X = self._imglist2matrix([_[0] for _ in X], fLOG=fLOG)[0]
169 else:
170 raise TypeError("X should be a list of Image")
171 elif isinstance(X, Image):
172 return self._private_kn(method, [X], *args, **kwargs)
174 method = getattr(NearestNeighbors, method)
175 return method(self, X, *args, **kwargs)
177 def kneighbors(self, X=None, n_neighbors=None, return_distance=True, fLOG=None): # pylint: disable=W0221
178 """
179 See :epkg:`NearestNeighbors`, method :epkg:`kneighbors`.
180 Parameter *X* can be a file, the image is then loaded and converted
181 with the same transform. *X* can also be an *Image* from :epkg:`PIL`.
182 """
183 return self._private_kn("kneighbors", X=X, n_neighbors=n_neighbors,
184 return_distance=return_distance, fLOG=fLOG)
186 def kneighbors_graph(self, X=None, n_neighbors=None, mode='connectivity', fLOG=None): # pylint: disable=W0221
187 """
188 See :epkg:`NearestNeighbors`, method :epkg:`kneighbors_graph`.
189 Parameter *X* can be a file, the image is then loaded and converted
190 with the same transform. *X* can also be an *Image* from :epkg:`PIL`.
191 """
192 return self._private_kn("kneighbors_graph", X=X, n_neighbors=n_neighbors, mode=mode, fLOG=fLOG)
194 def radius_neighbors(self, X=None, radius=None, return_distance=True, fLOG=None): # pylint: disable=W0221,W0237
195 """
196 See :epkg:`NearestNeighbors`, method :epkg:`radius_neighbors`.
197 Parameter *X* can be a file, the image is then loaded and converted
198 with the same transform. *X* can also be an *Image* from :epkg:`PIL`.
199 """
200 return self._private_kn("radius_neighbors", X=X, radius=radius,
201 return_distance=return_distance, fLOG=fLOG)
203 def get_image_names(self, indices):
204 """
205 Returns images names for the given list of indices.
207 @param indices indices can be a single array or a matrix.
208 @return same shape
209 """
210 if not hasattr(self, "image_names_"):
211 raise RuntimeError("No image names were stored during training.")
212 new_indices = indices.ravel()
213 res = numpy.array([self.image_names_[i] for i in new_indices])
214 return res.reshape(indices.shape)
216 def get_image_classes(self, indices):
217 """
218 Returns images classes for the given list of indices.
220 @param indices indices can be a single array or a matrix.
221 @return same shape
222 """
223 if not hasattr(self, "image_classes_"):
224 raise RuntimeError("No image classes were stored during training.")
225 new_indices = indices.ravel()
226 res = numpy.array([self.image_classes_[i] for i in new_indices])
227 return res.reshape(indices.shape)
229 def plot_neighbors(self, neighbors, distances=None, obs=None, return_figure=False,
230 format_distance="%1.2f", folder_or_images=None):
231 """
232 Calls :epkg:`plot_gallery_images`
233 with information on the neighbors.
235 :param neighbors: matrix of indices
236 :param distances: distances to display
237 :param obs: original image, if not None, will be placed
238 on the first row
239 :param return_figure: returns ``fig, ax`` instead of ``ax``
240 :param format_distance: used to format distances
241 :param folder_or_images: image paths may be relative
242 to some folder, in that case, they should be relative
243 to this folder, it can also be a list of images
244 :return: *ax* or *fix, ax* if *return_figure* is True
245 """
246 from mlinsights.plotting import plot_gallery_images
247 names = self.get_image_names(neighbors)
248 if hasattr(self, "image_classes_"):
249 subs = self.get_image_classes(neighbors)
250 else:
251 subs = numpy.array([["" for i in range(names.shape[1])]
252 for j in range(names.shape[0])])
254 labels = []
255 if distances is not None:
256 for i in range(names.shape[0]):
257 for j in range(names.shape[1]):
258 labels.append("{0} d={1}".format(
259 subs[i, j], format_distance % distances[i, j]))
260 else:
261 for i in range(names.shape[0]):
262 for j in range(names.shape[1]):
263 labels.append(subs[i, j] + " i=%d" % neighbors[i, j])
264 subs = numpy.array(labels).reshape(subs.shape)
266 if obs is not None:
267 if isinstance(obs, str):
268 obs = read_image(obs)
269 row = numpy.array([object() for i in range(names.shape[1])])
270 row[0] = obs
271 names = numpy.vstack([row, names])
272 text = numpy.array(["" for i in range(names.shape[1])])
273 text[0] = "-"
274 subs = numpy.vstack([text, subs])
276 fi = None if isinstance(folder_or_images, list) else folder_or_images
277 return plot_gallery_images(names, subs, return_figure=return_figure,
278 folder_image=fi)