Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Template application for a machine learning model 

4based on :epkg:`keras` available through a REST API. 

5""" 

6import os 

7import numpy 

8import skimage.transform as skt 

9 

10 

11# Declare an id for the REST API. 

12def restapi_version(): 

13 """ 

14 Displays a version. 

15 """ 

16 return "0.1.1237" 

17 

18 

19# Declare a loading function. 

20def restapi_load(files={'model': "dlmodel.keras"}): # pylint: disable=W0102 

21 """ 

22 Loads the model. 

23 The model name is relative to this file. 

24 When call by a REST API, the default value is always used. 

25 """ 

26 from keras.models import load_model # pylint: disable=E0401,E0611 

27 model = files["model"] 

28 here = os.path.dirname(__file__) 

29 model = os.path.join(here, model) 

30 if not os.path.exists(model): 

31 raise FileNotFoundError("Cannot find model '{0}' (full path is '{1}')".format( 

32 model, os.path.abspath(model))) 

33 loaded_model = load_model(model) 

34 return loaded_model 

35 

36 

37# Declare a predict function. 

38def restapi_predict(model, X): 

39 """ 

40 Computes the prediction for model *clf*. 

41 

42 :param model: pipeline following :epkg:`scikit-learn` API 

43 :param X: image as a :epkg:`numpy` array 

44 :return: output of *predict_proba* 

45 """ 

46 if not isinstance(X, numpy.ndarray): 

47 raise TypeError("X must be an array") 

48 im = X 

49 im = skt.resize(im, (3, 224, 224)) 

50 im = numpy.transpose(im, (1, 2, 0)) 

51 im = im[numpy.newaxis, :, :, :] 

52 return model.predict(im)