Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Machine Learning Post request
4"""
5import traceback
6import json
7import falcon
8import numpy
9from ..tools import json_loads, json_dumps
10from .base_logging import BaseLogging
13class MachineLearningPost(BaseLogging):
14 """
15 Implements a simple :epkg:`REST API` which handles
16 a post request, no authentification
17 is required. The model ingests a vector *X*
18 and outputs another one or a number *Y*.
19 An basic example of an application is given by
20 @see fn dummy_application.
21 """
23 _call_convention = {'single': 0, 'multi': 1, 'both': 2}
25 def __init__(self, load_function, predict_function,
26 secret=None, folder='.',
27 log_features=True, log_prediction=True,
28 load_params=None, ccall='single', version=None):
29 """
30 @param predict_function predict function
31 @param load_function load function
32 @param secret see @see cl BaseLogging
33 @param folder see @see cl BaseLogging
34 @param log_features log the features
35 @param log_prediction log the prediction
36 @param load_params given to the loading function
37 @param ccall see below
38 @param version API REST version
40 Some models can only computes predictions for a sequence
41 of observations, not just one. Parameter *ccall* defines what
42 the prediction function can ingest.
43 * *single*: only one observation
44 * *multi*: only multiple ones
45 * *both*: the function determines what it must do
46 """
47 BaseLogging.__init__(self, secret=secret, folder=folder)
48 self._predict_fct = predict_function
49 self._log_features = log_features
50 self._log_prediction = log_prediction
51 self._load_fct = load_function
52 self._load_params = {} if load_params is None else load_params
53 self._loaded_results = None
54 self._version = version
55 if ccall not in MachineLearningPost._call_convention:
56 raise ValueError("ccall '{0}' must be in {1}".format(
57 ccall, MachineLearningPost._call_convention))
58 self._ccall = MachineLearningPost._call_convention[ccall]
59 if not isinstance(self._load_params, dict):
60 raise TypeError("load_params must be a dictionary.")
62 @staticmethod
63 def data2json(data):
64 """
65 :epkg:`numpy:array` cannot be converted into
66 :epkg:`json`. We change the type into a list.
67 """
68 if isinstance(data, numpy.ndarray):
69 return dict(shape=data.shape, data=data.tolist())
70 else:
71 return data
73 def _load(self):
74 return self._load_fct(**(self._load_params))
76 def _predict_single(self, obj, features):
77 if self._ccall == 1:
78 return self._predict_fct(obj, [features])
79 else:
80 return self._predict_fct(obj, features)
82 def check_single(self, features):
83 """
84 Checks the sequence load + predict returns
85 something with the given observations.
86 """
87 obj = self._load()
88 return self._predict_single(obj, features)
90 def on_post(self, req, resp):
91 """
92 @param req request
93 @param resp ...
94 """
95 add_log_data = dict(user=req.get_header('uid'), ip=req.access_route)
96 if self._version is not None:
97 add_log_data = self._version
99 # To get the parameters
100 # req.get_params
101 js = None
102 try:
103 while True:
104 chunk = req.stream.read(2**16)
105 if len(chunk) == 0:
106 break
107 if js is None:
108 js = chunk
109 else:
110 js += chunk
111 except AssertionError as e:
112 excs = traceback.format_exc()
113 es = str(e)
114 if len(es) > 400:
115 es = es[:400] + '...'
116 log_data = dict(error=str(e))
117 log_data.update(add_log_data)
118 if self._load_params:
119 log_data['load_params'] = self._load_params
120 self.error("ML.load", log_data)
121 raise falcon.HTTPBadRequest(
122 'Unable to retrieve request content due to: {0}'.format(es), excs)
124 args = json_loads(js)
125 X = args["X"]
127 # load the model
128 if self._loaded_results is None:
129 self.save_time()
130 try:
131 self._loaded_results = self._load()
132 except Exception as e:
133 excs = traceback.format_exc()
134 es = str(e)
135 if len(es) > 200:
136 es = es[:200] + '...'
137 duration = self.duration()
138 log_data = dict(duration=duration, error=str(e))
139 log_data.update(add_log_data)
140 if self._load_params:
141 log_data['load_params'] = self._load_params
142 self.error("ML.load", log_data)
143 raise falcon.HTTPBadRequest(
144 'Unable to load due to: {0}'.format(es), excs)
145 duration = self.duration()
146 log_data = dict(duration=duration)
147 log_data.update(add_log_data)
148 if self._load_params:
149 log_data['load_params'] = self._load_params
150 self.info("ML.load", log_data)
152 # predict
153 self.save_time()
154 try:
155 res = self._predict_single(self._loaded_results, X)
156 except Exception as e:
157 excs = traceback.format_exc()
158 es = str(e)
159 if len(es) > 200:
160 es = es[:200] + '...'
161 duration = self.duration()
162 log_data = dict(duration=duration)
163 log_data.update(add_log_data)
164 if self._log_features:
165 log_data['X'] = MachineLearningPost.data2json(X)
166 log_data["error"] = str(e)
167 self.error("ML.predict", log_data)
168 raise falcon.HTTPBadRequest(
169 'Unable to predict due to: {0}'.format(es), excs)
170 duration = self.duration()
172 # see http://falcon.readthedocs.io/en/stable/api/request_and_response.html
173 log_data = dict(duration=duration)
174 log_data.update(add_log_data)
175 if self._log_features:
176 log_data['X'] = MachineLearningPost.data2json(X)
177 if self._log_prediction:
178 log_data['Y'] = MachineLearningPost.data2json(res)
179 self.info("ML.predict", log_data)
181 resp.status = falcon.HTTP_201
182 answer = {"Y": res}
183 if self._version is not None:
184 answer[".version"] = self._version
185 try:
186 js = json_dumps(answer)
187 except OverflowError as e:
188 try:
189 json.dumps(answer)
190 except Exception as ee:
191 raise OverflowError(
192 'res probably contains numpy arrays or numpy.types ({0}), they cannot be serialized.'.format(type(res))) from ee
193 raise OverflowError(
194 'res probably contains numpy arrays ({0}), they cannot be serialized with ujson but with json.'.format(type(res))) from e
195 resp.body = js