Coverage for src/lightmlrestapi/mlapp/mlstorage_rest.py: 73%
139 statements
« prev ^ index » next coverage.py v6.4.1, created at 2022-06-06 07:16 +0200
« prev ^ index » next coverage.py v6.4.1, created at 2022-06-06 07:16 +0200
1"""
2@file
3@brief Machine Learning Post request
4"""
5import traceback
6import pickle
7import base64
8import falcon
9import numpy
10from ..args.args_images import string2bytes
11from ..tools import json_loads, json_dumps
12from .base_logging import BaseLogging
13from .mlstorage import MLStorage
16class MLStoragePost(BaseLogging):
17 """
18 Implements a simple :epkg:`REST API` to
19 upload zip files. The application assumes
20 machine learning models are actionable through
21 the following template: :ref:`l-template-ml`.
22 """
24 def __init__(self, secret=None, folder='.',
25 folder_storage='.', version=None):
26 """
27 @param secret see @see cl BaseLogging
28 @param folder see @see cl BaseLogging
29 @param folder_storage see @see cl MLStorage
30 @param version API REST version
31 """
32 BaseLogging.__init__(self, secret=secret, folder=folder)
33 self._storage = MLStorage(folder_storage)
34 self._version = version
36 @staticmethod
37 def data2json(data):
38 """
39 :epkg:`numpy:array` cannot be converted into
40 :epkg:`json`. We change the type into a list.
41 """
42 if isinstance(data, numpy.ndarray):
43 return dict(shape=data.shape, data=data.tolist())
44 else:
45 return data
47 def on_post(self, req, resp):
48 """
49 Processes a :epkg:`POST` request.
51 @param req request
52 @param resp ...
53 """
54 add_log_data = dict(user=req.get_header('uid'), ip=req.access_route)
55 if self._version is not None:
56 add_log_data = self._version
58 # To get the parameters
59 # req.get_params
60 js = None
61 try:
62 while True:
63 chunk = req.stream.read(2**16)
64 if len(chunk) == 0:
65 break
66 if js is None:
67 js = chunk
68 else:
69 js += chunk
70 except AssertionError as e:
71 excs = traceback.format_exc()
72 es = str(e)
73 if len(es) > 200:
74 es = es[:200] + '...'
75 log_data = dict(error=str(e))
76 log_data.update(add_log_data)
77 self.error("ML.load", log_data)
78 raise falcon.HTTPBadRequest(
79 'Unable to retrieve request content due to: {0}'.format(es), excs)
81 args = json_loads(js)
82 self.save_time()
83 duration = self.duration()
84 log_data = dict(duration=duration)
86 command = args.pop('cmd', None)
87 if command == 'upload':
88 self.save_time()
89 try:
90 name = self._store(args)
91 except Exception as e:
92 excs = traceback.format_exc()
93 es = str(e)
94 if len(es) > 200:
95 es = es[:200] + '...'
96 duration = self.duration()
97 log_data = dict(error=str(e), duration=duration, data=args)
98 log_data.update(add_log_data)
99 self.error("MLB.store", log_data)
100 raise falcon.HTTPBadRequest(
101 "Unable to upload model due to: {}".format(es), excs)
103 duration = self.duration()
104 log_data = dict(duration=duration, name=name)
105 self.info("MLB.store '%s'" % name, log_data)
106 resp.status = falcon.HTTP_201
107 answer = {"name": name}
109 elif command == 'predict':
110 self.save_time()
111 try:
112 name, pred, version, loaded = self._predict(args)
113 except Exception as e:
114 excs = traceback.format_exc()
115 es = str(e)
116 if len(es) > 200:
117 es = es[:200] + '...'
118 name = args.get('name', None)
119 duration = self.duration()
120 log_data = dict(error=str(e), duration=duration, data=args)
121 log_data.update(add_log_data)
122 if name is not None:
123 log_data['name'] = name
124 if format is not None:
125 log_data['format'] = format
126 if input in args:
127 log_data['input'] = args['input']
128 self.error("MLB.predict", log_data)
129 raise falcon.HTTPBadRequest(
130 "Unable to predict with model '{0}' due to: {1}, format='{2}'".format(name, es, format), excs)
132 duration = self.duration()
133 log_data = dict(duration=duration, version=version, name=name)
134 if loaded:
135 self.info("MLB.predict '%s' loaded again" % name, log_data)
136 else:
137 self.info("MLB.predict '%s'" % name, log_data)
138 resp.status = falcon.HTTP_201
139 answer = {"output": pred, "version": version}
140 else:
141 es = "Unknown command '{0}'".format(command)
142 log_data = dict(msg=es)
143 log_data.update(add_log_data)
144 self.error("MLStorage", log_data)
145 raise falcon.HTTPBadRequest(
146 'Unable to retrieve request content due to: {0}'.format(es))
148 try:
149 js = json_dumps(answer)
150 except OverflowError as e:
151 raise falcon.HTTPBadRequest(
152 'Unable to retrieve request content due to: {0}'.format(e))
153 resp.body = js
155 def _store(self, args):
156 """
157 Stores the model in the storage.
158 """
159 name = args.pop('name', None)
160 if name is None or name == '':
161 keys = ", ".join(sorted(args.keys()))
162 try:
163 ms = str(args)
164 except ValueError:
165 ms = ""
166 if len(ms) > 300:
167 ms = ms[:300] + "..."
168 raise KeyError(
169 "Unable to find a model name in sent data, keys={0}, args={1}.".format(keys, ms))
170 zipped = args.pop('zip', None)
171 if zipped is None:
172 raise KeyError(
173 "The REST API expects to find a zip file data in field 'zip'.")
174 unstring = string2bytes(zipped)
175 self._storage.add(name, unstring)
176 return name
178 def _predict(self, args):
179 """
180 Stores the model in the storage.
181 """
182 name = args.get('name', None)
183 if name is None:
184 raise KeyError("Unable to find a model name in sent data.")
185 form = args.get('format', 'json')
186 data = args.get('input', None)
187 if data is None:
188 raise KeyError("The REST API expects to find field 'input' which contains stringified data the "
189 "machine learned model can process. Field 'format' indicates which "
190 "preprocessing to do before calling the model which is currently '{0}'".format(form))
191 if form == 'json':
192 data = json_loads(data)
193 elif form == 'img':
194 simg = base64.b64decode(data)
195 data = pickle.loads(simg)
196 elif form == 'bytes':
197 data = string2bytes(data)
198 else:
199 raise ValueError("Unrecognized format '{0}'.".format(form))
200 res, version, loaded = self._storage.call_predict(
201 name, data, version=True, was_loaded=True)
202 if isinstance(res, numpy.ndarray):
203 res = res.tolist()
204 return name, res, version, loaded