Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# -*- encoding: utf-8 -*- 

2""" 

3@file 

4@brief Shortcut to *ops_whole*. 

5""" 

6import json 

7from io import BytesIO 

8import onnx 

9from ...tools.ort_wrapper import ( 

10 InferenceSession, SessionOptions, RunOptions, 

11 GraphOptimizationLevel, OrtFail, 

12 OrtInvalidGraph, OrtInvalidArgument, 

13 OrtNotImplemented, OrtRuntimeException) 

14from ...tools.asv_options_helper import display_onnx 

15 

16 

17class OnnxWholeSession: 

18 """ 

19 Runs the prediction for a single :epkg:`ONNX`, 

20 it lets the runtime handle the graph logic as well. 

21 """ 

22 

23 def __init__(self, onnx_data, runtime, runtime_options=None): 

24 """ 

25 @param onnx_data :epkg:`ONNX` model or data 

26 @param runtime runtime to be used, 

27 mostly :epkg:`onnxruntime` 

28 @param runtime_options runtime options 

29 """ 

30 if runtime != 'onnxruntime1': 

31 raise NotImplementedError( # pragma: no cover 

32 "runtime '{}' is not implemented.".format(runtime)) 

33 if hasattr(onnx_data, 'SerializeToString'): 

34 onnx_data = onnx_data.SerializeToString() 

35 session_options = ( 

36 None if runtime_options is None 

37 else runtime_options.get('session_options', None)) 

38 self.runtime = runtime 

39 sess_options = session_options or SessionOptions() 

40 self.run_options = RunOptions() 

41 

42 if session_options is None: 

43 try: 

44 sess_options.sessions_log_verbosity_level = 0 

45 except AttributeError: # pragma: no cover 

46 # onnxruntime not recent enough. 

47 pass 

48 try: 

49 self.run_options.run_log_verbosity_level = 0 

50 except AttributeError: # pragma: no cover 

51 # onnxruntime not recent enough. 

52 pass 

53 if runtime_options is not None: 

54 if runtime_options.get('disable_optimisation', False): 

55 sess_options.graph_optimization_level = ( # pragma: no cover 

56 GraphOptimizationLevel.ORT_ENABLE_ALL) 

57 if runtime_options.get('enable_profiling', True): 

58 sess_options.enable_profiling = True 

59 elif 'enable_profiling' in runtime_options: 

60 raise RuntimeError( # pragma: no cover 

61 "session_options and enable_profiling cannot be defined at the " 

62 "same time.") 

63 elif 'disable_optimisation' in runtime_options: 

64 raise RuntimeError( # pragma: no cover 

65 "session_options and disable_optimisation cannot be defined at the " 

66 "same time.") 

67 try: 

68 self.sess = InferenceSession(onnx_data, sess_options=sess_options) 

69 except (OrtFail, OrtNotImplemented, OrtInvalidGraph, 

70 OrtInvalidArgument, OrtRuntimeException, RuntimeError) as e: 

71 raise RuntimeError( 

72 "Unable to create InferenceSession due to '{}'\n{}.".format( 

73 e, display_onnx(onnx.load(BytesIO(onnx_data))))) from e 

74 

75 def run(self, inputs): 

76 """ 

77 Computes the predictions. 

78 

79 @param inputs dictionary *{variable, value}* 

80 @return list of outputs 

81 """ 

82 return self.sess.run(None, inputs, self.run_options) 

83 

84 @staticmethod 

85 def process_profiling(js): 

86 """ 

87 Flattens json returned by onnxruntime profiling. 

88 

89 :param js: json 

90 :return: list of dictionaries 

91 """ 

92 rows = [] 

93 for row in js: 

94 if 'args' in row and isinstance(row['args'], dict): 

95 for k, v in row['args'].items(): 

96 row['args_%s' % k] = v 

97 del row['args'] 

98 rows.append(row) 

99 return rows 

100 

101 def get_profiling(self): 

102 """ 

103 Returns the profiling informations. 

104 """ 

105 prof = self.sess.end_profiling() 

106 with open(prof, 'r') as f: 

107 content = f.read() 

108 js = json.loads(content) 

109 return OnnxWholeSession.process_profiling(js)