Coverage for src/lightmlrestapi/mlapp/mlstorage.py: 82%
232 statements
« prev ^ index » next coverage.py v6.4.1, created at 2022-06-06 07:16 +0200
« prev ^ index » next coverage.py v6.4.1, created at 2022-06-06 07:16 +0200
1"""
2@file
3@brief Machine Learning Post request
4"""
5import os
6import sys
7import json
8import threading
9import importlib
10from datetime import datetime
11# from filelock import Timeout, FileLock
12from ..args.zip_helper import unzip_bytes
15class AlreadyExistsException(Exception):
16 """
17 Exception raised when a project already exists.
18 """
19 pass
22class ZipStorage:
23 """
24 Stores and restores zipped files.
25 """
27 def __init__(self, folder):
28 """
29 @param folder folder
30 """
31 self._folder = folder
33 def enumerate_names(self):
34 """
35 Returns the list of sub folders.
36 """
37 for root, dirs, _ in os.walk(self._folder):
38 for name in dirs:
39 desc = os.path.join(root, name, ".desc")
40 if os.path.exists(desc):
41 zoo = os.path.relpath(
42 os.path.join(root, name), self._folder)
43 yield zoo.replace("\\", "/")
45 def exists(self, name):
46 """
47 Tells if project *name* exists.
49 @param name name
50 @return boolean
51 """
52 r = os.path.exists(self.get_full_name(name))
53 if not r:
54 return r
55 return os.path.exists(os.path.join(self.get_full_name(name), ".desc"))
57 def get_full_name(self, name):
58 """
59 Returns the full name of a project.
61 @param name project name
62 @return full name
63 """
64 return os.path.join(self._folder, name)
66 def _check_name(self, name, data=False):
67 """
68 A name is valid if it is a variable name
69 or a filename if *data* is True.
70 """
71 if name is None or not isinstance(name, str) or len(name) == 0:
72 raise ValueError("name cannot be empty.")
73 for i, c in enumerate(name):
74 if "a" <= c <= "z":
75 continue
76 if "A" <= c <= "Z":
77 continue
78 if "0" <= c <= "9" and i > 0:
79 continue
80 if c in '_/':
81 continue
82 if c == '.' and data:
83 continue
84 raise ValueError(
85 "A name contains a forbidden character '{0}'".format(name))
87 def verify_data(self, data):
88 """
89 Performs verifications to ensure the data to store
90 is ok.
92 @param data dictionary
93 @return None or information about the data
94 @raises raises an exception if not ok
95 """
96 if not isinstance(data, dict):
97 raise TypeError("data must be a dictionary.")
98 for k, v in data.items():
99 if not isinstance(k, str):
100 raise TypeError("Key must be a string.")
101 self._check_name(k, data=True)
102 if not isinstance(v, bytes):
103 raise TypeError(
104 "Values must be bytes for key '{0}'.".format(k))
105 return {}
107 def _makedirs(self, subfold):
108 """
109 Creates a subfolder and add a file ``__init__.py``.
110 The function overwrites it file ``__init__.py``
111 to let the interpreter know there was some changes.
112 """
113 spl = subfold.replace("\\", "/").split("/")
114 fold = self._folder
115 for sp in spl:
116 fold = os.path.join(fold, sp)
117 init = os.path.join(fold, '__init__.py')
118 if not os.path.exists(fold):
119 os.mkdir(fold)
120 with open(init, 'w') as f:
121 f.write('def do_exists():\n print("do exists")\n')
122 else:
123 with open(init, "r") as f:
124 content = f.read()
125 spl = content.split('do_exists')
126 content += '\ndef do_exists{0}():\n print("do exists{0}")\n'.format(
127 len(spl))
128 with open(init, "w") as f:
129 f.write(content)
131 def add(self, name, data):
132 """
133 Adds a project based on the data.
134 A project which already exists cannot be added.
136 @param name project name, should only contain
137 ascii characters + ``'/'``
138 @param data dictionary or bytes produced by
139 function @see fn zip_dict
140 """
141 # Verifications.
142 self._check_name(name)
143 if self.exists(name):
144 raise AlreadyExistsException(
145 "Project '{0}' already exists.".format(name))
146 if isinstance(data, bytes):
147 data = unzip_bytes(data)
148 dump = self.verify_data(data)
150 # Creates dictionary.
151 full = self.get_full_name(name)
152 self._makedirs(name)
153 desc = os.path.join(full, ".desc")
154 with open(desc, "w", encoding="utf-8") as fd:
155 fd.write("# ")
156 if dump is not None:
157 json.dump(dump, fd)
158 fd.write("\n")
160 # Stores.
161 # lock = FileLock(desc, timeout=2)
162 # with lock:
163 with open(desc, "a") as fd:
164 for k, v in sorted(data.items()):
165 subn = "{0}/{1}".format(name, k)
166 self._check_name(subn, data=True)
167 fd.write("{0}\n".format(k))
168 n = self.get_full_name(subn)
169 with open(n, "wb") as f:
170 f.write(v)
172 def get(self, name):
173 """
174 Retrieves a project based on its name.
176 @param name project name
177 @return data
178 """
179 if not self.exists(name):
180 raise FileNotFoundError(
181 "Project '{0}' does not exist.".format(name))
182 full = self.get_full_name(name)
183 desc = os.path.join(full, ".desc")
184 if not os.path.exists(desc):
185 raise FileNotFoundError(
186 "Project '{0}' does not exist.".format(name))
187 res = {}
188 # lock = FileLock(desc, timeout=1)
189 # with lock.acquire():
190 with open(desc, "r") as fd:
191 lines = fd.readlines()
192 lines = [_ for _ in lines if not _.startswith("#")]
193 for line in lines:
194 line = line.strip("\r\n ")
195 if line:
196 n = os.path.join(full, line)
197 with open(n, "rb") as f:
198 res[line] = f.read()
199 return res
201 def get_metadata(self, name):
202 """
203 Restores the data procuded by *verify_data*.
204 """
205 if not self.exists(name):
206 raise FileNotFoundError(
207 "Project '{0}' does not exist.".format(name))
208 full = self.get_full_name(name)
209 desc = os.path.join(full, ".desc")
210 if not os.path.exists(desc):
211 raise FileNotFoundError(
212 "Project '{0}' does not exist.".format(name))
213 with open(desc, "r", encoding="utf-8") as f:
214 first_line = f.readline().strip("# \n")
215 return json.loads(first_line)
218class MLStorage(ZipStorage):
219 """
220 Stores machine learned models into folders. The storages
221 expects to find at least one :epkg:`python` following
222 the specifications described at :ref:`l-mlapp-def`.
223 More template for actionable machine learned models
224 through the following template: :ref:`l-template-ml`.
225 """
227 def __init__(self, folder, cache_size=10):
228 """
229 @param folder folder
230 @param cache_size cache size
231 """
232 ZipStorage.__init__(self, folder)
233 self._cache_size = cache_size
234 self._cache = {}
235 self._lock = threading.Lock()
237 def verify_data(self, data):
238 """
239 Performs verifications to ensure the data to store
240 is ok. The storages expects to find at least one script
241 python with
243 @param data dictionary
244 @return python file which describes the model
245 @raises raises an exception if not ok
246 """
247 res = ZipStorage.verify_data(self, data)
248 main_script = None
249 for k, v in data.items():
250 if k.endswith(".py"):
251 content = v.decode("utf-8")
252 if "def restapi_version():" in content:
253 main_script = k
254 break
255 if main_script is None:
256 sorted_keys = ", ".join(sorted(data.keys()))
257 raise RuntimeError(
258 "Unable to find a script with 'def restapi_version():' inside.. List of found keys is {0}.".format(sorted_keys))
259 res.update(dict(main_script=main_script))
260 return res
262 def empty_cache(self):
263 """
264 Removes one place in the cache if the cache
265 is full. Sort them by last access.
266 """
267 if len(self._cache) < self._cache_size:
268 return
269 els = [(v['last'], k) for k, v in self._cache.items()]
270 els.sort()
271 self._lock.acquire()
272 del self._cache[els[0][1]]
273 self._lock.release()
275 def _import(self, name):
276 """
277 Imports the main module for one model.
279 @param name model name
280 @return imported module
281 """
282 meta = self.get_metadata(name)
283 loc = self.get_full_name(name)
284 script = os.path.join(loc, meta['main_script'])
285 if not os.path.exists(script):
286 raise FileNotFoundError(
287 "Unable to find script '{0}'".format(script))
289 fold, modname = os.path.split(script)
290 sys.path.insert(0, self._folder)
291 full_modname = ".".join([name.replace("/", "."),
292 os.path.splitext(modname)[0]])
294 def import_module():
295 try:
296 mod = importlib.import_module(full_modname)
297 # mod = __import__(full_modname)
298 except (ImportError, ModuleNotFoundError) as e:
299 with open(script, "r") as f:
300 code = f.read()
301 values = dict(self_folder=self._folder, name=name, meta=str(meta),
302 loc=loc, script=script, fold=fold, modname=modname,
303 full_modname=full_modname)
304 values = '\n'.join('{}={}'.format(k, v)
305 for k, v in values.items())
306 raise ImportError(
307 "Unable to compile file '{0}'\ndue to {1}\n{2}\n---\n{3}".format(script, e, code, values)) from e
308 return mod
310 try:
311 mod = import_module()
312 except ImportError:
313 # Reload modules.
314 specs = []
315 spl = full_modname.split('.')
316 for i in range(len(spl) - 1):
317 name = '.'.join(spl[:i + 1])
318 if name in sys.modules:
319 del sys.modules[name]
320 importlib.invalidate_caches()
321 spec = importlib.util.find_spec(name)
322 specs.append((name, spec))
323 mod = importlib.import_module(name)
324 importlib.reload(mod)
325 try:
326 mod = import_module()
327 except ImportError as ee:
328 del sys.path[0]
329 mes = "\n".join("{0}: {1}".format(a, b) for a, b in specs)
330 raise ImportError("Unable to import module '{0}', specs=\n{1}".format(
331 full_modname, mes)) from ee
333 del sys.path[0]
335 if not hasattr(mod, "restapi_load"):
336 raise ImportError(
337 "Unable to find function 'restapi_load' in module '{0}'".format(mod.__name__))
338 return mod
340 def load_model(self, name, was_loaded=False):
341 """
342 Loads a model into the cache if not loaded
343 and returns it.
345 @param name cache name
346 @param was_loaded if True, tells if the model was loaded again
347 @return dictionary with keys: *last*, *model*, *module*.
348 """
349 if name in self._cache:
350 self._lock.acquire()
351 res = self._cache[name]
352 res['last'] = datetime.now()
353 self._lock.release()
354 if was_loaded:
355 return res, False
356 else:
357 return res
359 self.empty_cache()
361 # Imports the module.
362 self._lock.acquire()
363 try:
364 mod = self._import(name)
365 finally:
366 self._lock.release()
368 # Loads the models.
369 self._lock.acquire()
370 try:
371 model = mod.restapi_load()
372 finally:
373 self._lock.release()
375 res = dict(last=datetime.now(), model=model, module=mod)
376 self._lock.acquire()
377 self._cache[name] = res
378 self._lock.release()
379 if was_loaded:
380 return res, True
381 else:
382 return res
384 def call_predict(self, name, data, version=False, was_loaded=False, loaded_model=None):
385 """
386 Calls method *restapi_predict* from a stored script *python*.
388 @param name model name
389 @param data input data
390 @param version returns the version as well
391 @param was_loaded if True, return if the model was loaded again
392 @param loaded_model skip cached model if exists, should be the result of
393 a previous call to @see me loaded_model
394 @return *predictions* or *predictions, version*
395 """
396 if loaded_model is None:
397 res = self.load_model(name, was_loaded=was_loaded)
398 if was_loaded:
399 res, loaded = res
400 else:
401 res, loaded = loaded_model, False
402 pred = res['module'].restapi_predict(res['model'], data)
403 if version:
404 version = res['module'].restapi_version()
405 if was_loaded:
406 return pred, version, loaded
407 else:
408 return pred, version
409 else:
410 if was_loaded:
411 return pred, loaded
412 else:
413 return pred
415 def call_version(self, name):
416 """
417 Calls method *restapi_version* from a stored script *python*.
418 """
419 res = self.load_model(name)
420 return res['module'].restapi_version()