Coverage for src/lightmlrestapi/mlapp/mlstorage.py: 82%

232 statements  

« prev     ^ index     » next       coverage.py v6.4.1, created at 2022-06-06 07:16 +0200

1""" 

2@file 

3@brief Machine Learning Post request 

4""" 

5import os 

6import sys 

7import json 

8import threading 

9import importlib 

10from datetime import datetime 

11# from filelock import Timeout, FileLock 

12from ..args.zip_helper import unzip_bytes 

13 

14 

15class AlreadyExistsException(Exception): 

16 """ 

17 Exception raised when a project already exists. 

18 """ 

19 pass 

20 

21 

22class ZipStorage: 

23 """ 

24 Stores and restores zipped files. 

25 """ 

26 

27 def __init__(self, folder): 

28 """ 

29 @param folder folder 

30 """ 

31 self._folder = folder 

32 

33 def enumerate_names(self): 

34 """ 

35 Returns the list of sub folders. 

36 """ 

37 for root, dirs, _ in os.walk(self._folder): 

38 for name in dirs: 

39 desc = os.path.join(root, name, ".desc") 

40 if os.path.exists(desc): 

41 zoo = os.path.relpath( 

42 os.path.join(root, name), self._folder) 

43 yield zoo.replace("\\", "/") 

44 

45 def exists(self, name): 

46 """ 

47 Tells if project *name* exists. 

48 

49 @param name name 

50 @return boolean 

51 """ 

52 r = os.path.exists(self.get_full_name(name)) 

53 if not r: 

54 return r 

55 return os.path.exists(os.path.join(self.get_full_name(name), ".desc")) 

56 

57 def get_full_name(self, name): 

58 """ 

59 Returns the full name of a project. 

60 

61 @param name project name 

62 @return full name 

63 """ 

64 return os.path.join(self._folder, name) 

65 

66 def _check_name(self, name, data=False): 

67 """ 

68 A name is valid if it is a variable name 

69 or a filename if *data* is True. 

70 """ 

71 if name is None or not isinstance(name, str) or len(name) == 0: 

72 raise ValueError("name cannot be empty.") 

73 for i, c in enumerate(name): 

74 if "a" <= c <= "z": 

75 continue 

76 if "A" <= c <= "Z": 

77 continue 

78 if "0" <= c <= "9" and i > 0: 

79 continue 

80 if c in '_/': 

81 continue 

82 if c == '.' and data: 

83 continue 

84 raise ValueError( 

85 "A name contains a forbidden character '{0}'".format(name)) 

86 

87 def verify_data(self, data): 

88 """ 

89 Performs verifications to ensure the data to store 

90 is ok. 

91 

92 @param data dictionary 

93 @return None or information about the data 

94 @raises raises an exception if not ok 

95 """ 

96 if not isinstance(data, dict): 

97 raise TypeError("data must be a dictionary.") 

98 for k, v in data.items(): 

99 if not isinstance(k, str): 

100 raise TypeError("Key must be a string.") 

101 self._check_name(k, data=True) 

102 if not isinstance(v, bytes): 

103 raise TypeError( 

104 "Values must be bytes for key '{0}'.".format(k)) 

105 return {} 

106 

107 def _makedirs(self, subfold): 

108 """ 

109 Creates a subfolder and add a file ``__init__.py``. 

110 The function overwrites it file ``__init__.py`` 

111 to let the interpreter know there was some changes. 

112 """ 

113 spl = subfold.replace("\\", "/").split("/") 

114 fold = self._folder 

115 for sp in spl: 

116 fold = os.path.join(fold, sp) 

117 init = os.path.join(fold, '__init__.py') 

118 if not os.path.exists(fold): 

119 os.mkdir(fold) 

120 with open(init, 'w') as f: 

121 f.write('def do_exists():\n print("do exists")\n') 

122 else: 

123 with open(init, "r") as f: 

124 content = f.read() 

125 spl = content.split('do_exists') 

126 content += '\ndef do_exists{0}():\n print("do exists{0}")\n'.format( 

127 len(spl)) 

128 with open(init, "w") as f: 

129 f.write(content) 

130 

131 def add(self, name, data): 

132 """ 

133 Adds a project based on the data. 

134 A project which already exists cannot be added. 

135 

136 @param name project name, should only contain 

137 ascii characters + ``'/'`` 

138 @param data dictionary or bytes produced by 

139 function @see fn zip_dict 

140 """ 

141 # Verifications. 

142 self._check_name(name) 

143 if self.exists(name): 

144 raise AlreadyExistsException( 

145 "Project '{0}' already exists.".format(name)) 

146 if isinstance(data, bytes): 

