Source code for lightmlrestapi.testing.template_dl_torch

"""
Template application for a machine learning model
based on :epkg:`torch` available through a REST API.


:githublink:`%|py|6`
"""
import os
import numpy
import skimage.transform as skt


# Declare an id for the REST API.
[docs]def restapi_version(): """ Displays a version. :githublink:`%|py|15` """ return "0.1.1238"
# Declare a loading function.
[docs]def restapi_load(files={"model": "dlmodel.torch"}): # pylint: disable=W0102 """ Loads the model. The model name is relative to this file. When call by a REST API, the default value is always used. :githublink:`%|py|25` """ model = files["model"] here = os.path.dirname(__file__) model = os.path.join(here, model) if not os.path.exists(model): raise FileNotFoundError("Cannot find model '{0}' (full path is '{1}')".format( model, os.path.abspath(model))) import torch # pylint: disable=E0401,C0415 loaded_model = torch.load(model) return loaded_model
# Declare a predict function.
[docs]def restapi_predict(model, X): """ Computes the prediction for model *clf*. :param model: pipeline following :epkg:`scikit-learn` API :param X: image as a :epkg:`numpy` array :return: output of *predict_proba* :githublink:`%|py|46` """ from torch import from_numpy # pylint: disable=E0611,E0401 if not isinstance(X, numpy.ndarray): raise TypeError("X must be an array") im = X im = skt.resize(im, (3, 224, 224)) #im = numpy.transpose(im, (1, 2, 0)) im = im[numpy.newaxis, :, :, :] ten = from_numpy(im.astype(numpy.float32)) pred = model.forward(ten) return pred.detach().numpy()