Templates¶
Scikit-learn template¶
Classes MLStorage
and
MLStoragePost
assume that a model is actionable by implementing
the following template:
"""
Template application for a machine learning model
available through a REST API.
:githublink:`%|py|6`
"""
import pickle
import os
# Declare an id for the REST API.
def restapi_version():
"""
Displays a version.
:githublink:`%|py|14`
"""
return "0.1.1234"
# Declare a loading function.
def restapi_load(files={"model": "iris2.pkl"}): # pylint: disable=W0102
"""
Loads the model.
The model name is relative to this file.
When call by a REST API, the default value is always used.
:githublink:`%|py|24`
"""
model = files["model"]
here = os.path.dirname(__file__)
model = os.path.join(here, model)
if not os.path.exists(model):
raise FileNotFoundError("Cannot find model '{0}' (full path is '{1}')".format(
model, os.path.abspath(model)))
with open(model, "rb") as f:
loaded_model = pickle.load(f)
return loaded_model
# Declare a predict function.
def restapi_predict(model, X):
"""
Computes the prediction for model *clf*.
:param model: pipeline following :epkg:`scikit-learn` API
:param X: inputs
:return: output of *predict_proba*
:githublink:`%|py|44`
"""
return model.predict_proba(X)
Model with image¶
The second template shows how to deal with images with a dummy example which computes the distance between two images.
"""
Template application for a machine learning model
available through a REST API and using images like
deep learning models.
:githublink:`%|py|7`
"""
import pickle
import os
import numpy
import skimage.transform as skt
# Declare an id for the REST API.
def restapi_version():
"""
Displays a version.
:githublink:`%|py|17`
"""
return "0.1.1235"
# Declare a loading function.
def restapi_load(files={"model": "dlimg.pkl"}): # pylint: disable=W0102
"""
Loads the model.
The model name is relative to this file.
When call by a REST API, the default value is always used.
:githublink:`%|py|27`
"""
model = files['model']
here = os.path.dirname(__file__)
model = os.path.join(here, model)
if not os.path.exists(model):
raise FileNotFoundError("Cannot find model '{0}' (full path is '{1}')".format(
model, os.path.abspath(model)))
with open(model, "rb") as f:
loaded_model = pickle.load(f)
return loaded_model
# Declare a predict function.
def restapi_predict(model, X):
"""
Computes the prediction for model *clf*.
:param model: pipeline following :epkg:`scikit-learn` API
:param X: image as a :epkg:`numpy` array
:return: output of *predict_proba*
:githublink:`%|py|47`
"""
if not isinstance(X, numpy.ndarray):
raise TypeError("X must be an array")
im1 = model
im2 = X
im1 = skt.resize(im1, (3, 224, 224))
im2 = skt.resize(im2, (3, 224, 224))
diff = im1.ravel() - im2.ravel()
total = numpy.abs(diff)
return total.sum() / float(len(total)) / 255
Model with keras¶
Thrid template with keras and a model trained on ImageNet.
"""
Template application for a machine learning model
based on :epkg:`keras` available through a REST API.
:githublink:`%|py|6`
"""
import os
import numpy
import skimage.transform as skt
# Declare an id for the REST API.
def restapi_version():
"""
Displays a version.
:githublink:`%|py|15`
"""
return "0.1.1237"
# Declare a loading function.
def restapi_load(files={'model': "dlmodel.keras"}): # pylint: disable=W0102
"""
Loads the model.
The model name is relative to this file.
When call by a REST API, the default value is always used.
:githublink:`%|py|25`
"""
from keras.models import load_model # pylint: disable=E0401,E0611
model = files["model"]
here = os.path.dirname(__file__)
model = os.path.join(here, model)
if not os.path.exists(model):
raise FileNotFoundError("Cannot find model '{0}' (full path is '{1}')".format(
model, os.path.abspath(model)))
loaded_model = load_model(model)
return loaded_model
# Declare a predict function.
def restapi_predict(model, X):
"""
Computes the prediction for model *clf*.
:param model: pipeline following :epkg:`scikit-learn` API
:param X: image as a :epkg:`numpy` array
:return: output of *predict_proba*
:githublink:`%|py|45`
"""
if not isinstance(X, numpy.ndarray):
raise TypeError("X must be an array")
im = X
im = skt.resize(im, (3, 224, 224))
im = numpy.transpose(im, (1, 2, 0))
im = im[numpy.newaxis, :, :, :]
return model.predict(im)