Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Implements a benchmark about performance for :epkg:`onnxruntime`
4"""
5import contextlib
6from collections import OrderedDict
7from io import BytesIO, StringIO
8import numpy
9from numpy.testing import assert_almost_equal
10import pandas
11from sklearn.ensemble._forest import BaseForest
12from sklearn.tree._classes import BaseDecisionTree
13from mlprodict.onnxrt import OnnxInference
14from mlprodict.tools.asv_options_helper import (
15 get_opset_number_from_onnx, get_ir_version_from_onnx)
16from ..benchmark import BenchPerfTest
17from ..benchmark.sklearn_helper import get_nb_skl_base_estimators
20class OnnxRuntimeBenchPerfTest(BenchPerfTest):
21 """
22 Specific test to compare computing time predictions
23 with :epkg:`scikit-learn` and :epkg:`onnxruntime`.
24 See example :ref:`l-example-onnxruntime-logreg`.
25 The class requires the following modules to be installed:
26 :epkg:`onnx`, :epkg:`onnxruntime`, :epkg:`skl2onnx`,
27 :epkg:`mlprodict`.
28 """
30 def __init__(self, estimator, dim=None, N_fit=100000,
31 runtimes=('python_compiled', 'onnxruntime1'),
32 onnx_options=None, dtype=numpy.float32,
33 **opts):
34 """
35 @param estimator estimator class
36 @param dim number of features
37 @param N_fit number of observations to fit an estimator
38 @param runtimes runtimes to test for class :epkg:`OnnxInference`
39 @param opts training settings
40 @param onnx_options ONNX conversion options
41 @param dtype dtype (float32 or float64)
42 """
43 # These libraries are optional.
44 from skl2onnx import to_onnx # pylint: disable=E0401,C0415
45 from skl2onnx.common.data_types import FloatTensorType, DoubleTensorType # pylint: disable=E0401,C0415
47 if dim is None:
48 raise RuntimeError( # pragma: no cover
49 "dim must be defined.")
50 BenchPerfTest.__init__(self, **opts)
52 allowed = {"max_depth"}
53 opts = {k: v for k, v in opts.items() if k in allowed}
54 self.dtype = dtype
55 self.skl = estimator(**opts)
56 X, y = self._get_random_dataset(N_fit, dim)
57 try:
58 self.skl.fit(X, y)
59 except Exception as e: # pragma: no cover
60 raise RuntimeError("X.shape={}\nopts={}\nTraining failed for {}".format(
61 X.shape, opts, self.skl)) from e
63 if dtype == numpy.float64:
64 initial_types = [('X', DoubleTensorType([None, X.shape[1]]))]
65 elif dtype == numpy.float32:
66 initial_types = [('X', FloatTensorType([None, X.shape[1]]))]
67 else:
68 raise ValueError( # pragma: no cover
69 "Unable to convert the model into ONNX, unsupported dtype {}.".format(dtype))
70 self.logconvert = StringIO()
71 with contextlib.redirect_stdout(self.logconvert):
72 with contextlib.redirect_stderr(self.logconvert):
73 onx = to_onnx(self.skl, initial_types=initial_types,
74 options=onnx_options,
75 target_opset=get_opset_number_from_onnx())
76 onx.ir_version = get_ir_version_from_onnx()
78 self._init(onx, runtimes)
80 def _get_random_dataset(self, N, dim):
81 """
82 Returns a random datasets.
83 """
84 raise NotImplementedError( # pragma: no cover
85 "This method must be overloaded.")
87 def _init(self, onx, runtimes):
88 "Finalizes the init."
89 f = BytesIO()
90 f.write(onx.SerializeToString())
91 self.ort_onnx = onx
92 content = f.getvalue()
93 self.ort = OrderedDict()
94 self.outputs = OrderedDict()
95 for r in runtimes:
96 self.ort[r] = OnnxInference(content, runtime=r)
97 self.outputs[r] = self.ort[r].output_names
98 self.extract_model_info_skl()
99 self.extract_model_info_onnx(ort_size=len(content))
101 def extract_model_info_skl(self, **kwargs):
102 """
103 Populates member ``self.skl_info`` with additional
104 information on the model such as the number of node for
105 a decision tree.
106 """
107 self.skl_info = dict(
108 skl_nb_base_estimators=get_nb_skl_base_estimators(self.skl, fitted=True))
109 self.skl_info.update(kwargs)
110 if isinstance(self.skl, BaseDecisionTree):
111 self.skl_info["skl_dt_nodes"] = self.skl.tree_.node_count
112 elif isinstance(self.skl, BaseForest):
113 self.skl_info["skl_rf_nodes"] = sum(
114 est.tree_.node_count for est in self.skl.estimators_)
116 def extract_model_info_onnx(self, **kwargs):
117 """
118 Populates member ``self.onnx_info`` with additional
119 information on the :epkg:`ONNX` graph.
120 """
121 self.onnx_info = {
122 'onnx_nodes': len(self.ort_onnx.graph.node), # pylint: disable=E1101
123 'onnx_opset': get_opset_number_from_onnx(),
124 }
125 self.onnx_info.update(kwargs)
127 def data(self, N=None, dim=None, **kwargs): # pylint: disable=W0221
128 """
129 Generates random features.
131 @param N number of observations
132 @param dim number of features
133 """
134 if dim is None:
135 raise RuntimeError( # pragma: no cover
136 "dim must be defined.")
137 if N is None:
138 raise RuntimeError( # pragma: no cover
139 "N must be defined.")
140 return self._get_random_dataset(N, dim)[:1]
142 def model_info(self, model):
143 """
144 Returns additional informations about a model.
146 @param model model to describe
147 @return dictionary with additional descriptor
148 """
149 res = dict(type_name=model.__class__.__name__)
150 return res
152 def validate(self, results, **kwargs):
153 """
154 Checks that methods *predict* and *predict_proba* returns
155 the same results for both :epkg:`scikit-learn` and
156 :epkg:`onnxruntime`.
157 """
158 res = {}
159 baseline = None
160 for idt, fct, vals in results:
161 key = idt, fct.get('method', '')
162 if key not in res:
163 res[key] = {}
164 if isinstance(vals, list):
165 vals = pandas.DataFrame(vals).values
166 lib = fct['lib']
167 res[key][lib] = vals
168 if lib == 'skl':
169 baseline = lib
171 if len(res) == 0:
172 raise RuntimeError( # pragma: no cover
173 "No results to compare.")
174 if baseline is None:
175 raise RuntimeError( # pragma: no cover
176 "Unable to guess the baseline in {}.".format(
177 list(res.pop())))
179 for key, exp in res.items():
180 vbase = exp[baseline]
181 if vbase.shape[0] <= 10000:
182 for name, vals in exp.items():
183 if name == baseline:
184 continue
185 p1, p2 = vbase, vals
186 if len(p1.shape) == 1 and len(p2.shape) == 2:
187 p2 = p2.ravel()
188 try:
189 assert_almost_equal(p1, p2, decimal=4)
190 except AssertionError as e:
191 if p1.dtype == numpy.int64 and p2.dtype == numpy.int64:
192 delta = numpy.sum(numpy.abs(p1 - p2) != 0)
193 if delta <= 2:
194 # scikit-learn does double computation not float,
195 # discrepencies between scikit-learn is likely to happen
196 continue
197 msg = "ERROR: Dim {}-{} ({}-{}) - discrepencies between '{}' and '{}' for '{}'.".format(
198 vbase.shape, vals.shape, getattr(
199 p1, 'dtype', None),
200 getattr(p2, 'dtype', None), baseline, name, key)
201 self.dump_error(msg, skl=self.skl, ort=self.ort,
202 baseline=vbase, discrepencies=vals,
203 onnx_bytes=self.ort_onnx.SerializeToString(),
204 results=results, **kwargs)
205 raise AssertionError(msg) from e