Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Implements a benchmark about performance for :epkg:`onnxruntime` 

4""" 

5import contextlib 

6from collections import OrderedDict 

7from io import BytesIO, StringIO 

8import numpy 

9from numpy.testing import assert_almost_equal 

10import pandas 

11from sklearn.ensemble._forest import BaseForest 

12from sklearn.tree._classes import BaseDecisionTree 

13from mlprodict.onnxrt import OnnxInference 

14from mlprodict.tools.asv_options_helper import ( 

15 get_opset_number_from_onnx, get_ir_version_from_onnx) 

16from ..benchmark import BenchPerfTest 

17from ..benchmark.sklearn_helper import get_nb_skl_base_estimators 

18 

19 

20class OnnxRuntimeBenchPerfTest(BenchPerfTest): 

21 """ 

22 Specific test to compare computing time predictions 

23 with :epkg:`scikit-learn` and :epkg:`onnxruntime`. 

24 See example :ref:`l-example-onnxruntime-logreg`. 

25 The class requires the following modules to be installed: 

26 :epkg:`onnx`, :epkg:`onnxruntime`, :epkg:`skl2onnx`, 

27 :epkg:`mlprodict`. 

28 """ 

29 

30 def __init__(self, estimator, dim=None, N_fit=100000, 

31 runtimes=('python_compiled', 'onnxruntime1'), 

32 onnx_options=None, dtype=numpy.float32, 

33 **opts): 

34 """ 

35 @param estimator estimator class 

36 @param dim number of features 

37 @param N_fit number of observations to fit an estimator 

38 @param runtimes runtimes to test for class :epkg:`OnnxInference` 

39 @param opts training settings 

40 @param onnx_options ONNX conversion options 

41 @param dtype dtype (float32 or float64) 

42 """ 

43 # These libraries are optional. 

44 from skl2onnx import to_onnx # pylint: disable=E0401,C0415 

45 from skl2onnx.common.data_types import FloatTensorType, DoubleTensorType # pylint: disable=E0401,C0415 

46 

47 if dim is None: 

48 raise RuntimeError( # pragma: no cover 

49 "dim must be defined.") 

50 BenchPerfTest.__init__(self, **opts) 

51 

52 allowed = {"max_depth"} 

53 opts = {k: v for k, v in opts.items() if k in allowed} 

54 self.dtype = dtype 

55 self.skl = estimator(**opts) 

56 X, y = self._get_random_dataset(N_fit, dim) 

57 try: 

58 self.skl.fit(X, y) 

59 except Exception as e: # pragma: no cover 

60 raise RuntimeError("X.shape={}\nopts={}\nTraining failed for {}".format( 

61 X.shape, opts, self.skl)) from e 

62 

63 if dtype == numpy.float64: 

64 initial_types = [('X', DoubleTensorType([None, X.shape[1]]))] 

65 elif dtype == numpy.float32: 

66 initial_types = [('X', FloatTensorType([None, X.shape[1]]))] 

67 else: 

68 raise ValueError( # pragma: no cover 

69 "Unable to convert the model into ONNX, unsupported dtype {}.".format(dtype)) 

70 self.logconvert = StringIO() 

71 with contextlib.redirect_stdout(self.logconvert): 

72 with contextlib.redirect_stderr(self.logconvert): 

73 onx = to_onnx(self.skl, initial_types=initial_types, 

74 options=onnx_options, 

75 target_opset=get_opset_number_from_onnx()) 

76 onx.ir_version = get_ir_version_from_onnx() 

77 

78 self._init(onx, runtimes) 

79 

80 def _get_random_dataset(self, N, dim): 

81 """ 

82 Returns a random datasets. 

83 """ 

84 raise NotImplementedError( # pragma: no cover 

85 "This method must be overloaded.") 

86 

87 def _init(self, onx, runtimes): 

88 "Finalizes the init." 

89 f = BytesIO() 

90 f.write(onx.SerializeToString()) 

91 self.ort_onnx = onx 

92 content = f.getvalue() 

93 self.ort = OrderedDict() 

94 self.outputs = OrderedDict() 

95 for r in runtimes: 

96 self.ort[r] = OnnxInference(content, runtime=r) 

97 self.outputs[r] = self.ort[r].output_names 

98 self.extract_model_info_skl() 

99 self.extract_model_info_onnx(ort_size=len(content)) 

100 

101 def extract_model_info_skl(self, **kwargs): 

102 """ 

