"""
Machine Learning Post request
:githublink:`%|py|5`
"""
import traceback
import pickle
import base64
import falcon
import numpy
from ..args.args_images import string2bytes
from ..tools import json_loads, json_dumps
from .base_logging import BaseLogging
from .mlstorage import MLStorage
[docs]class MLStoragePost(BaseLogging):
"""
Implements a simple :epkg:`REST API` to
upload zip files. The application assumes
machine learning models are actionable through
the following template: :ref:`l-template-ml`.
:githublink:`%|py|22`
"""
[docs] def __init__(self, secret=None, folder='.',
folder_storage='.', version=None):
"""
:param secret: see :class:`BaseLogging <lightmlrestapi.mlapp.base_logging.BaseLogging>`
:param folder: see :class:`BaseLogging <lightmlrestapi.mlapp.base_logging.BaseLogging>`
:param folder_storage: see :class:`MLStorage <lightmlrestapi.mlapp.mlstorage.MLStorage>`
:param version: API REST version
:githublink:`%|py|31`
"""
BaseLogging.__init__(self, secret=secret, folder=folder)
self._storage = MLStorage(folder_storage)
self._version = version
[docs] @staticmethod
def data2json(data):
"""
:epkg:`numpy:array` cannot be converted into
:epkg:`json`. We change the type into a list.
:githublink:`%|py|41`
"""
if isinstance(data, numpy.ndarray):
return dict(shape=data.shape, data=data.tolist())
else:
return data
[docs] def on_post(self, req, resp):
"""
Processes a :epkg:`POST` request.
:param req: request
:param resp: ...
:githublink:`%|py|53`
"""
add_log_data = dict(user=req.get_header('uid'), ip=req.access_route)
if self._version is not None:
add_log_data = self._version
# To get the parameters
# req.get_params
js = None
try:
while True:
chunk = req.stream.read(2**16)
if len(chunk) == 0:
break
if js is None:
js = chunk
else:
js += chunk
except AssertionError as e:
excs = traceback.format_exc()
es = str(e)
if len(es) > 200:
es = es[:200] + '...'
log_data = dict(error=str(e))
log_data.update(add_log_data)
self.error("ML.load", log_data)
raise falcon.HTTPBadRequest(
'Unable to retrieve request content due to: {0}'.format(es), excs)
args = json_loads(js)
self.save_time()
duration = self.duration()
log_data = dict(duration=duration)
command = args.pop('cmd', None)
if command == 'upload':
self.save_time()
try:
name = self._store(args)
except Exception as e:
excs = traceback.format_exc()
es = str(e)
if len(es) > 200:
es = es[:200] + '...'
duration = self.duration()
log_data = dict(error=str(e), duration=duration, data=args)
log_data.update(add_log_data)
self.error("MLB.store", log_data)
raise falcon.HTTPBadRequest(
"Unable to upload model due to: {}".format(es), excs)
duration = self.duration()
log_data = dict(duration=duration, name=name)
self.info("MLB.store '%s'" % name, log_data)
resp.status = falcon.HTTP_201
answer = {"name": name}
elif command == 'predict':
self.save_time()
try:
name, pred, version, loaded = self._predict(args)
except Exception as e:
excs = traceback.format_exc()
es = str(e)
if len(es) > 200:
es = es[:200] + '...'
name = args.get('name', None)
duration = self.duration()
log_data = dict(error=str(e), duration=duration, data=args)
log_data.update(add_log_data)
if name is not None:
log_data['name'] = name
if format is not None:
log_data['format'] = format
if input in args:
log_data['input'] = args['input']
self.error("MLB.predict", log_data)
raise falcon.HTTPBadRequest(
"Unable to predict with model '{0}' due to: {1}, format='{2}'".format(name, es, format), excs)
duration = self.duration()
log_data = dict(duration=duration, version=version, name=name)
if loaded:
self.info("MLB.predict '%s' loaded again" % name, log_data)
else:
self.info("MLB.predict '%s'" % name, log_data)
resp.status = falcon.HTTP_201
answer = {"output": pred, "version": version}
else:
es = "Unknown command '{0}'".format(command)
log_data = dict(msg=es)
log_data.update(add_log_data)
self.error("MLStorage", log_data)
raise falcon.HTTPBadRequest(
'Unable to retrieve request content due to: {0}'.format(es))
try:
js = json_dumps(answer)
except OverflowError as e:
raise falcon.HTTPBadRequest(
'Unable to retrieve request content due to: {0}'.format(e))
resp.body = js
[docs] def _store(self, args):
"""
Stores the model in the storage.
:githublink:`%|py|158`
"""
name = args.pop('name', None)
if name is None or name == '':
keys = ", ".join(sorted(args.keys()))
try:
ms = str(args)
except ValueError:
ms = ""
if len(ms) > 300:
ms = ms[:300] + "..."
raise KeyError(
"Unable to find a model name in sent data, keys={0}, args={1}.".format(keys, ms))
zipped = args.pop('zip', None)
if zipped is None:
raise KeyError(
"The REST API expects to find a zip file data in field 'zip'.")
unstring = string2bytes(zipped)
self._storage.add(name, unstring)
return name
[docs] def _predict(self, args):
"""
Stores the model in the storage.
:githublink:`%|py|181`
"""
name = args.get('name', None)
if name is None:
raise KeyError("Unable to find a model name in sent data.")
form = args.get('format', 'json')
data = args.get('input', None)
if data is None:
raise KeyError("The REST API expects to find field 'input' which contains stringified data the "
"machine learned model can process. Field 'format' indicates which "
"preprocessing to do before calling the model which is currently '{0}'".format(form))
if form == 'json':
data = json_loads(data)
elif form == 'img':
simg = base64.b64decode(data)
data = pickle.loads(simg)
elif form == 'bytes':
data = string2bytes(data)
else:
raise ValueError("Unrecognized format '{0}'.".format(form))
res, version, loaded = self._storage.call_predict(
name, data, version=True, was_loaded=True)
if isinstance(res, numpy.ndarray):
res = res.tolist()
return name, res, version, loaded