Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# -*- coding: utf-8 -*- 

2""" 

3@file 

4@brief Builds a knn classifier for image in order to find close images. 

5""" 

6import os 

7import numpy 

8from PIL.Image import Image 

9from sklearn.neighbors import NearestNeighbors 

10from .image_helper import img2gray, enumerate_image_class, read_image, image_zoom 

11 

12 

13class ImageNearestNeighbors(NearestNeighbors): 

14 """ 

15 Builds a model on the top of :epkg:`NearestNeighbors` 

16 in order to find close images. 

17 """ 

18 

19 def __init__(self, transform='gray', size=(10, 10), **kwargs): 

20 """ 

21 @param transform function which transform every image 

22 @param size every image is zoomed to keep the same dimension 

23 @param kwargs see :epkg:`NearestNeighbors` 

24 """ 

25 NearestNeighbors.__init__(self, **kwargs) 

26 self.image_size = size 

27 self.transform = transform 

28 self._get_transform() 

29 

30 def _get_transform(self): 

31 """ 

32 Returns the associated transform function with ``self.transform_``. 

33 """ 

34 if self.transform == "gray": 

35 pre = img2gray 

36 elif self.transform is None: 

37 pre = None 

38 else: 

39 raise ValueError( 

40 "No transform assicated with value '{0}'.".format(self.transform)) 

41 if pre is None: 

42 return lambda img: image_zoom(img, new_size=self.image_size) 

43 else: 

44 return lambda img: image_zoom(pre(img), new_size=self.image_size) 

45 

46 def _folder2matrix(self, folder, fLOG): 

47 """ 

48 Converts images stored in a folder into a matrix of features. 

49 """ 

50 transform = self._get_transform() 

51 imgs = [] 

52 subs = [] 

53 stack = [] 

54 for i, (name, sub) in enumerate(enumerate_image_class(folder, abspath=False)): 

55 if fLOG is not None and i % 1000 == 0: 

56 fLOG("[ImageNearestNeighbors] processing image {0}: " 

57 "'{1}' - class '{2}'".format(i, name, sub)) 

58 imgs.append(name.replace("\\", "/")) 

59 subs.append(sub) 

60 img = read_image(os.path.join(folder, name)) 

61 trimg = transform(img) 

62 stack.append(numpy.array(trimg).ravel()) 

63 X = numpy.vstack(stack) 

64 return X, imgs, subs 

65 

66 def _imglist2matrix(self, list_of_images, fLOG): 

67 """ 

68 Converts a list of images into a matrix of features. 

69 """ 

70 transform = self._get_transform() 

71 imgs = [] 

72 subs = [] 

73 stack = [] 

74 for i, name in enumerate(list_of_images): 

75 if isinstance(name, tuple): 

76 name, sub = name 

77 else: 

78 sub = None 

79 if fLOG is not None and i % 1000 == 0: 

80 fLOG("[ImageNearestNeighbors] processing image {0}: " 

81 "'{1}' - class '{2}'".format(i, img, sub)) 

82 if isinstance(name, Image): 

83 imgs.append(None) 

84 img = name 

85 else: 

86 imgs.append(name.replace("\\", "/")) 

87 img = read_image(name) 

88 subs.append(sub) 

89 trimg = transform(img) 

90 stack.append(numpy.array(trimg).ravel()) 

91 X = numpy.vstack(stack) 

92 return X, imgs, subs 

93 

94 def fit(self, X, y=None, fLOG=None): # pylint: disable=W0221 

95 """ 

96 Fits the model. *X* can be a folder. 

97 

98 @param X matrix or str for a subfolder of images 

99 @param y unused 

100 @param fLOG logging function 

101 

102 If *X* is a folder, the method relies on function 

103 @see fct enumerate_image_class. In that case, the method 

104 also creates attributes: 

105 

106 * ``image_names_``: all image names 

107 * ``image_classes_``: subfolder the image belongs too 

108 """ 