103 Populates member ``self.skl_info`` with additional 

104 information on the model such as the number of node for 

105 a decision tree. 

106 """ 

107 self.skl_info = dict( 

108 skl_nb_base_estimators=get_nb_skl_base_estimators(self.skl, fitted=True)) 

109 self.skl_info.update(kwargs) 

110 if isinstance(self.skl, BaseDecisionTree): 

111 self.skl_info["skl_dt_nodes"] = self.skl.tree_.node_count 

112 elif isinstance(self.skl, BaseForest): 

113 self.skl_info["skl_rf_nodes"] = sum( 

114 est.tree_.node_count for est in self.skl.estimators_) 

115 

116 def extract_model_info_onnx(self, **kwargs): 

117 """ 

118 Populates member ``self.onnx_info`` with additional 

119 information on the :epkg:`ONNX` graph. 

120 """ 

121 self.onnx_info = { 

122 'onnx_nodes': len(self.ort_onnx.graph.node), # pylint: disable=E1101 

123 'onnx_opset': get_opset_number_from_onnx(), 

124 } 

125 self.onnx_info.update(kwargs) 

126 

127 def data(self, N=None, dim=None, **kwargs): # pylint: disable=W0221 

128 """ 

129 Generates random features. 

130 

131 @param N number of observations 

132 @param dim number of features 

133 """ 

134 if dim is None: 

135 raise RuntimeError( # pragma: no cover 

136 "dim must be defined.") 

137 if N is None: 

138 raise RuntimeError( # pragma: no cover 

139 "N must be defined.") 

140 return self._get_random_dataset(N, dim)[:1] 

141 

142 def model_info(self, model): 

143 """ 

144 Returns additional informations about a model. 

145 

146 @param model model to describe 

147 @return dictionary with additional descriptor 

148 """ 

149 res = dict(type_name=model.__class__.__name__) 

150 return res 

151 

152 def validate(self, results, **kwargs): 

153 """ 

154 Checks that methods *predict* and *predict_proba* returns 

155 the same results for both :epkg:`scikit-learn` and 

156 :epkg:`onnxruntime`. 

157 """ 

158 res = {} 

159 baseline = None 

160 for idt, fct, vals in results: 

161 key = idt, fct.get('method', '') 

162 if key not in res: 

163 res[key] = {} 

164 if isinstance(vals, list): 

165 vals = pandas.DataFrame(vals).values 

166 lib = fct['lib'] 

167 res[key][lib] = vals 

168 if lib == 'skl': 

169 baseline = lib 

170 

171 if len(res) == 0: 

172 raise RuntimeError( # pragma: no cover 

173 "No results to compare.") 

174 if baseline is None: 

175 raise RuntimeError( # pragma: no cover 

176 "Unable to guess the baseline in {}.".format( 

177 list(res.pop()))) 

178 

179 for key, exp in res.items(): 

180 vbase = exp[baseline] 

181 if vbase.shape[0] <= 10000: 

182 for name, vals in exp.items(): 

183 if name == baseline: 

184 continue 

185 p1, p2 = vbase, vals 

186 if len(p1.shape) == 1 and len(p2.shape) == 2: 

187 p2 = p2.ravel() 

188 try: 

189 assert_almost_equal(p1, p2, decimal=4) 

190 except AssertionError as e: 

191 if p1.dtype == numpy.int64 and p2.dtype == numpy.int64: 

192 delta = numpy.sum(numpy.abs(p1 - p2) != 0) 

193 if delta <= 2: 

194 # scikit-learn does double computation not float, 

195 # discrepencies between scikit-learn is likely to happen 

196 continue 

197 msg = "ERROR: Dim {}-{} ({}-{}) - discrepencies between '{}' and '{}' for '{}'.".format( 

198 vbase.shape, vals.shape, getattr( 

199 p1, 'dtype', None), 

200 getattr(p2, 'dtype', None), baseline, name, key) 

201 self.dump_error(msg, skl=self.skl, ort=self.ort, 

202 baseline=vbase, discrepencies=vals, 

203 onnx_bytes=self.ort_onnx.SerializeToString(), 

204 results=results, **kwargs) 

205 raise AssertionError(msg) from e