Coverage for src/pymlbenchmark/external/onnxruntime_perf_binclass.py: 88%

48 statements  

« prev     ^ index     » next       coverage.py v7.2.1, created at 2023-03-08 00:27 +0100

1""" 

2@file 

3@brief Implements a benchmark for a binary classification 

4about performance for :epkg:`OnnxInference`. 

5""" 

6import numpy 

7from ..datasets import random_binary_classification 

8from .onnxruntime_perf import OnnxRuntimeBenchPerfTest 

9 

10 

11class OnnxRuntimeBenchPerfTestBinaryClassification(OnnxRuntimeBenchPerfTest): 

12 """ 

13 Specific test to compare computing time predictions 

14 with :epkg:`scikit-learn` and :epkg:`onnxruntime` 

15 for a binary classification. 

16 See example :ref:`l-example-onnxruntime-logreg`. 

17 The class requires the following modules to be installed: 

18 :epkg:`onnx`, :epkg:`onnxruntime`, :epkg:`skl2onnx`, 

19 :epkg:`mlprodict`. 

20 """ 

21 

22 def _get_random_dataset(self, N, dim): 

23 """ 

24 Returns a random datasets. 

25 """ 

26 return random_binary_classification(N, dim) 

27 

28 def fcts(self, dim=None, **kwargs): # pylint: disable=W0221 

29 """ 

30 Returns a few functions, tests methods 

31 *perdict*, *predict_proba* for both 

32 :epkg:`scikit-learn` and :epkg:`OnnxInference` 

33 multiplied by the number of runtime to test. 

34 """ 

35 def predict_skl_predict(X, model=self.skl): 

36 return model.predict(X) 

37 

38 def predict_skl_predict_proba(X, model=self.skl): 

39 return model.predict_proba(X.astype(self.dtype)) 

40 

41 def predict_onnxrt_predict(X, sess, output): 

42 return numpy.array(sess.run({'X': X.astype(self.dtype)})[output]) 

43 

44 def predict_onnxrt_predict_proba(X, sess, output): 

45 res = sess.run({'X': X.astype(self.dtype)})[output] 

46 # do not use DataFrame to convert the output into array, 

47 # it takes too much time 

48 if hasattr(res, 'items'): 

49 out = numpy.empty((len(res), len(res[0])), dtype=numpy.float32) 

50 for i, row in enumerate(res): 

51 for k, v in row.items(): 

52 out[i, k] = v 

53 else: 

54 out = res 

55 return out 

56 

57 fcts = [{'method': 'predict', 'lib': 'skl', 'fct': predict_skl_predict}] 

58 for runtime in self.ort: 

59 inst = self.ort[runtime] 

60 output = self.outputs[runtime][0] 

61 fcts.append({'method': 'predict', 'lib': 'onx' + runtime, 

62 'fct': lambda X, sess=inst, output=output: 

63 predict_onnxrt_predict(X, sess, output)}) 

64 

65 if hasattr(self.skl, '_check_proba'): 

66 try: 

67 self.skl._check_proba() 

68 prob = True 

69 except AttributeError: 

70 prob = False 

71 elif hasattr(self.skl, 'predict_proba'): 

72 prob = True 

73 else: 

74 prob = False 

75 

76 if prob: 

77 fcts.append({'method': 'predict_proba', 'lib': 'skl', 

78 'fct': predict_skl_predict_proba}) 

79 for runtime in self.ort: 

80 inst = self.ort[runtime] 

81 output = self.outputs[runtime][1] 

82 fcts.append({'method': 'predict_proba', 'lib': 'onx' + runtime, 

83 'fct': lambda X, sess=inst, output=output: 

84 predict_onnxrt_predict_proba(X, sess, output)}) 

85 

86 for fct in fcts: 

87 if fct['lib'] == 'skl': 

88 fct.update(self.skl_info) 

89 elif fct['lib'].startswith('onx'): 

90 fct.update(self.onnx_info) 

91 return fcts