Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Defines a search engine for images inspired from 

4# searchimagesrst>`_. 

5`Search images with deep learning <http://www.xavierdupre.fr/app/mlinsights/helpsphinx/notebooks/search_images.html 

6It relies on :epkg:`lightmlrestapi`. 

7""" 

8import os 

9import logging 

10import shutil 

11import numpy 

12import falcon 

13from pyensae.datasource import download_data 

14 

15 

16def search_images_dogcat(app=None, url_images=None, dest=None, module="torch"): 

17 """ 

18 Defines a :epkg:`REST` application. 

19 It returns a list of neighbors among a small set of 

20 images representing dogs and cats. It relies 

21 on :epkg:`torch` or :epkg:`keras`. 

22 

23 @param app application, if None, creates one 

24 @param url_images url or path to the images 

25 @param dest destination of the images (where to unzip) 

26 @param module :epkg:`keras` or :epkg:`torch` 

27 @return app 

28 

29 You can start it by running: 

30 

31 :: 

32 

33 start_dogcatrestapi 

34 

35 And then query it with: 

36 

37 :: 

38 

39 import requests 

40 import ujson 

41 from lightmlrestapi.args import image2base64 

42 img = "path_to_image" 

43 b64 = image2base64(img)[1] 

44 features = ujson.dumps({'X': b64}) 

45 r = requests.post('http://127.0.0.1:8081', data=features) 

46 print(r) 

47 print(r.json()) 

48 

49 It should return: 

50 

51 :: 

52 

53 {'Y': [[[41, 4.8754486973, {'name': 'wiki.png', description='something'}]]]} 

54 """ 

55 logger = logging.getLogger('search_images_dogcat') 

56 if module == "keras": 

57 return _search_images_dogcat_keras( 

58 app=app, url_images=url_images, dest=dest, fLOG=logger.info) 

59 elif module == "torch": 

60 return _search_images_dogcat_torch( 

61 app=app, url_images=url_images, dest=dest, fLOG=logger.info) 

62 else: 

63 raise ValueError("Unexpected module '{0}'.".format(module)) 

64 

65 

66def _search_images_dogcat_keras(app=None, url_images=None, dest=None, fLOG=None): 

67 fLOG("[_search_images_dogcat_keras] Use keras") 

68 if url_images is None or len(url_images) == 0: 

69 url_images = "dog-cat-pixabay.zip" 

70 if dest is None or len(dest) == 0: 

71 dest = os.path.abspath("images") 

72 if not os.path.exists(dest): 

73 fLOG("Create folder '{0}'".format(dest)) # pylint: disable=W1202 

74 os.mkdir(dest) 

75 

76 if not os.path.exists(dest): 

77 raise FileNotFoundError("Unable to find folder '{0}'".format(dest)) 

78 

79 # Downloads and unzips images. 

80 fLOG("Downloads images '{0}'".format(url_images)) # pylint: disable=W1202 

81 fLOG("Destination '{0}'".format(dest)) # pylint: disable=W1202 

82 if '/' in url_images: 

83 spl = url_images.split('/') 

84 zipname = spl[-1] 

85 website = '/'.join(spl[:-1]) 

86 fLOG("zipname '{0}'".format(zipname)) # pylint: disable=W1202 

87 download_data(zipname, whereTo=dest, website=website + "/") 

88 else: 

89 download_data(url_images, whereTo=dest) 

90 

91 classes = [_ for _ in os.listdir( 

92 dest) if os.path.isdir(os.path.join(dest, _))] 

93 if len(classes) == 0: 

94 # We move all images in a folder. 

95 imgs = os.listdir(dest) 

96 cl = os.path.join(dest, "oneclass") 

97 os.mkdir(cl) 

98 for img in imgs: 

99 shutil.move(os.path.join(dest, img), cl) 

100 fLOG("Moving all images to '{0}'".format(cl)) # pylint: disable=W1202 

101 classes = ['oneclass'] 

102 

103 fLOG("Discovering images in '{0}'".format(dest)) # pylint: disable=W1202 

104 

105 # Iterator on images 

106 from keras.preprocessing.image import ImageDataGenerator 

107 augmenting_datagen = ImageDataGenerator(rescale=1. / 255) 

108 try: 

109 iterimf = augmenting_datagen.flow_from_directory(dest, batch_size=1, target_size=(224, 224), 

110 classes=classes, shuffle=False) 

111 except Exception as e: 

112 fLOG("ERROR '{0}'".format(str(e))) # pylint: disable=W1202 

113 raise e 

114 

115 # Deep learning model. 

116 fLOG("Loading model '{0}'".format('MobileNet')) # pylint: disable=W1202 

117 

118 from keras.applications.mobilenet import MobileNet 

119 model = MobileNet(input_shape=None, alpha=1.0, depth_multiplier=1, dropout=1e-3, 

120 include_top=True, weights='imagenet', input_tensor=None, 

121 pooling=None, classes=1000) 

122 

123 # Sets up the application. 

124 def predict_load(): 

125 from mlinsights.search_rank import SearchEnginePredictionImages 

126 se = SearchEnginePredictionImages(model, fct_params=dict(layer=len(model.layers) - 2), 

127 n_neighbors=5) 

128 # fit 

129 fLOG("Creating the neighbors") 

130 se.fit(iterimf) 

131 return se 

132 

133 # prediction function 

134 from lightmlrestapi.args import base642image, image2array 

135 

136 def mypredict(se, X): 

137 if isinstance(X, str): 

138 img2 = base642image(X) 

139 return mypredict(se, img2) 

140 elif isinstance(X, list): 

141 return [mypredict(se, x) for x in X] 

142 else: 

143 gen = ImageDataGenerator(rescale=1. / 255) 

144 X = image2array(X.convert('RGB').resize((224, 224))) 

145 iterim = gen.flow(X[numpy.newaxis, :, :, :], batch_size=1) 

146 score, ind, meta = se.kneighbors(iterim) 

147 res = list(zip(map(float, score), map( 

148 int, ind), meta.to_dict('records'))) 

149 return res 

150 

151 # Creates the application. 

152 fLOG("[_search_images_dogcat_keras] Setting the application") 

153 from lightmlrestapi.mlapp import MachineLearningPost 

154 if app is None: 

155 app = falcon.API() 

156 app.add_route('/', 

157 MachineLearningPost(predict_load, mypredict)) 

158 return app 

159 

160 

161def _search_images_dogcat_torch(app=None, url_images=None, dest=None, fLOG=None): 

162 fLOG("[_search_images_dogcat_torch] Use torch") 

163 if url_images is None or len(url_images) == 0: 

164 url_images = "dog-cat-pixabay.zip" 

165 if dest is None or len(dest) == 0: 

166 dest = os.path.abspath("images") 

167 if not os.path.exists(dest): 

168 fLOG("Create folder '{0}'".format(dest)) # pylint: disable=W1202 

169 os.mkdir(dest) 

170 

171 if not os.path.exists(dest): 

172 raise FileNotFoundError("Unable to find folder '{0}'".format(dest)) 

173 

174 # Downloads and unzips images. 

175 fLOG("Downloads images '{0}'".format(url_images)) # pylint: disable=W1202 

176 fLOG("Destination '{0}'".format(dest)) # pylint: disable=W1202 

177 if '/' in url_images: 

178 spl = url_images.split('/') 

179 zipname = spl[-1] 

180 website = '/'.join(spl[:-1]) 

181 download_data(zipname, whereTo=dest, website=website + "/") 

182 else: 

183 download_data(url_images, whereTo=dest) 

184 

185 classes = [_ for _ in os.listdir( 

186 dest) if os.path.isdir(os.path.join(dest, _))] 

187 if len(classes) == 0: 

188 # We move all images in a folder. 

189 imgs = os.listdir(dest) 

190 cl = os.path.join(dest, "oneclass") 

191 os.mkdir(cl) 

192 for img in imgs: 

193 shutil.move(os.path.join(dest, img), cl) 

194 fLOG("Moving all images to '{0}'".format(cl)) # pylint: disable=W1202 

195 classes = ['oneclass'] 

196 

197 fLOG("Discovering images in '{0}'".format(dest)) # pylint: disable=W1202 

198 

199 # fit a model 

200 from torchvision import datasets, transforms # pylint: disable=E0401 

201 trans = transforms.Compose([transforms.Resize((224, 224)), 

202 transforms.CenterCrop(224), 

203 transforms.ToTensor()]) 

204 iterim = datasets.ImageFolder(dest, trans) 

205 

206 from torchvision.models import squeezenet1_1 # pylint: disable=E0401 

207 model = squeezenet1_1(True) 

208 

209 # Sets up the application. 

210 def predict_load(): 

211 from mlinsights.search_rank import SearchEnginePredictionImages 

212 se = SearchEnginePredictionImages( 

213 model, fct_params={}, n_neighbors=5) 

214 # fit 

215 fLOG("Creating the neighbors") 

216 se.fit(iterim) 

217 fLOG("Creating the neighbors - end") 

218 return se 

219 

220 # prediction function 

221 from lightmlrestapi.args import base642image, image2array 

222 

223 def mypredict(se, X): 

224 if isinstance(X, str): 

225 img2 = base642image(X) 

226 return mypredict(se, img2) 

227 elif isinstance(X, list): 

228 return [mypredict(se, x) for x in X] 

229 else: 

230 try: 

231 X = image2array(X.convert('RGB').resize((224, 224))) 

232 if X.shape[1] != 3: 

233 X = numpy.transpose(X, (2, 0, 1)) 

234 if X.dtype.kind == 'u': 

235 X = X / numpy.float32(255.0) 

236 score, ind, meta = se.kneighbors(X) 

237 except Exception as e: 

238 # import traceback 

239 # traceback.print_exc() 

240 raise e 

241 try: 

242 res = list(zip(map(float, score), map( 

243 int, ind), meta.to_dict('records'))) 

244 except Exception as e: 

245 fLOG("ERROR: {}".format(e)) 

246 raise e 

247 return res 

248 

249 # Creates the application. 

250 fLOG("[_search_images_dogcat_torch] Setting the application") 

251 from lightmlrestapi.mlapp import MachineLearningPost 

252 if app is None: 

253 app = falcon.API() 

254 app.add_route('/', 

255 MachineLearningPost(predict_load, mypredict)) 

256 return app