109 if isinstance(X, str): 

110 if not os.path.exists(X): 

111 raise FileNotFoundError("Folder '{0}' not found.".format(X)) 

112 X, imgs, subs = self._folder2matrix(X, fLOG) 

113 self.image_names_ = imgs # pylint: disable=W0201 

114 self.image_classes_ = subs # pylint: disable=W0201 

115 

116 elif isinstance(X, list): 

117 if isinstance(X[0], Image): 

118 transform = self._get_transform() 

119 X = numpy.array([numpy.array(transform(img)).ravel() 

120 for img in X]) 

121 elif isinstance(X[0], str): 

122 # image names 

123 X, imgs, subs = self._imglist2matrix(X, fLOG) 

124 self.image_names_ = imgs # pylint: disable=W0201 

125 self.image_classes_ = subs # pylint: disable=W0201 

126 elif isinstance(X[0], tuple): 

127 self.image_classes_ = list( # pylint: disable=W0201 

128 map(lambda t: t[1], X)) 

129 X, imgs, _ = self._imglist2matrix([_[0] for _ in X], fLOG) 

130 self.image_names_ = imgs # pylint: disable=W0201 

131 else: 

132 raise TypeError( 

133 "X should be a list of PIL.Image not {0}".format(type(X[0]))) 

134 

135 super(ImageNearestNeighbors, self).fit(X, y) 

136 return self 

137 

138 def _private_kn(self, method, X, *args, fLOG=None, **kwargs): 

139 """ 

140 Commun private function to handle the same kind of 

141 inputs in all transform functions. 

142 

143 @param method method to run 

144 @param X inputs, matrix, folder or list of images 

145 @param args additional positinal arguments 

146 @param fLOG logging function 

147 @param kwargs additional named arguements 

148 @return depends on *method* 

149 """ 

150 if isinstance(X, str): 

151 if not os.path.exists(X): 

152 raise FileNotFoundError("Folder '{0}' not found.".format(X)) 

153 if os.path.isfile(X): 

154 X = [X] 

155 return self._private_kn(method, X, *args, **kwargs) 

156 X = self._folder2matrix(X, fLOG=fLOG)[0] 

157 

158 elif isinstance(X, list): 

159 if isinstance(X[0], Image): 

160 transform = self._get_transform() 

161 X = numpy.array([numpy.array(transform(img)).ravel() 

162 for img in X]) 

163 elif isinstance(X[0], str): 

164 # image names 

165 X = self._imglist2matrix(X, None)[0] 

166 elif isinstance(X[0], tuple): 

167 # image names 

168 X = self._imglist2matrix([_[0] for _ in X], fLOG=fLOG)[0] 

169 else: 

170 raise TypeError("X should be a list of Image") 

171 elif isinstance(X, Image): 

172 return self._private_kn(method, [X], *args, **kwargs) 

173 

174 method = getattr(NearestNeighbors, method) 

175 return method(self, X, *args, **kwargs) 

176 

177 def kneighbors(self, X=None, n_neighbors=None, return_distance=True, fLOG=None): # pylint: disable=W0221 

178 """ 

179 See :epkg:`NearestNeighbors`, method :epkg:`kneighbors`. 

180 Parameter *X* can be a file, the image is then loaded and converted 

181 with the same transform. *X* can also be an *Image* from :epkg:`PIL`. 

182 """ 

183 return self._private_kn("kneighbors", X=X, n_neighbors=n_neighbors, 

184 return_distance=return_distance, fLOG=fLOG) 

185 

186 def kneighbors_graph(self, X=None, n_neighbors=None, mode='connectivity', fLOG=None): # pylint: disable=W0221 

187 """ 

188 See :epkg:`NearestNeighbors`, method :epkg:`kneighbors_graph`. 

189 Parameter *X* can be a file, the image is then loaded and converted 

190 with the same transform. *X* can also be an *Image* from :epkg:`PIL`. 

191 """ 

