Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Command line about validation of prediction runtime. 

4""" 

5import os 

6import pickle 

7from logging import getLogger 

8import warnings 

9from pandas import read_csv 

10from skl2onnx.common.data_types import FloatTensorType, DoubleTensorType 

11from ..onnx_conv import to_onnx 

12from ..onnxrt import OnnxInference 

13from ..onnx_tools.optim import onnx_optimisations 

14from ..onnxrt.validate.validate_difference import measure_relative_difference 

15from ..onnx_conv import guess_schema_from_data, guess_schema_from_model 

16 

17 

18def convert_validate(pkl, data=None, schema=None, 

19 method="predict", name='Y', 

20 target_opset=None, 

21 outonnx="model.onnx", 

22 runtime='python', metric="l1med", 

23 use_double=None, noshape=False, 

24 optim='onnx', rewrite_ops=True, 

25 options=None, fLOG=print, verbose=1, 

26 register=True): 

27 """ 

28 Converts a model stored in *pkl* file and measure the differences 

29 between the model and the ONNX predictions. 

30 

31 :param pkl: pickle file 

32 :param data: data file, loaded with pandas, 

33 converted to a single array, the data is used to guess 

34 the schema if *schema* not specified 

35 :param schema: initial type of the model 

36 :param method: method to call 

37 :param name: output name 

38 :param target_opset: target opset 

39 :param outonnx: produced ONNX model 

40 :param runtime: runtime to use to compute predictions, 

41 'python', 'python_compiled', 

42 'onnxruntime1' or 'onnxruntime2' 

43 :param metric: the metric 'l1med' is given by function 

44 :func:`measure_relative_difference 

45 <mlprodict.onnxrt.validate.validate_difference.measure_relative_difference>` 

46 :param noshape: run the conversion with no shape information 

47 :param use_double: use double for the runtime if possible, 

48 two possible options, ``"float64"`` or ``'switch'``, 

49 the first option produces an ONNX file with doubles, 

50 the second option loads an ONNX file (float or double) 

51 and replaces matrices in ONNX with the matrices coming from 

52 the model, this second way is just for testing purposes 

53 :param optim: applies optimisations on the first ONNX graph, 

54 use 'onnx' to reduce the number of node Identity and 

55 redundant subgraphs 

56 :param rewrite_ops: rewrites some converters from skl2onnx 

57 :param options: additional options for conversion, 

58 dictionary as a string 

59 :param verbose: verbose level 

60 :param register: registers additional converters implemented by this package 

61 :param fLOG: logging function 

62 :return: a dictionary with the results 

63 

64 .. cmdref:: 

65 :title: Converts and compares an ONNX file 

66 :cmd: -m mlprodict convert_validate --help 

67 :lid: l-cmd-convert_validate 

68 

69 The command converts and validates a :epkg:`scikit-learn` model. 

70 An example to check the prediction of a logistic regression. 

71 

72 :: 

73 

74 import os 

75 import pickle 

76 import pandas 

77 from sklearn.datasets import load_iris 

78 from sklearn.model_selection import train_test_split 

79 from sklearn.linear_model import LogisticRegression 

80 from mlprodict.__main__ import main 

81 from mlprodict.cli import convert_validate 

82 

83 iris = load_iris() 

84 X, y = iris.data, iris.target 

85 X_train, X_test, y_train, _ = train_test_split(X, y, random_state=11) 

86 clr = LogisticRegression() 

87 clr.fit(X_train, y_train) 

88 

89 pandas.DataFrame(X_test).to_csv("data.csv", index=False) 

90 with open("model.pkl", "wb") as f: 

91 pickle.dump(clr, f) 

92 

93 And the command line to check the predictions 

94 using a command line. 

95 

96 :: 

97 

98 convert_validate --pkl model.pkl --data data.csv 

99 --method predict,predict_proba 

100 --name output_label,output_probability 

101 --verbose 1 

