Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Measures time processing for ONNX models. 

4""" 

5import pickle 

6import os 

7import sklearn 

8from ...tools.ort_wrapper import InferenceSession, OrtFail 

9from .. import OnnxInference 

10from .validate_helper import default_time_kwargs, measure_time, _multiply_time_kwargs 

11from .validate_benchmark import make_n_rows 

12 

13 

14class SimplifiedOnnxInference: 

15 "Simple wrapper around InferenceSession which imitates OnnxInference." 

16 

17 def __init__(self, ort): 

18 self.sess = InferenceSession(ort) 

19 

20 @property 

21 def input_names(self): 

22 "Returns InferenceSession input names." 

23 return [_.name for _ in self.sess.get_inputs()] 

24 

25 def run(self, input): 

26 "Calls InferenceSession.run." 

27 return self.sess.run(None, input) 

28 

29 

30def enumerate_benchmark_replay(folder, runtime='python', time_kwargs=None, 

31 skip_long_test=True, time_kwargs_fact=None, 

32 time_limit=4, verbose=1, fLOG=None): 

33 """ 

34 Replays a benchmark stored with function 

35 :func:`enumerate_validated_operator_opsets 

36 <mlprodict.onnxrt.validate.validate.enumerate_validated_operator_opsets>` 

37 or command line :ref:`validate_runtime <l-cmd-validate_runtime>`. 

38 Enumerates the results. 

39 

40 @param folder folder where to find pickled files, all files must have 

41 *pkl* or *pickle* extension 

42 @param runtime runtime or runtimes 

43 @param time_kwargs to define a more precise way to measure a model 

44 @param skip_long_test skips tests for high values of N if they seem too long 

45 @param time_kwargs_fact see :func:`_multiply_time_kwargs <mlprodict.onnxrt.validate.validate_helper._multiply_time_kwargs>` 

46 @param time_limit to skip the rest of the test after this limit (in second) 

47 @param verbose if >= 1, uses :epkg:`tqdm` 

48 @param fLOG logging function 

49 @return iterator on results 

50 """ 

51 files = [_ for _ in os.listdir(folder) if _.endswith( 

52 ".pkl") or _.endswith("_.pickle")] 

53 if len(files) == 0: 

54 raise FileNotFoundError( 

55 "Unable to find any file in folder '{}'.".format(folder)) 

56 

57 if time_kwargs in (None, ''): 

58 time_kwargs = default_time_kwargs() 

59 

60 if isinstance(runtime, str): 

61 runtime = runtime.split(",") 

62 

63 loop = files 

64 if verbose >= 1: 

65 try: 

66 from tqdm import tqdm 

67 loop = tqdm(files) 

68 except ImportError: # pragma: no cover 

69 pass 

70 

71 for pkl in loop: 

72 if "ERROR" in pkl: 

73 # An error. 

74 if verbose >= 2 and fLOG is not None: # pragma: no cover 

75 fLOG( # pragma: no cover 

76 "[enumerate_benchmark_replay] skip '{}'.".format(pkl)) 

77 continue # pragma: no cover 

78 if verbose >= 2 and fLOG is not None: 

79 fLOG("[enumerate_benchmark_replay] process '{}'.".format(pkl)) 

80 row = {} 

81 with open(os.path.join(folder, pkl), 'rb') as f: 

82 obj = pickle.load(f) 

83 X_test = obj['X_test'] 

84 ort_test = obj['Xort_test'] 

85 onx = obj['onnx_bytes'] 

86 model = obj['skl_model'] 

87 tkw = _multiply_time_kwargs(time_kwargs, time_kwargs_fact, model) 

88 row['folder'] = folder 

89 row['filename'] = pkl 

90 row['n_features'] = X_test.shape[1] 

91 

92 for key in ['assume_finite', 'conv_options', 

93 'init_types', 'idtype', 'method_name', 'n_features', 

94 'name', 'optim', 'opset', 'predict_kwargs', 

95 'output_index', 'problem', 'scenario']: 

96 row[key] = obj['obs_op'][key] 

97 

98 # 'bench-batch', 

99 # 'bench-skl', 

100 

101 oinfs = {} 

102 for rt in runtime: 

103 if rt == 'onnxruntime': 

104 try: 

105 oinfs[rt] = SimplifiedOnnxInference(onx) 

106 except (OrtFail, RuntimeError) as e: # pragma: no cover 

107 row['ERROR'] = str(e) 

108 oinfs[rt] = None 

109 else: 

110 try: 

111 oinfs[rt] = OnnxInference(onx, runtime=rt) 

112 except (OrtFail, RuntimeError) as e: # pragma: no cover 

113 row['ERROR'] = str(e) 

114 oinfs[rt] = None 

115 

116 for k, v in sorted(tkw.items()): 

117 if verbose >= 3 and fLOG is not None: 

118 fLOG( # pragma: no cover 

119 "[enumerate_benchmark_replay] process n_rows={} - {}".format(k, v)) 

120 xt = make_n_rows(X_test, k) 

121 number = v['number'] 

122 repeat = v['repeat'] 

123 

124 meth = getattr(model, row['method_name']) 

125 with sklearn.config_context(assume_finite=row['assume_finite']): 

126 skl = measure_time(lambda x: meth(x), xt, 

127 number=number, repeat=repeat, 

128 div_by_number=True) 

129 if verbose >= 4 and fLOG is not None: 

130 fLOG( # pragma: no cover 

131 "[enumerate_benchmark_replay] skl={}".format(skl)) 

132 row['%d-skl-details' % k] = skl 

133 row['%d-skl' % k] = skl['average'] 

134 

135 xto = make_n_rows(ort_test, k) 

136 for rt in runtime: 

137 oinf = oinfs[rt] 

138 if oinf is None: 

139 continue # pragma: no cover 

140 if len(oinf.input_names) != 1: 

141 raise NotImplementedError( # pragma: no cover 

142 "This function only allows one input not {}".format( 

143 len(oinf.input_names))) 

144 name = oinf.input_names[0] 

145 ort = measure_time(lambda x: oinf.run({name: x}), xto, 

146 number=number, repeat=repeat, 

147 div_by_number=True) 

148 if verbose >= 4 and fLOG is not None: 

149 fLOG( # pragma: no cover 

150 "[enumerate_benchmark_replay] {}={}".format(rt, ort)) 

151 row['%d-%s-detail' % (k, rt)] = ort 

152 row['%d-%s' % (k, rt)] = ort['average'] 

153 yield row