Coverage for src/lightmlrestapi/testing/template_dl_light.py: 92%

25 statements  

« prev     ^ index     » next       coverage.py v6.4.1, created at 2022-06-06 07:16 +0200

1""" 

2@file 

3@brief Template application for a machine learning model 

4available through a REST API and using images like 

5deep learning models. 

6""" 

7import pickle 

8import os 

9import numpy 

10import skimage.transform as skt 

11 

12 

13# Declare an id for the REST API. 

14def restapi_version(): 

15 """ 

16 Displays a version. 

17 """ 

18 return "0.1.1235" 

19 

20 

21# Declare a loading function. 

22def restapi_load(files={"model": "dlimg.pkl"}): # pylint: disable=W0102 

23 """ 

24 Loads the model. 

25 The model name is relative to this file. 

26 When call by a REST API, the default value is always used. 

27 """ 

28 model = files['model'] 

29 here = os.path.dirname(__file__) 

30 model = os.path.join(here, model) 

31 if not os.path.exists(model): 

32 raise FileNotFoundError("Cannot find model '{0}' (full path is '{1}')".format( 

33 model, os.path.abspath(model))) 

34 with open(model, "rb") as f: 

35 loaded_model = pickle.load(f) 

36 return loaded_model 

37 

38 

39# Declare a predict function. 

40def restapi_predict(model, X): 

41 """ 

42 Computes the prediction for model *clf*. 

43 

44 :param model: pipeline following :epkg:`scikit-learn` API 

45 :param X: image as a :epkg:`numpy` array 

46 :return: output of *predict_proba* 

47 """ 

48 if not isinstance(X, numpy.ndarray): 

49 raise TypeError("X must be an array") 

50 im1 = model 

51 im2 = X 

52 im1 = skt.resize(im1, (3, 224, 224)) 

53 im2 = skt.resize(im2, (3, 224, 224)) 

54 diff = im1.ravel() - im2.ravel() 

55 total = numpy.abs(diff) 

56 return total.sum() / float(len(total)) / 255