Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Machine Learning Post request 

4""" 

5import traceback 

6import pickle 

7import base64 

8import falcon 

9import numpy 

10from ..args.args_images import string2bytes 

11from ..tools import json_loads, json_dumps 

12from .base_logging import BaseLogging 

13from .mlstorage import MLStorage 

14 

15 

16class MLStoragePost(BaseLogging): 

17 """ 

18 Implements a simple :epkg:`REST API` to 

19 upload zip files. The application assumes 

20 machine learning models are actionable through 

21 the following template: :ref:`l-template-ml`. 

22 """ 

23 

24 def __init__(self, secret=None, folder='.', 

25 folder_storage='.', version=None): 

26 """ 

27 @param secret see @see cl BaseLogging 

28 @param folder see @see cl BaseLogging 

29 @param folder_storage see @see cl MLStorage 

30 @param version API REST version 

31 """ 

32 BaseLogging.__init__(self, secret=secret, folder=folder) 

33 self._storage = MLStorage(folder_storage) 

34 self._version = version 

35 

36 @staticmethod 

37 def data2json(data): 

38 """ 

39 :epkg:`numpy:array` cannot be converted into 

40 :epkg:`json`. We change the type into a list. 

41 """ 

42 if isinstance(data, numpy.ndarray): 

43 return dict(shape=data.shape, data=data.tolist()) 

44 else: 

45 return data 

46 

47 def on_post(self, req, resp): 

48 """ 

49 Processes a :epkg:`POST` request. 

50 

51 @param req request 

52 @param resp ... 

53 """ 

54 add_log_data = dict(user=req.get_header('uid'), ip=req.access_route) 

55 if self._version is not None: 

56 add_log_data = self._version 

57 

58 # To get the parameters 

59 # req.get_params 

60 js = None 

61 try: 

62 while True: 

63 chunk = req.stream.read(2**16) 

64 if len(chunk) == 0: 

65 break 

66 if js is None: 

67 js = chunk 

68 else: 

69 js += chunk 

70 except AssertionError as e: 

71 excs = traceback.format_exc() 

72 es = str(e) 

73 if len(es) > 200: 

74 es = es[:200] + '...' 

75 log_data = dict(error=str(e)) 

76 log_data.update(add_log_data) 

77 self.error("ML.load", log_data) 

78 raise falcon.HTTPBadRequest( 

79 'Unable to retrieve request content due to: {0}'.format(es), excs) 

80 

81 args = json_loads(js) 

82 self.save_time() 

83 duration = self.duration() 

84 log_data = dict(duration=duration) 

85 

86 command = args.pop('cmd', None) 

87 if command == 'upload': 

88 self.save_time() 

89 try: 

90 name = self._store(args) 

91 except Exception as e: 

92 excs = traceback.format_exc() 

93 es = str(e) 

94 if len(es) > 200: 

95 es = es[:200] + '...' 

96 duration = self.duration() 

97 log_data = dict(error=str(e), duration=duration, data=args) 

98 log_data.update(add_log_data) 

99 self.error("MLB.store", log_data) 

100 raise falcon.HTTPBadRequest( 

101 "Unable to upload model due to: {}".format(es), excs) 

102 

103 duration = self.duration() 

104 log_data = dict(duration=duration, name=name) 

105 self.info("MLB.store '%s'" % name, log_data) 

106 resp.status = falcon.HTTP_201 

107 answer = {"name": name} 

108 

109 elif command == 'predict': 

110 self.save_time() 

111 try: 

112 name, pred, version, loaded = self._predict(args) 

113 except Exception as e: 

114 excs = traceback.format_exc() 

115 es = str(e) 

116 if len(es) > 200: 

117 es = es[:200] + '...' 

118 name = args.get('name', None) 

119 duration = self.duration() 

120 log_data = dict(error=str(e), duration=duration, data=args) 

121 log_data.update(add_log_data) 

122 if name is not None: 

123 log_data['name'] = name 

124 if format is not None: 

125 log_data['format'] = format 

126 if input in args: 

127 log_data['input'] = args['input'] 

128 self.error("MLB.predict", log_data) 

129 raise falcon.HTTPBadRequest( 

130 "Unable to predict with model '{0}' due to: {1}, format='{2}'".format(name, es, format), excs) 

131 

132 duration = self.duration() 

133 log_data = dict(duration=duration, version=version, name=name) 

134 if loaded: 

135 self.info("MLB.predict '%s' loaded again" % name, log_data) 

136 else: 

137 self.info("MLB.predict '%s'" % name, log_data) 

138 resp.status = falcon.HTTP_201 

139 answer = {"output": pred, "version": version} 

140 else: 

141 es = "Unknown command '{0}'".format(command) 

142 log_data = dict(msg=es) 

143 log_data.update(add_log_data) 

144 self.error("MLStorage", log_data) 

145 raise falcon.HTTPBadRequest( 

146 'Unable to retrieve request content due to: {0}'.format(es)) 

147 

148 try: 

149 js = json_dumps(answer) 

150 except OverflowError as e: 

151 raise falcon.HTTPBadRequest( 

152 'Unable to retrieve request content due to: {0}'.format(e)) 

153 resp.body = js 

154 

155 def _store(self, args): 

156 """ 

157 Stores the model in the storage. 

158 """ 

159 name = args.pop('name', None) 

160 if name is None or name == '': 

161 keys = ", ".join(sorted(args.keys())) 

162 try: 

163 ms = str(args) 

164 except ValueError: 

165 ms = "" 

166 if len(ms) > 300: 

167 ms = ms[:300] + "..." 

168 raise KeyError( 

169 "Unable to find a model name in sent data, keys={0}, args={1}.".format(keys, ms)) 

170 zipped = args.pop('zip', None) 

171 if zipped is None: 

172 raise KeyError( 

173 "The REST API expects to find a zip file data in field 'zip'.") 

174 unstring = string2bytes(zipped) 

175 self._storage.add(name, unstring) 

176 return name 

177 

178 def _predict(self, args): 

179 """ 

180 Stores the model in the storage. 

181 """ 

182 name = args.get('name', None) 

183 if name is None: 

184 raise KeyError("Unable to find a model name in sent data.") 

185 form = args.get('format', 'json') 

186 data = args.get('input', None) 

187 if data is None: 

188 raise KeyError("The REST API expects to find field 'input' which contains stringified data the " 

189 "machine learned model can process. Field 'format' indicates which " 

190 "preprocessing to do before calling the model which is currently '{0}'".format(form)) 

191 if form == 'json': 

192 data = json_loads(data) 

193 elif form == 'img': 

194 simg = base64.b64decode(data) 

195 data = pickle.loads(simg) 

196 elif form == 'bytes': 

197 data = string2bytes(data) 

198 else: 

199 raise ValueError("Unrecognized format '{0}'.".format(form)) 

200 res, version, loaded = self._storage.call_predict( 

201 name, data, version=True, was_loaded=True) 

202 if isinstance(res, numpy.ndarray): 

203 res = res.tolist() 

204 return name, res, version, loaded