Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Template application for a machine learning model 

4based on :epkg:`torch` available through a REST API. 

5""" 

6import os 

7import numpy 

8import skimage.transform as skt 

9 

10 

11# Declare an id for the REST API. 

12def restapi_version(): 

13 """ 

14 Displays a version. 

15 """ 

16 return "0.1.1238" 

17 

18 

19# Declare a loading function. 

20def restapi_load(files={"model": "dlmodel.torch"}): # pylint: disable=W0102 

21 """ 

22 Loads the model. 

23 The model name is relative to this file. 

24 When call by a REST API, the default value is always used. 

25 """ 

26 model = files["model"] 

27 here = os.path.dirname(__file__) 

28 model = os.path.join(here, model) 

29 if not os.path.exists(model): 

30 raise FileNotFoundError("Cannot find model '{0}' (full path is '{1}')".format( 

31 model, os.path.abspath(model))) 

32 import torch # pylint: disable=E0401,C0415 

33 loaded_model = torch.load(model) 

34 return loaded_model 

35 

36# Declare a predict function. 

37 

38 

39def restapi_predict(model, X): 

40 """ 

41 Computes the prediction for model *clf*. 

42 

43 :param model: pipeline following :epkg:`scikit-learn` API 

44 :param X: image as a :epkg:`numpy` array 

45 :return: output of *predict_proba* 

46 """ 

47 from torch import from_numpy # pylint: disable=E0611,E0401 

48 if not isinstance(X, numpy.ndarray): 

49 raise TypeError("X must be an array") 

50 im = X 

51 im = skt.resize(im, (3, 224, 224)) 

52 #im = numpy.transpose(im, (1, 2, 0)) 

53 im = im[numpy.newaxis, :, :, :] 

54 ten = from_numpy(im.astype(numpy.float32)) 

55 pred = model.forward(ten) 

56 return pred.detach().numpy()