Coverage for src/pymlbenchmark/external/onnxruntime_perf_binclass.py: 88%
48 statements
« prev ^ index » next coverage.py v7.2.1, created at 2023-03-08 00:27 +0100
« prev ^ index » next coverage.py v7.2.1, created at 2023-03-08 00:27 +0100
1"""
2@file
3@brief Implements a benchmark for a binary classification
4about performance for :epkg:`OnnxInference`.
5"""
6import numpy
7from ..datasets import random_binary_classification
8from .onnxruntime_perf import OnnxRuntimeBenchPerfTest
11class OnnxRuntimeBenchPerfTestBinaryClassification(OnnxRuntimeBenchPerfTest):
12 """
13 Specific test to compare computing time predictions
14 with :epkg:`scikit-learn` and :epkg:`onnxruntime`
15 for a binary classification.
16 See example :ref:`l-example-onnxruntime-logreg`.
17 The class requires the following modules to be installed:
18 :epkg:`onnx`, :epkg:`onnxruntime`, :epkg:`skl2onnx`,
19 :epkg:`mlprodict`.
20 """
22 def _get_random_dataset(self, N, dim):
23 """
24 Returns a random datasets.
25 """
26 return random_binary_classification(N, dim)
28 def fcts(self, dim=None, **kwargs): # pylint: disable=W0221
29 """
30 Returns a few functions, tests methods
31 *perdict*, *predict_proba* for both
32 :epkg:`scikit-learn` and :epkg:`OnnxInference`
33 multiplied by the number of runtime to test.
34 """
35 def predict_skl_predict(X, model=self.skl):
36 return model.predict(X)
38 def predict_skl_predict_proba(X, model=self.skl):
39 return model.predict_proba(X.astype(self.dtype))
41 def predict_onnxrt_predict(X, sess, output):
42 return numpy.array(sess.run({'X': X.astype(self.dtype)})[output])
44 def predict_onnxrt_predict_proba(X, sess, output):
45 res = sess.run({'X': X.astype(self.dtype)})[output]
46 # do not use DataFrame to convert the output into array,
47 # it takes too much time
48 if hasattr(res, 'items'):
49 out = numpy.empty((len(res), len(res[0])), dtype=numpy.float32)
50 for i, row in enumerate(res):
51 for k, v in row.items():
52 out[i, k] = v
53 else:
54 out = res
55 return out
57 fcts = [{'method': 'predict', 'lib': 'skl', 'fct': predict_skl_predict}]
58 for runtime in self.ort:
59 inst = self.ort[runtime]
60 output = self.outputs[runtime][0]
61 fcts.append({'method': 'predict', 'lib': 'onx' + runtime,
62 'fct': lambda X, sess=inst, output=output:
63 predict_onnxrt_predict(X, sess, output)})
65 if hasattr(self.skl, '_check_proba'):
66 try:
67 self.skl._check_proba()
68 prob = True
69 except AttributeError:
70 prob = False
71 elif hasattr(self.skl, 'predict_proba'):
72 prob = True
73 else:
74 prob = False
76 if prob:
77 fcts.append({'method': 'predict_proba', 'lib': 'skl',
78 'fct': predict_skl_predict_proba})
79 for runtime in self.ort:
80 inst = self.ort[runtime]
81 output = self.outputs[runtime][1]
82 fcts.append({'method': 'predict_proba', 'lib': 'onx' + runtime,
83 'fct': lambda X, sess=inst, output=output:
84 predict_onnxrt_predict_proba(X, sess, output)})
86 for fct in fcts:
87 if fct['lib'] == 'skl':
88 fct.update(self.skl_info)
89 elif fct['lib'].startswith('onx'):
90 fct.update(self.onnx_info)
91 return fcts