Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Template application for a machine learning model
4based on :epkg:`torch` available through a REST API.
5"""
6import os
7import numpy
8import skimage.transform as skt
11# Declare an id for the REST API.
12def restapi_version():
13 """
14 Displays a version.
15 """
16 return "0.1.1238"
19# Declare a loading function.
20def restapi_load(files={"model": "dlmodel.torch"}): # pylint: disable=W0102
21 """
22 Loads the model.
23 The model name is relative to this file.
24 When call by a REST API, the default value is always used.
25 """
26 model = files["model"]
27 here = os.path.dirname(__file__)
28 model = os.path.join(here, model)
29 if not os.path.exists(model):
30 raise FileNotFoundError("Cannot find model '{0}' (full path is '{1}')".format(
31 model, os.path.abspath(model)))
32 import torch # pylint: disable=E0401,C0415
33 loaded_model = torch.load(model)
34 return loaded_model
36# Declare a predict function.
39def restapi_predict(model, X):
40 """
41 Computes the prediction for model *clf*.
43 :param model: pipeline following :epkg:`scikit-learn` API
44 :param X: image as a :epkg:`numpy` array
45 :return: output of *predict_proba*
46 """
47 from torch import from_numpy # pylint: disable=E0611,E0401
48 if not isinstance(X, numpy.ndarray):
49 raise TypeError("X must be an array")
50 im = X
51 im = skt.resize(im, (3, 224, 224))
52 #im = numpy.transpose(im, (1, 2, 0))
53 im = im[numpy.newaxis, :, :, :]
54 ten = from_numpy(im.astype(numpy.float32))
55 pred = model.forward(ten)
56 return pred.detach().numpy()