Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Machine Learning Post request 

4""" 

5import traceback 

6import json 

7import falcon 

8import numpy 

9from ..tools import json_loads, json_dumps 

10from .base_logging import BaseLogging 

11 

12 

13class MachineLearningPost(BaseLogging): 

14 """ 

15 Implements a simple :epkg:`REST API` which handles 

16 a post request, no authentification 

17 is required. The model ingests a vector *X* 

18 and outputs another one or a number *Y*. 

19 An basic example of an application is given by 

20 @see fn dummy_application. 

21 """ 

22 

23 _call_convention = {'single': 0, 'multi': 1, 'both': 2} 

24 

25 def __init__(self, load_function, predict_function, 

26 secret=None, folder='.', 

27 log_features=True, log_prediction=True, 

28 load_params=None, ccall='single', version=None): 

29 """ 

30 @param predict_function predict function 

31 @param load_function load function 

32 @param secret see @see cl BaseLogging 

33 @param folder see @see cl BaseLogging 

34 @param log_features log the features 

35 @param log_prediction log the prediction 

36 @param load_params given to the loading function 

37 @param ccall see below 

38 @param version API REST version 

39 

40 Some models can only computes predictions for a sequence 

41 of observations, not just one. Parameter *ccall* defines what 

42 the prediction function can ingest. 

43 * *single*: only one observation 

44 * *multi*: only multiple ones 

45 * *both*: the function determines what it must do 

46 """ 

47 BaseLogging.__init__(self, secret=secret, folder=folder) 

48 self._predict_fct = predict_function 

49 self._log_features = log_features 

50 self._log_prediction = log_prediction 

51 self._load_fct = load_function 

52 self._load_params = {} if load_params is None else load_params 

53 self._loaded_results = None 

54 self._version = version 

55 if ccall not in MachineLearningPost._call_convention: 

56 raise ValueError("ccall '{0}' must be in {1}".format( 

57 ccall, MachineLearningPost._call_convention)) 

58 self._ccall = MachineLearningPost._call_convention[ccall] 

59 if not isinstance(self._load_params, dict): 

60 raise TypeError("load_params must be a dictionary.") 

61 

62 @staticmethod 

63 def data2json(data): 

64 """ 

65 :epkg:`numpy:array` cannot be converted into 

66 :epkg:`json`. We change the type into a list. 

67 """ 

68 if isinstance(data, numpy.ndarray): 

69 return dict(shape=data.shape, data=data.tolist()) 

70 else: 

71 return data 

72 

73 def _load(self): 

74 return self._load_fct(**(self._load_params)) 

75 

76 def _predict_single(self, obj, features): 

77 if self._ccall == 1: 

78 return self._predict_fct(obj, [features]) 

79 else: 

80 return self._predict_fct(obj, features) 

81 

82 def check_single(self, features): 

83 """ 

84 Checks the sequence load + predict returns 

85 something with the given observations. 

86 """ 

87 obj = self._load() 

88 return self._predict_single(obj, features) 

89 

90 def on_post(self, req, resp): 

91 """ 

92 @param req request 

93 @param resp ... 

94 """ 

95 add_log_data = dict(user=req.get_header('uid'), ip=req.access_route) 

96 if self._version is not None: 

97 add_log_data = self._version 

98 

99 # To get the parameters 

100 # req.get_params 

101 js = None 

102 try: 

103 while True: 

104 chunk = req.stream.read(2**16) 

105 if len(chunk) == 0: 

106 break 

107 if js is None: 

108 js = chunk 

109 else: 

110 js += chunk 

111 except AssertionError as e: 

112 excs = traceback.format_exc() 

113 es = str(e) 

114 if len(es) > 400: 

115 es = es[:400] + '...' 

116 log_data = dict(error=str(e)) 

117 log_data.update(add_log_data) 

118 if self._load_params: 

119 log_data['load_params'] = self._load_params 

120 self.error("ML.load", log_data) 

121 raise falcon.HTTPBadRequest( 

122 'Unable to retrieve request content due to: {0}'.format(es), excs) 

123 

124 args = json_loads(js) 

125 X = args["X"] 

126 

127 # load the model 

128 if self._loaded_results is None: 

129 self.save_time() 

130 try: 

131 self._loaded_results = self._load() 

132 except Exception as e: 

133 excs = traceback.format_exc() 

134 es = str(e) 

135 if len(es) > 200: 

136 es = es[:200] + '...' 

137 duration = self.duration() 

138 log_data = dict(duration=duration, error=str(e)) 

139 log_data.update(add_log_data) 

140 if self._load_params: 

141 log_data['load_params'] = self._load_params 

142 self.error("ML.load", log_data) 

143 raise falcon.HTTPBadRequest( 

144 'Unable to load due to: {0}'.format(es), excs) 

145 duration = self.duration() 

146 log_data = dict(duration=duration) 

147 log_data.update(add_log_data) 

148 if self._load_params: 

149 log_data['load_params'] = self._load_params 

150 self.info("ML.load", log_data) 

151 

152 # predict 

153 self.save_time() 

154 try: 

155 res = self._predict_single(self._loaded_results, X) 

156 except Exception as e: 

157 excs = traceback.format_exc() 

158 es = str(e) 

159 if len(es) > 200: 

160 es = es[:200] + '...' 

161 duration = self.duration() 

162 log_data = dict(duration=duration) 

163 log_data.update(add_log_data) 

164 if self._log_features: 

165 log_data['X'] = MachineLearningPost.data2json(X) 

166 log_data["error"] = str(e) 

167 self.error("ML.predict", log_data) 

168 raise falcon.HTTPBadRequest( 

169 'Unable to predict due to: {0}'.format(es), excs) 

170 duration = self.duration() 

171 

172 # see http://falcon.readthedocs.io/en/stable/api/request_and_response.html 

173 log_data = dict(duration=duration) 

174 log_data.update(add_log_data) 

175 if self._log_features: 

176 log_data['X'] = MachineLearningPost.data2json(X) 

177 if self._log_prediction: 

178 log_data['Y'] = MachineLearningPost.data2json(res) 

179 self.info("ML.predict", log_data) 

180 

181 resp.status = falcon.HTTP_201 

182 answer = {"Y": res} 

183 if self._version is not None: 

184 answer[".version"] = self._version 

185 try: 

186 js = json_dumps(answer) 

187 except OverflowError as e: 

188 try: 

189 json.dumps(answer) 

190 except Exception as ee: 

191 raise OverflowError( 

192 'res probably contains numpy arrays or numpy.types ({0}), they cannot be serialized.'.format(type(res))) from ee 

193 raise OverflowError( 

194 'res probably contains numpy arrays ({0}), they cannot be serialized with ujson but with json.'.format(type(res))) from e 

195 resp.body = js