Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Implements a way to get close examples based 

4on the output of a machine learned model. 

5""" 

6import numpy 

7from .search_engine_predictions import SearchEnginePredictions 

8 

9 

10class SearchEnginePredictionImages(SearchEnginePredictions): 

11 """ 

12 Extends class @see cl SearchEnginePredictions. 

13 Vectors are coming from images. The metadata must contains 

14 information about path names. We assume all images can hold 

15 in memory. An example can found in notebook 

16 :ref:`searchimageskerasrst` or :ref:`searchimagestorchrst`. 

17 Another example can be found there: 

18 `search_images_dogcat.py 

19 <https://github.com/sdpython/ensae_projects/blob/master/src/ 

20 ensae_projects/restapi/search_images_dogcat.py>`_. 

21 """ 

22 

23 def _prepare_fit(self, data=None, features=None, metadata=None, 

24 transform=None, n=None, fLOG=None): 

25 """ 

26 Stores data in the class itself. 

27 

28 @param data a dataframe or None if the 

29 the features and the metadata 

30 are specified with an array and a 

31 dictionary 

32 @param features features columns or an array 

33 @param metadata data 

34 @param transform transform each vector before using it 

35 @param n takes *n* images (or ``len(iter_images)``) 

36 @param fLOG logging function 

37 """ 

38 if "torch" in str(type(data)): 

39 self.module_ = "torch" 

40 from torch.utils.data import DataLoader # pylint: disable=E0401,C0415,E0611 

41 dataloader = DataLoader( 

42 data, batch_size=1, shuffle=False, num_workers=0) 

43 self.iter_images_ = iter_images = iter( 

44 zip(dataloader, data.samples)) 

45 if n is None: 

46 n = len(data) 

47 elif "keras" in str(type(data)): # pragma: no cover 

48 self.module_ = "keras" 

49 iter_images = data 

50 # We delay the import as keras backend is not necessarily installed. 

51 from keras.preprocessing.image import Iterator # pylint: disable=E0401,C0415,E0611 

52 from keras_preprocessing.image import DirectoryIterator, NumpyArrayIterator # pylint: disable=E0401,C0415 

53 if not isinstance(iter_images, (Iterator, DirectoryIterator, NumpyArrayIterator)): 

54 raise NotImplementedError( # pragma: no cover 

55 "iter_images must be a keras Iterator. No option implemented for type {0}." 

56 "".format(type(iter_images))) 

57 if iter_images.batch_size != 1: 

58 raise ValueError( # pragma: no cover 

59 "batch_size must be 1 not {0}".format( 

60 iter_images.batch_size)) 

61 self.iter_images_ = iter_images 

62 if n is None: 

63 n = len(iter_images) 

64 if not hasattr(iter_images, "filenames"): 

65 raise NotImplementedError( # pragma: no cover 

66 "Iterator does not iterate on images but numpy arrays (not implemented).") 

67 else: 

68 raise TypeError( # pragma: no cover 

69 "Unexpected data type {0}.".format(type(data))) 

70 

71 def get_current_index(flow): 

72 "get current index" 

73 return flow.index_array[(flow.batch_index + flow.n - 1) % flow.n] 

74 

75 def iterator_feature_meta(): 

76 "iterators on metadata" 

77 def accessor(iter_images): 

78 if hasattr(iter_images, 'filenames'): 

79 # keras 

80 return (lambda i, ite: (ite, iter_images.filenames[get_current_index(iter_images)])) 

81 else: 

82 # torch 

83 return (lambda i, ite: (ite[0], ite[1][0])) 

84 acc = accessor(iter_images) 

85 

86 for i, it in zip(range(n), iter_images): 

87 im, name = acc(i, it) 

88 if not isinstance(name, str): 

89 raise TypeError( # pragma: no cover 

90 "name should be a string, not {0}".format(type(name))) 

91 yield im[0], dict(name=name, i=i) 

92 if fLOG and i % 10000 == 0: 

93 fLOG( 

94 '[SearchEnginePredictionImages.fit] i={}/{} - {}'.format(i, n, name)) 

95 super()._prepare_fit(data=iterator_feature_meta(), transform=transform) 

96 

97 def fit(self, iter_images, n=None, fLOG=None): # pylint: disable=W0237 

98 """ 