147 data = unzip_bytes(data) 

148 dump = self.verify_data(data) 

149 

150 # Creates dictionary. 

151 full = self.get_full_name(name) 

152 self._makedirs(name) 

153 desc = os.path.join(full, ".desc") 

154 with open(desc, "w", encoding="utf-8") as fd: 

155 fd.write("# ") 

156 if dump is not None: 

157 json.dump(dump, fd) 

158 fd.write("\n") 

159 

160 # Stores. 

161 # lock = FileLock(desc, timeout=2) 

162 # with lock: 

163 with open(desc, "a") as fd: 

164 for k, v in sorted(data.items()): 

165 subn = "{0}/{1}".format(name, k) 

166 self._check_name(subn, data=True) 

167 fd.write("{0}\n".format(k)) 

168 n = self.get_full_name(subn) 

169 with open(n, "wb") as f: 

170 f.write(v) 

171 

172 def get(self, name): 

173 """ 

174 Retrieves a project based on its name. 

175 

176 @param name project name 

177 @return data 

178 """ 

179 if not self.exists(name): 

180 raise FileNotFoundError( 

181 "Project '{0}' does not exist.".format(name)) 

182 full = self.get_full_name(name) 

183 desc = os.path.join(full, ".desc") 

184 if not os.path.exists(desc): 

185 raise FileNotFoundError( 

186 "Project '{0}' does not exist.".format(name)) 

187 res = {} 

188 # lock = FileLock(desc, timeout=1) 

189 # with lock.acquire(): 

190 with open(desc, "r") as fd: 

191 lines = fd.readlines() 

192 lines = [_ for _ in lines if not _.startswith("#")] 

193 for line in lines: 

194 line = line.strip("\r\n ") 

195 if line: 

196 n = os.path.join(full, line) 

197 with open(n, "rb") as f: 

198 res[line] = f.read() 

199 return res 

200 

201 def get_metadata(self, name): 

202 """ 

203 Restores the data procuded by *verify_data*. 

204 """ 

205 if not self.exists(name): 

206 raise FileNotFoundError( 

207 "Project '{0}' does not exist.".format(name)) 

208 full = self.get_full_name(name) 

209 desc = os.path.join(full, ".desc") 

210 if not os.path.exists(desc): 

211 raise FileNotFoundError( 

212 "Project '{0}' does not exist.".format(name)) 

213 with open(desc, "r", encoding="utf-8") as f: 

214 first_line = f.readline().strip("# \n") 

215 return json.loads(first_line) 

216 

217 

218class MLStorage(ZipStorage): 

219 """ 

220 Stores machine learned models into folders. The storages 

221 expects to find at least one :epkg:`python` following 

222 the specifications described at :ref:`l-mlapp-def`. 

223 More template for actionable machine learned models 

224 through the following template: :ref:`l-template-ml`. 

225 """ 

226 

227 def __init__(self, folder, cache_size=10): 

228 """ 

229 @param folder folder 

230 @param cache_size cache size 

231 """ 

232 ZipStorage.__init__(self, folder) 

233 self._cache_size = cache_size 

234 self._cache = {} 

235 self._lock = threading.Lock() 

236 

237 def verify_data(self, data): 

238 """ 

239 Performs verifications to ensure the data to store 

240 is ok. The storages expects to find at least one script 

241 python with 

242 

243 @param data dictionary 

244 @return python file which describes the model 

245 @raises raises an exception if not ok 

246 """ 

247 res = ZipStorage.verify_data(self, data) 

248 main_script = None 

249 for k, v in data.items(): 

250 if k.endswith(".py"): 

251 content = v.decode("utf-8") 

252 if "def restapi_version():" in content: 

253 main_script = k 

254 break 

255 if main_script is None: 

256 sorted_keys = ", ".join(sorted(data.keys())) 

257 raise RuntimeError( 

258 "Unable to find a script with 'def restapi_version():' inside.. List of found keys is {0}.".format(sorted_keys)) 

259 res.update(dict(main_script=main_script)) 

260 return res 

261 

262 def empty_cache(self): 

263 """ 

264 Removes one place in the cache if the cache 

265 is full. Sort them by last access. 

266 """ 

267 if len(self._cache) < self._cache_size: 

268 return 

269 els = [(v['last'], k) for k, v in self._cache.items()] 

270 els.sort() 

271 self._lock.acquire() 

272 del self._cache[els[0][1]] 

273 self._lock.release() 

274 

275 def _import(self, name): 

276 """ 

277 Imports the main module for one model. 

278 

279 @param name model name 

280 @return imported module 

281 """ 

282 meta = self.get_metadata(name) 

