Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2"""
3@file
4@brief Shortcut to *ops_whole*.
5"""
6import json
7from io import BytesIO
8import onnx
9from ...tools.ort_wrapper import (
10 InferenceSession, SessionOptions, RunOptions,
11 GraphOptimizationLevel, OrtFail,
12 OrtInvalidGraph, OrtInvalidArgument,
13 OrtNotImplemented, OrtRuntimeException)
14from ...tools.asv_options_helper import display_onnx
17class OnnxWholeSession:
18 """
19 Runs the prediction for a single :epkg:`ONNX`,
20 it lets the runtime handle the graph logic as well.
21 """
23 def __init__(self, onnx_data, runtime, runtime_options=None):
24 """
25 @param onnx_data :epkg:`ONNX` model or data
26 @param runtime runtime to be used,
27 mostly :epkg:`onnxruntime`
28 @param runtime_options runtime options
29 """
30 if runtime != 'onnxruntime1':
31 raise NotImplementedError( # pragma: no cover
32 "runtime '{}' is not implemented.".format(runtime))
33 if hasattr(onnx_data, 'SerializeToString'):
34 onnx_data = onnx_data.SerializeToString()
35 session_options = (
36 None if runtime_options is None
37 else runtime_options.get('session_options', None))
38 self.runtime = runtime
39 sess_options = session_options or SessionOptions()
40 self.run_options = RunOptions()
42 if session_options is None:
43 try:
44 sess_options.sessions_log_verbosity_level = 0
45 except AttributeError: # pragma: no cover
46 # onnxruntime not recent enough.
47 pass
48 try:
49 self.run_options.run_log_verbosity_level = 0
50 except AttributeError: # pragma: no cover
51 # onnxruntime not recent enough.
52 pass
53 if runtime_options is not None:
54 if runtime_options.get('disable_optimisation', False):
55 sess_options.graph_optimization_level = ( # pragma: no cover
56 GraphOptimizationLevel.ORT_ENABLE_ALL)
57 if runtime_options.get('enable_profiling', True):
58 sess_options.enable_profiling = True
59 elif 'enable_profiling' in runtime_options:
60 raise RuntimeError( # pragma: no cover
61 "session_options and enable_profiling cannot be defined at the "
62 "same time.")
63 elif 'disable_optimisation' in runtime_options:
64 raise RuntimeError( # pragma: no cover
65 "session_options and disable_optimisation cannot be defined at the "
66 "same time.")
67 try:
68 self.sess = InferenceSession(onnx_data, sess_options=sess_options)
69 except (OrtFail, OrtNotImplemented, OrtInvalidGraph,
70 OrtInvalidArgument, OrtRuntimeException, RuntimeError) as e:
71 raise RuntimeError(
72 "Unable to create InferenceSession due to '{}'\n{}.".format(
73 e, display_onnx(onnx.load(BytesIO(onnx_data))))) from e
75 def run(self, inputs):
76 """
77 Computes the predictions.
79 @param inputs dictionary *{variable, value}*
80 @return list of outputs
81 """
82 return self.sess.run(None, inputs, self.run_options)
84 @staticmethod
85 def process_profiling(js):
86 """
87 Flattens json returned by onnxruntime profiling.
89 :param js: json
90 :return: list of dictionaries
91 """
92 rows = []
93 for row in js:
94 if 'args' in row and isinstance(row['args'], dict):
95 for k, v in row['args'].items():
96 row['args_%s' % k] = v
97 del row['args']
98 rows.append(row)
99 return rows
101 def get_profiling(self):
102 """
103 Returns the profiling informations.
104 """
105 prof = self.sess.end_profiling()
106 with open(prof, 'r') as f:
107 content = f.read()
108 js = json.loads(content)
109 return OnnxWholeSession.process_profiling(js)