Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Template application for a machine learning model 

4available through a REST API. 

5""" 

6import pickle 

7import os 

8 

9 

10# Declare an id for the REST API. 

11def restapi_version(): 

12 """ 

13 Displays a version. 

14 """ 

15 return "0.1.1234" 

16 

17 

18# Declare a loading function. 

19def restapi_load(files={"model": "iris2.pkl"}): # pylint: disable=W0102 

20 """ 

21 Loads the model. 

22 The model name is relative to this file. 

23 When call by a REST API, the default value is always used. 

24 """ 

25 model = files["model"] 

26 here = os.path.dirname(__file__) 

27 model = os.path.join(here, model) 

28 if not os.path.exists(model): 

29 raise FileNotFoundError("Cannot find model '{0}' (full path is '{1}')".format( 

30 model, os.path.abspath(model))) 

31 with open(model, "rb") as f: 

32 loaded_model = pickle.load(f) 

33 return loaded_model 

34 

35 

36# Declare a predict function. 

37def restapi_predict(model, X): 

38 """ 

39 Computes the prediction for model *clf*. 

40 

41 :param model: pipeline following :epkg:`scikit-learn` API 

42 :param X: inputs 

43 :return: output of *predict_proba* 

44 """ 

45 return model.predict_proba(X)