99 Processes images through the model and fits a *k-nn*. 

100 

101 @param iter_images `Iterator <https://github.com/fchollet/keras/blob/master/keras/preprocessing/image.py#L719>`_ 

102 @param n takes *n* images (or ``len(iter_images)``) 

103 @param fLOG logging function 

104 @param kwimg parameters used to preprocess the images 

105 """ 

106 self._prepare_fit(data=iter_images, transform=self.fct, n=n, fLOG=fLOG) 

107 return self._fit_knn() 

108 

109 def kneighbors(self, iter_images, n_neighbors=None): # pylint: disable=W0237 

110 """ 

111 Searches for neighbors close to the first image 

112 returned by *iter_images*. It returns the neighbors 

113 only for the first image. 

114 

115 @param iter_images `Iterator <https://github.com/fchollet/keras/blob/master/keras/preprocessing/image.py#L719>`_ 

116 @return score, ind, meta 

117 

118 *score* is an array representing the lengths to points, 

119 *ind* contains the indices of the nearest points in the population matrix, 

120 *meta* is the metadata. 

121 """ 

122 if isinstance(iter_images, numpy.ndarray): 

123 if self.module_ == "keras": # pragma: no cover 

124 raise NotImplementedError("Not yet implemented or Keras.") 

125 elif self.module_ == "torch": 

126 from torch import from_numpy # pylint: disable=E0611,E0401,C0415 

127 X = from_numpy(iter_images[numpy.newaxis, :, :, :]) 

128 return super().kneighbors(X, n_neighbors=n_neighbors) 

129 raise RuntimeError( # pragma: no cover 

130 "Unknown module '{0}'.".format(self.module_)) 

131 elif "keras" in str(iter_images): # pragma: no cover 

132 if self.module_ != "keras": 

133 raise RuntimeError( # pragma: no cover 

134 "Keras object but {0} was used to train the KNN.".format(self.module_)) 

135 # We delay the import as keras backend is not necessarily installed. 

136 # keras, it expects an iterator. 

137 from keras.preprocessing.image import Iterator # pylint: disable=E0401,C0415,E0611 

138 from keras_preprocessing.image import DirectoryIterator, NumpyArrayIterator # pylint: disable=E0401,C0415,E0611 

139 if not isinstance(iter_images, (Iterator, DirectoryIterator, NumpyArrayIterator)): 

140 raise NotImplementedError( # pragma: no cover 

141 "iter_images must be a keras Iterator. No option implemented for type {0}.".format(type(iter_images))) 

142 if iter_images.batch_size != 1: 

143 raise ValueError( # pragma: no cover 

144 "batch_size must be 1 not {0}".format( 

145 iter_images.batch_size)) 

146 for img in iter_images: 

147 X = img[0] 

148 break 

149 return super().kneighbors(X, n_neighbors=n_neighbors) 

150 elif "torch" in str(type(iter_images)): 

151 if self.module_ != "torch": 

152 raise RuntimeError( # pragma: no cover 

153 "Torch object but {0} was used to train the KNN.".format(self.module_)) 

154 # torch: it expects a tensor 

155 X = iter_images 

156 return super().kneighbors(X, n_neighbors=n_neighbors) 

157 elif isinstance(iter_images, list): 

158 res = [self.kneighbors(it, n_neighbors=n_neighbors) 

159 for it in iter_images] 

160 return (numpy.vstack([_[0] for _ in res]), 

161 numpy.vstack([_[1] for _ in res]), 

162 numpy.vstack([_[2] for _ in res])) 

163 else: 

164 raise TypeError( # pragma: no cover 

165 "Unexpected type {0} in SearchEnginePredictionImages.kneighbors".format( 

166 type(iter_images)))