Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Defines a search engine for images inspired from
4# searchimagesrst>`_.
5`Search images with deep learning <http://www.xavierdupre.fr/app/mlinsights/helpsphinx/notebooks/search_images.html
6It relies on :epkg:`lightmlrestapi`.
7"""
8import os
9import logging
10import shutil
11import numpy
12import falcon
13from pyensae.datasource import download_data
16def search_images_dogcat(app=None, url_images=None, dest=None, module="torch"):
17 """
18 Defines a :epkg:`REST` application.
19 It returns a list of neighbors among a small set of
20 images representing dogs and cats. It relies
21 on :epkg:`torch` or :epkg:`keras`.
23 @param app application, if None, creates one
24 @param url_images url or path to the images
25 @param dest destination of the images (where to unzip)
26 @param module :epkg:`keras` or :epkg:`torch`
27 @return app
29 You can start it by running:
31 ::
33 start_dogcatrestapi
35 And then query it with:
37 ::
39 import requests
40 import ujson
41 from lightmlrestapi.args import image2base64
42 img = "path_to_image"
43 b64 = image2base64(img)[1]
44 features = ujson.dumps({'X': b64})
45 r = requests.post('http://127.0.0.1:8081', data=features)
46 print(r)
47 print(r.json())
49 It should return:
51 ::
53 {'Y': [[[41, 4.8754486973, {'name': 'wiki.png', description='something'}]]]}
54 """
55 logger = logging.getLogger('search_images_dogcat')
56 if module == "keras":
57 return _search_images_dogcat_keras(
58 app=app, url_images=url_images, dest=dest, fLOG=logger.info)
59 elif module == "torch":
60 return _search_images_dogcat_torch(
61 app=app, url_images=url_images, dest=dest, fLOG=logger.info)
62 else:
63 raise ValueError("Unexpected module '{0}'.".format(module))
66def _search_images_dogcat_keras(app=None, url_images=None, dest=None, fLOG=None):
67 fLOG("[_search_images_dogcat_keras] Use keras")
68 if url_images is None or len(url_images) == 0:
69 url_images = "dog-cat-pixabay.zip"
70 if dest is None or len(dest) == 0:
71 dest = os.path.abspath("images")
72 if not os.path.exists(dest):
73 fLOG("Create folder '{0}'".format(dest)) # pylint: disable=W1202
74 os.mkdir(dest)
76 if not os.path.exists(dest):
77 raise FileNotFoundError("Unable to find folder '{0}'".format(dest))
79 # Downloads and unzips images.
80 fLOG("Downloads images '{0}'".format(url_images)) # pylint: disable=W1202
81 fLOG("Destination '{0}'".format(dest)) # pylint: disable=W1202
82 if '/' in url_images:
83 spl = url_images.split('/')
84 zipname = spl[-1]
85 website = '/'.join(spl[:-1])
86 fLOG("zipname '{0}'".format(zipname)) # pylint: disable=W1202
87 download_data(zipname, whereTo=dest, website=website + "/")
88 else:
89 download_data(url_images, whereTo=dest)
91 classes = [_ for _ in os.listdir(
92 dest) if os.path.isdir(os.path.join(dest, _))]
93 if len(classes) == 0:
94 # We move all images in a folder.
95 imgs = os.listdir(dest)
96 cl = os.path.join(dest, "oneclass")
97 os.mkdir(cl)
98 for img in imgs:
99 shutil.move(os.path.join(dest, img), cl)
100 fLOG("Moving all images to '{0}'".format(cl)) # pylint: disable=W1202
101 classes = ['oneclass']
103 fLOG("Discovering images in '{0}'".format(dest)) # pylint: disable=W1202
105 # Iterator on images
106 from keras.preprocessing.image import ImageDataGenerator
107 augmenting_datagen = ImageDataGenerator(rescale=1. / 255)
108 try:
109 iterimf = augmenting_datagen.flow_from_directory(dest, batch_size=1, target_size=(224, 224),
110 classes=classes, shuffle=False)
111 except Exception as e:
112 fLOG("ERROR '{0}'".format(str(e))) # pylint: disable=W1202
113 raise e
115 # Deep learning model.
116 fLOG("Loading model '{0}'".format('MobileNet')) # pylint: disable=W1202
118 from keras.applications.mobilenet import MobileNet
119 model = MobileNet(input_shape=None, alpha=1.0, depth_multiplier=1, dropout=1e-3,
120 include_top=True, weights='imagenet', input_tensor=None,
121 pooling=None, classes=1000)
123 # Sets up the application.
124 def predict_load():
125 from mlinsights.search_rank import SearchEnginePredictionImages
126 se = SearchEnginePredictionImages(model, fct_params=dict(layer=len(model.layers) - 2),
127 n_neighbors=5)
128 # fit
129 fLOG("Creating the neighbors")
130 se.fit(iterimf)
131 return se
133 # prediction function
134 from lightmlrestapi.args import base642image, image2array
136 def mypredict(se, X):
137 if isinstance(X, str):
138 img2 = base642image(X)
139 return mypredict(se, img2)
140 elif isinstance(X, list):
141 return [mypredict(se, x) for x in X]
142 else:
143 gen = ImageDataGenerator(rescale=1. / 255)
144 X = image2array(X.convert('RGB').resize((224, 224)))
145 iterim = gen.flow(X[numpy.newaxis, :, :, :], batch_size=1)
146 score, ind, meta = se.kneighbors(iterim)
147 res = list(zip(map(float, score), map(
148 int, ind), meta.to_dict('records')))
149 return res
151 # Creates the application.
152 fLOG("[_search_images_dogcat_keras] Setting the application")
153 from lightmlrestapi.mlapp import MachineLearningPost
154 if app is None:
155 app = falcon.API()
156 app.add_route('/',
157 MachineLearningPost(predict_load, mypredict))
158 return app
161def _search_images_dogcat_torch(app=None, url_images=None, dest=None, fLOG=None):
162 fLOG("[_search_images_dogcat_torch] Use torch")
163 if url_images is None or len(url_images) == 0:
164 url_images = "dog-cat-pixabay.zip"
165 if dest is None or len(dest) == 0:
166 dest = os.path.abspath("images")
167 if not os.path.exists(dest):
168 fLOG("Create folder '{0}'".format(dest)) # pylint: disable=W1202
169 os.mkdir(dest)
171 if not os.path.exists(dest):
172 raise FileNotFoundError("Unable to find folder '{0}'".format(dest))
174 # Downloads and unzips images.
175 fLOG("Downloads images '{0}'".format(url_images)) # pylint: disable=W1202
176 fLOG("Destination '{0}'".format(dest)) # pylint: disable=W1202
177 if '/' in url_images:
178 spl = url_images.split('/')
179 zipname = spl[-1]
180 website = '/'.join(spl[:-1])
181 download_data(zipname, whereTo=dest, website=website + "/")
182 else:
183 download_data(url_images, whereTo=dest)
185 classes = [_ for _ in os.listdir(
186 dest) if os.path.isdir(os.path.join(dest, _))]
187 if len(classes) == 0:
188 # We move all images in a folder.
189 imgs = os.listdir(dest)
190 cl = os.path.join(dest, "oneclass")
191 os.mkdir(cl)
192 for img in imgs:
193 shutil.move(os.path.join(dest, img), cl)
194 fLOG("Moving all images to '{0}'".format(cl)) # pylint: disable=W1202
195 classes = ['oneclass']
197 fLOG("Discovering images in '{0}'".format(dest)) # pylint: disable=W1202
199 # fit a model
200 from torchvision import datasets, transforms # pylint: disable=E0401
201 trans = transforms.Compose([transforms.Resize((224, 224)),
202 transforms.CenterCrop(224),
203 transforms.ToTensor()])
204 iterim = datasets.ImageFolder(dest, trans)
206 from torchvision.models import squeezenet1_1 # pylint: disable=E0401
207 model = squeezenet1_1(True)
209 # Sets up the application.
210 def predict_load():
211 from mlinsights.search_rank import SearchEnginePredictionImages
212 se = SearchEnginePredictionImages(
213 model, fct_params={}, n_neighbors=5)
214 # fit
215 fLOG("Creating the neighbors")
216 se.fit(iterim)
217 fLOG("Creating the neighbors - end")
218 return se
220 # prediction function
221 from lightmlrestapi.args import base642image, image2array
223 def mypredict(se, X):
224 if isinstance(X, str):
225 img2 = base642image(X)
226 return mypredict(se, img2)
227 elif isinstance(X, list):
228 return [mypredict(se, x) for x in X]
229 else:
230 try:
231 X = image2array(X.convert('RGB').resize((224, 224)))
232 if X.shape[1] != 3:
233 X = numpy.transpose(X, (2, 0, 1))
234 if X.dtype.kind == 'u':
235 X = X / numpy.float32(255.0)
236 score, ind, meta = se.kneighbors(X)
237 except Exception as e:
238 # import traceback
239 # traceback.print_exc()
240 raise e
241 try:
242 res = list(zip(map(float, score), map(
243 int, ind), meta.to_dict('records')))
244 except Exception as e:
245 fLOG("ERROR: {}".format(e))
246 raise e
247 return res
249 # Creates the application.
250 fLOG("[_search_images_dogcat_torch] Setting the application")
251 from lightmlrestapi.mlapp import MachineLearningPost
252 if app is None:
253 app = falcon.API()
254 app.add_route('/',
255 MachineLearningPost(predict_load, mypredict))
256 return app