Coverage for src/pymlbenchmark/external/onnxruntime_perf_regression.py: 100%
22 statements
« prev ^ index » next coverage.py v7.2.1, created at 2023-03-08 00:27 +0100
« prev ^ index » next coverage.py v7.2.1, created at 2023-03-08 00:27 +0100
1"""
2@file
3@brief Implements a benchmark for a single regression
4about performance for :epkg:`onnxruntime`.
5"""
6import numpy
7from ..datasets import random_regression
8from .onnxruntime_perf import OnnxRuntimeBenchPerfTest
11class OnnxRuntimeBenchPerfTestRegression(OnnxRuntimeBenchPerfTest):
12 """
13 Specific test to compare computing time predictions
14 with :epkg:`scikit-learn` and :epkg:`onnxruntime`
15 for a binary classification.
16 See example :ref:`l-example-onnxruntime-linreg`.
17 The class requires the following modules to be installed:
18 :epkg:`onnx`, :epkg:`onnxruntime`, :epkg:`skl2onnx`,
19 :epkg:`mlprodict`.
20 """
22 def _get_random_dataset(self, N, dim):
23 """
24 Returns a random datasets.
25 """
26 return random_regression(N, dim)
28 def fcts(self, dim=None, **kwargs): # pylint: disable=W0221
29 """
30 Returns a few functions, tests methods
31 *perdict*, *predict_proba* for both
32 :epkg:`scikit-learn` and :epkg:`OnnxInference`
33 multiplied by the number of runtime to test.
34 """
35 def predict_skl_predict(X, model=self.skl):
36 return model.predict(X.astype(self.dtype))
38 def predict_onnxrt_predict(X, sess, output):
39 return numpy.array(sess.run({'X': X.astype(self.dtype)})[output])
41 fcts = [{'method': 'predict', 'lib': 'skl', 'fct': predict_skl_predict}]
42 for runtime in self.ort:
43 inst = self.ort[runtime]
44 output = self.outputs[runtime][0]
45 fcts.append({'method': 'predict', 'lib': 'onx' + runtime,
46 'fct': lambda X, sess=inst, output=output:
47 predict_onnxrt_predict(X.astype(self.dtype),
48 sess, output)})
50 for fct in fcts:
51 if fct['lib'] == 'skl':
52 fct.update(self.skl_info)
53 elif fct['lib'].startswith('onx'):
54 fct.update(self.onnx_info)
55 return fcts