102 """ 

103 if fLOG is None: 

104 verbose = 0 # pragma: no cover 

105 if use_double not in (None, 'float64', 'switch'): 

106 raise ValueError( # pragma: no cover 

107 "use_double must be either None, 'float64' or 'switch'") 

108 if optim == '': 

109 optim = None # pragma: no cover 

110 if target_opset == '': 

111 target_opset = None # pragma: no cover 

112 if verbose == 0: 

113 logger = getLogger('skl2onnx') 

114 logger.disabled = True 

115 if not os.path.exists(pkl): 

116 raise FileNotFoundError( # pragma: no cover 

117 "Unable to find model '{}'.".format(pkl)) 

118 if os.path.exists(outonnx): 

119 warnings.warn("File '{}' will be overwritten.".format(outonnx)) 

120 if verbose > 0: 

121 fLOG("[convert_validate] load model '{}'".format(pkl)) 

122 with open(pkl, "rb") as f: 

123 model = pickle.load(f) 

124 

125 if use_double == 'float64': 

126 tensor_type = DoubleTensorType 

127 else: 

128 tensor_type = FloatTensorType 

129 if options in (None, ''): 

130 options = None 

131 else: 

132 from ..onnxrt.validate.validate_scenarios import ( 

133 interpret_options_from_string) 

134 options = interpret_options_from_string(options) 

135 if verbose > 0: 

136 fLOG("[convert_validate] options={}".format(repr(options))) 

137 

138 if register: 

139 from ..onnx_conv import ( 

140 register_converters, register_rewritten_operators) 

141 register_converters() 

142 register_rewritten_operators() 

143 

144 # data and schema 

145 if data is None or not os.path.exists(data): 

146 if schema is None: 

147 schema = guess_schema_from_model(model, tensor_type) 

148 if verbose > 0: 

149 fLOG("[convert_validate] model schema={}".format(schema)) 

150 df = None 

151 else: 

152 if verbose > 0: 

153 fLOG("[convert_validate] load data '{}'".format(data)) 

154 df = read_csv(data) 

155 if verbose > 0: 

156 fLOG("[convert_validate] convert data into matrix") 

157 if schema is None: 

158 schema = guess_schema_from_data(df, tensor_type) 

159 if schema is None: 

160 schema = [ # pragma: no cover 

161 ('X', tensor_type([None, df.shape[1]]))] 

162 if len(schema) == 1: 

163 df = df.values # pylint: disable=E1101 

164 if verbose > 0: 

165 fLOG("[convert_validate] data schema={}".format(schema)) 

166 

167 if noshape: 

168 if verbose > 0: 

169 fLOG( # pragma: no cover 

170 "[convert_validate] convert the model with no shape information") 

171 schema = [(name, col.__class__([None, None])) for name, col in schema] 

172 onx = to_onnx( 

173 model, initial_types=schema, rewrite_ops=rewrite_ops, 

174 target_opset=target_opset, options=options) 

175 else: 

176 if verbose > 0: 

177 fLOG("[convert_validate] convert the model with shapes") 

178 onx = to_onnx( 

179 model, initial_types=schema, target_opset=target_opset, 

180 rewrite_ops=rewrite_ops, options=options) 

181 

182 if optim is not None: 

183 if verbose > 0: 

184 fLOG("[convert_validate] run optimisations '{}'".format(optim)) 

185 onx = onnx_optimisations(onx, optim=optim) 

186 if verbose > 0: 

187 fLOG("[convert_validate] saves to '{}'".format(outonnx)) 

188 memory = onx.SerializeToString() 

189 with open(outonnx, 'wb') as f: 

190 f.write(memory) 

191 

192 if verbose > 0: 

193 fLOG("[convert_validate] creates OnnxInference session") 

194 sess = OnnxInference(onx, runtime=runtime) 

195 if use_double == "switch": 

196 if verbose > 0: 

197 fLOG("[convert_validate] switch to double") 

198 sess.switch_initializers_dtype(model) 

199 

200 if verbose > 0: 

201 fLOG("[convert_validate] compute prediction from model") 

202 

203 if ',' in method: 

204 methods = method.split(',') 

205 else: 

206 methods = [method] 

207 if ',' in name: 

208 names = name.split(',') 

209 else: 

210 names = [name] 

211 

212 if len(names) != len(methods): 

213 raise ValueError( 

214 "Number of methods and outputs do not match: {}, {}".format( 

215 names, methods)) 

216 

217 if metric != 'l1med': 

218 raise ValueError( # pragma: no cover 

219 "Unknown metric '{}'".format(metric)) 

220 

221 if df is None: 

222 # no test on data 

223 return dict(onnx=memory) 

224 

225 if verbose > 0: 

226 fLOG("[convert_validate] compute predictions from ONNX with name '{}'" 

227 "".format(name)) 

228 

229 ort_preds = sess.run( 

230 {'X': df}, verbose=max(verbose - 1, 0), fLOG=fLOG) 

231 

232 metrics = [] 

233 out_skl_preds = [] 

234 out_ort_preds = [] 

235 for method_, name_ in zip(methods, names): 

236 if verbose > 0: 

237 fLOG("[convert_validate] compute predictions with method '{}'".format( 

238 method_)) 

239 meth = getattr(model, method_) 

240 skl_pred = meth(df) 

241 out_skl_preds.append(df) 

242 

243 if name_ not in ort_preds: 

244 raise KeyError( 

245 "Unable to find output name '{}' in {}".format( 

246 name_, list(sorted(ort_preds)))) 

247 

248 ort_pred = ort_preds[name_] 

249 out_ort_preds.append(ort_pred) 

250 diff = measure_relative_difference(skl_pred, ort_pred) 

251 if verbose > 0: 

252 fLOG("[convert_validate] {}={}".format(metric, diff)) 

253 metrics.append(diff) 

254 

255 return dict(skl_pred=out_skl_preds, ort_pred=out_ort_preds, 

256 metrics=metrics, onnx=memory)