283 loc = self.get_full_name(name) 

284 script = os.path.join(loc, meta['main_script']) 

285 if not os.path.exists(script): 

286 raise FileNotFoundError( 

287 "Unable to find script '{0}'".format(script)) 

288 

289 fold, modname = os.path.split(script) 

290 sys.path.insert(0, self._folder) 

291 full_modname = ".".join([name.replace("/", "."), 

292 os.path.splitext(modname)[0]]) 

293 

294 def import_module(): 

295 try: 

296 mod = importlib.import_module(full_modname) 

297 # mod = __import__(full_modname) 

298 except (ImportError, ModuleNotFoundError) as e: 

299 with open(script, "r") as f: 

300 code = f.read() 

301 values = dict(self_folder=self._folder, name=name, meta=str(meta), 

302 loc=loc, script=script, fold=fold, modname=modname, 

303 full_modname=full_modname) 

304 values = '\n'.join('{}={}'.format(k, v) 

305 for k, v in values.items()) 

306 raise ImportError( 

307 "Unable to compile file '{0}'\ndue to {1}\n{2}\n---\n{3}".format(script, e, code, values)) from e 

308 return mod 

309 

310 try: 

311 mod = import_module() 

312 except ImportError: 

313 # Reload modules. 

314 specs = [] 

315 spl = full_modname.split('.') 

316 for i in range(len(spl) - 1): 

317 name = '.'.join(spl[:i + 1]) 

318 if name in sys.modules: 

319 del sys.modules[name] 

320 importlib.invalidate_caches() 

321 spec = importlib.util.find_spec(name) 

322 specs.append((name, spec)) 

323 mod = importlib.import_module(name) 

324 importlib.reload(mod) 

325 try: 

326 mod = import_module() 

327 except ImportError as ee: 

328 del sys.path[0] 

329 mes = "\n".join("{0}: {1}".format(a, b) for a, b in specs) 

330 raise ImportError("Unable to import module '{0}', specs=\n{1}".format( 

331 full_modname, mes)) from ee 

332 

333 del sys.path[0] 

334 

335 if not hasattr(mod, "restapi_load"): 

336 raise ImportError( 

337 "Unable to find function 'restapi_load' in module '{0}'".format(mod.__name__)) 

338 return mod 

339 

340 def load_model(self, name, was_loaded=False): 

341 """ 

342 Loads a model into the cache if not loaded 

343 and returns it. 

344 

345 @param name cache name 

346 @param was_loaded if True, tells if the model was loaded again 

347 @return dictionary with keys: *last*, *model*, *module*. 

348 """ 

349 if name in self._cache: 

350 self._lock.acquire() 

351 res = self._cache[name] 

352 res['last'] = datetime.now() 

353 self._lock.release() 

354 if was_loaded: 

355 return res, False 

356 else: 

357 return res 

358 

359 self.empty_cache() 

360 

361 # Imports the module. 

362 self._lock.acquire() 

363 try: 

364 mod = self._import(name) 

365 finally: 

366 self._lock.release() 

367 

368 # Loads the models. 

369 self._lock.acquire() 

370 try: 

371 model = mod.restapi_load() 

372 finally: 

373 self._lock.release() 

374 

375 res = dict(last=datetime.now(), model=model, module=mod) 

376 self._lock.acquire() 

377 self._cache[name] = res 

378 self._lock.release() 

379 if was_loaded: 

380 return res, True 

381 else: 

382 return res 

383 

384 def call_predict(self, name, data, version=False, was_loaded=False, loaded_model=None): 

385 """ 

386 Calls method *restapi_predict* from a stored script *python*. 

387 

388 @param name model name 

389 @param data input data 

390 @param version returns the version as well 

391 @param was_loaded if True, return if the model was loaded again 

392 @param loaded_model skip cached model if exists, should be the result of 

393 a previous call to @see me loaded_model 

394 @return *predictions* or *predictions, version* 

395 """ 

396 if loaded_model is None: 

397 res = self.load_model(name, was_loaded=was_loaded) 

398 if was_loaded: 

399 res, loaded = res 

400 else: 

401 res, loaded = loaded_model, False 

402 pred = res['module'].restapi_predict(res['model'], data) 

403 if version: 

404 version = res['module'].restapi_version() 

405 if was_loaded: 

406 return pred, version, loaded 

407 else: 

408 return pred, version 

409 else: 

410 if was_loaded: 

411 return pred, loaded 

412 else: 

413 return pred 

414 

415 def call_version(self, name): 

416 """ 

417 Calls method *restapi_version* from a stored script *python*. 

418 """ 

419 res = self.load_model(name) 

420 return res['module'].restapi_version()