192 return self._private_kn("kneighbors_graph", X=X, n_neighbors=n_neighbors, mode=mode, fLOG=fLOG) 

193 

194 def radius_neighbors(self, X=None, radius=None, return_distance=True, fLOG=None): # pylint: disable=W0221,W0237 

195 """ 

196 See :epkg:`NearestNeighbors`, method :epkg:`radius_neighbors`. 

197 Parameter *X* can be a file, the image is then loaded and converted 

198 with the same transform. *X* can also be an *Image* from :epkg:`PIL`. 

199 """ 

200 return self._private_kn("radius_neighbors", X=X, radius=radius, 

201 return_distance=return_distance, fLOG=fLOG) 

202 

203 def get_image_names(self, indices): 

204 """ 

205 Returns images names for the given list of indices. 

206 

207 @param indices indices can be a single array or a matrix. 

208 @return same shape 

209 """ 

210 if not hasattr(self, "image_names_"): 

211 raise RuntimeError("No image names were stored during training.") 

212 new_indices = indices.ravel() 

213 res = numpy.array([self.image_names_[i] for i in new_indices]) 

214 return res.reshape(indices.shape) 

215 

216 def get_image_classes(self, indices): 

217 """ 

218 Returns images classes for the given list of indices. 

219 

220 @param indices indices can be a single array or a matrix. 

221 @return same shape 

222 """ 

223 if not hasattr(self, "image_classes_"): 

224 raise RuntimeError("No image classes were stored during training.") 

225 new_indices = indices.ravel() 

226 res = numpy.array([self.image_classes_[i] for i in new_indices]) 

227 return res.reshape(indices.shape) 

228 

229 def plot_neighbors(self, neighbors, distances=None, obs=None, return_figure=False, 

230 format_distance="%1.2f", folder_or_images=None): 

231 """ 

232 Calls :epkg:`plot_gallery_images` 

233 with information on the neighbors. 

234 

235 :param neighbors: matrix of indices 

236 :param distances: distances to display 

237 :param obs: original image, if not None, will be placed 

238 on the first row 

239 :param return_figure: returns ``fig, ax`` instead of ``ax`` 

240 :param format_distance: used to format distances 

241 :param folder_or_images: image paths may be relative 

242 to some folder, in that case, they should be relative 

243 to this folder, it can also be a list of images 

244 :return: *ax* or *fix, ax* if *return_figure* is True 

245 """ 

246 from mlinsights.plotting import plot_gallery_images 

247 names = self.get_image_names(neighbors) 

248 if hasattr(self, "image_classes_"): 

249 subs = self.get_image_classes(neighbors) 

250 else: 

251 subs = numpy.array([["" for i in range(names.shape[1])] 

252 for j in range(names.shape[0])]) 

253 

254 labels = [] 

255 if distances is not None: 

256 for i in range(names.shape[0]): 

257 for j in range(names.shape[1]): 

258 labels.append("{0} d={1}".format( 

259 subs[i, j], format_distance % distances[i, j])) 

260 else: 

261 for i in range(names.shape[0]): 

262 for j in range(names.shape[1]): 

263 labels.append(subs[i, j] + " i=%d" % neighbors[i, j]) 

264 subs = numpy.array(labels).reshape(subs.shape) 

265 

266 if obs is not None: 

267 if isinstance(obs, str): 

268 obs = read_image(obs) 

269 row = numpy.array([object() for i in range(names.shape[1])]) 

270 row[0] = obs 

271 names = numpy.vstack([row, names]) 

272 text = numpy.array(["" for i in range(names.shape[1])]) 

273 text[0] = "-" 

274 subs = numpy.vstack([text, subs]) 

275 

276 fi = None if isinstance(folder_or_images, list) else folder_or_images 

277 return plot_gallery_images(names, subs, return_figure=return_figure, 

278 folder_image=fi)