Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2"""
3@file
4@brief Shortcut to *ops_onnxruntime*.
5"""
6import numpy
7import onnx.defs
8from onnx.helper import make_tensor
9from onnx.onnx_cpp2py_export.shape_inference import InferenceError # pylint: disable=E0401,E0611
10from skl2onnx.common.data_types import (
11 DictionaryType, FloatTensorType, Int64TensorType, StringTensorType)
12import skl2onnx.algebra.onnx_ops as alg
13try:
14 import skl2onnx.algebra.custom_ops as alg2
15except ImportError: # pragma: no cover
16 # older version of skl2onnx
17 alg2 = alg
18from ...tools.ort_wrapper import (
19 InferenceSession, SessionOptions, RunOptions,
20 GraphOptimizationLevel, OrtInvalidArgument,
21 OrtNotImplemented, OrtInvalidGraph, OrtFail)
22from ...onnx_tools.onnx2py_helper import guess_proto_dtype
23from ...onnx_tools.optim.graph_schema_helper import (
24 get_defined_inputs, get_defined_outputs, proto2vars)
27_schemas = {
28 schema.name: schema for schema in onnx.defs.get_all_schemas_with_history()}
31class OpRunOnnxRuntime:
32 """
33 Unique operator which calls :epkg:`onnxruntime`
34 to compute predictions for one operator.
35 """
37 def __init__(self, onnx_node, desc=None, variables=None,
38 dtype=None, **options):
39 """
40 @param onnx_node :epkg:`onnx` node
41 @param desc internal representation
42 @param variables registered variables created by previous operators
43 @param dtype float computation type
44 @param options runtime options
45 """
46 self._provider = 'onnxruntime'
47 self.onnx_node = onnx_node
48 self.desc = desc
49 self._schema = _schemas.get(onnx_node.op_type, None)
50 if desc is not None:
51 if 'atts' in desc:
52 for a, b in desc['atts'].items():
53 if not isinstance(b, dict) or 'value' not in b:
54 raise ValueError( # pragma: no cover
55 "Unexpected value {}.".format(b))
56 options[a] = b['value']
58 self.options = options
59 self.dtype = dtype
60 self._init(variables)
62 def _name_mapping(self, inputs):
63 mapping = {}
64 new_inputs = []
65 for name in inputs:
66 if name in mapping:
67 i = 0
68 new_name = "{}_{}".format(name, i)
69 while new_name in mapping:
70 i += 1 # pragma: no cover
71 new_name = "{}_{}".format(name, i) # pragma: no cover
72 mapping[new_name] = name
73 new_inputs.append(new_name)
74 else:
75 new_inputs.append(name)
76 mapping[name] = name
77 return mapping, new_inputs
79 def _guess_proto_type(self, dtype):
80 return guess_proto_dtype(dtype)
82 def _init(self, variables=None):
83 """
84 Initializes the node.
86 :param variables: registered variables created by previous operators
88 The current implementation for operator *Scan*
89 only works for matrices.
90 """
91 custom_nodes = self.options.get('nodes', None)
92 if (custom_nodes is not None and
93 self.onnx_node.op_type in custom_nodes):
94 self.alg_class = custom_nodes[self.onnx_node.op_type]
95 else:
96 try:
97 self.alg_class = getattr(alg2, 'Onnx' + self.onnx_node.op_type)
98 except AttributeError:
99 self.alg_class = getattr(alg, 'Onnx' + self.onnx_node.op_type)
101 inputs = list(self.onnx_node.input)
102 self.mapping, self.inputs = self._name_mapping(inputs)
103 self.outputs = list(self.onnx_node.output)
105 options = self.options.copy()
106 options.pop('nodes', None)
107 target_opset = options.pop('target_opset', None)
108 domain = options.pop('domain', None)
109 disable_optimisation = options.pop('disable_optimisation', False)
110 session_options = options.pop('session_options', False)
111 ir_version = options.pop('ir_version', None)
113 if domain == '' and target_opset < 9:
114 # target_opset should be >= 9 not {} for main domain.
115 # We assume it was the case when the graph was created.
116 pass
118 if self.onnx_node.op_type == 'ZipMap':
119 self.inst_ = self.alg_class(*self.inputs, output_names=self.outputs,
120 op_version=target_opset, **options)
121 inputs = get_defined_inputs(
122 self.inputs, variables, dtype=self.dtype)
123 name = (self.outputs[0] if len(self.outputs) == 1
124 else self.inst_.expected_outputs[0][0])
125 otype = (Int64TensorType if 'classlabels_int64s' in options
126 else StringTensorType)
127 outvar = [(name, DictionaryType(otype([1]), FloatTensorType([1])))]
128 self.onnx_ = self.inst_.to_onnx(inputs, outputs=outvar)
129 forced = True
130 elif self.onnx_node.op_type == 'ConstantOfShape':
131 for k in options:
132 v = options[k]
133 if isinstance(v, numpy.ndarray):
134 options[k] = make_tensor(
135 k, self._guess_proto_type(v.dtype),
136 v.shape, v.tolist())
138 self.inst_ = self.alg_class(*self.inputs, output_names=self.outputs,
139 op_version=target_opset, **options)
140 inputs = get_defined_inputs(
141 self.inputs, variables, dtype=self.dtype)
142 try:
143 self.onnx_ = self.inst_.to_onnx(inputs, target_opset=target_opset,
144 domain=domain)
145 if "dim_value: 0" in str(self.onnx_):
146 raise RuntimeError( # pragma: no cover
147 "Probable issue as one dimension is null.\n--\n{}".format(
148 self.onnx_))
149 except AttributeError as e: # pragma: no cover
150 # older version of skl2onnx
151 self.onnx_ = self.inst_.to_onnx(inputs)
152 if "dim_value: 0" in str(self.onnx_):
153 raise RuntimeError(
154 "Probable issue as one dimension is null.\n--\n{}".format(
155 self.onnx_)) from e
156 forced = False
157 elif self.onnx_node.op_type == 'Scan':
158 self.inst_ = self.alg_class(
159 *self.inputs, output_names=self.outputs,
160 op_version=target_opset, **options)
161 inputs = get_defined_inputs(
162 self.inputs, variables, dtype=self.dtype)
163 outputs = get_defined_outputs(
164 self.outputs, self.onnx_node, inputs, variables,
165 dtype=self.dtype)
166 inputs = [(name, cl.__class__([None, None]))
167 for (name, cl) in inputs]
168 outputs = [(name, cl.__class__([None, None]))
169 for (name, cl) in outputs]
170 self.onnx_ = self.inst_.to_onnx(inputs, outputs=outputs,
171 target_opset=target_opset,
172 domain=domain)
173 if "dim_value: 0" in str(self.onnx_):
174 raise RuntimeError( # pragma: no cover
175 "Probable issue as one dimension is null.\n--\n{}".format(
176 self.onnx_))
177 forced = True
178 else:
179 self.inst_ = self.alg_class(*self.inputs, output_names=self.outputs,
180 op_version=target_opset, domain=domain,
181 **options)
182 inputs = get_defined_inputs(
183 self.inputs, variables, dtype=self.dtype,
184 schema=self.alg_class.expected_inputs)
186 try:
187 self.onnx_ = self.inst_.to_onnx(
188 inputs, target_opset=target_opset, domain=domain)
189 if "dim_value: 0" in str(self.onnx_):
190 raise RuntimeError( # pragma: no cover
191 "Probable issue as one dimension is null.\n--\n{}\n---\n{}".format(
192 self.onnx_, inputs))
193 forced = False
194 except (RuntimeError, ValueError, InferenceError) as eo:
195 # Let's try again by forcing output types.
196 forced = True
197 outputs = get_defined_outputs(
198 self.outputs, self.onnx_node, inputs, variables,
199 dtype=self.dtype, schema=self.alg_class.expected_outputs,
200 schema_inputs=self.alg_class.expected_inputs)
201 try:
202 self.onnx_ = self.inst_.to_onnx(inputs, outputs=outputs,
203 target_opset=target_opset,
204 domain=domain)
205 except NotImplementedError as e:
206 raise NotImplementedError(
207 "Unable to instantiate node {} inputs={} "
208 "self.inputs={} outputs={} variables={} "
209 "dtype={} e={} eo={}".format(
210 self.alg_class, inputs, self.inputs,
211 outputs, variables, self.dtype, e, eo)) from e
212 if "dim_value: 0" in str(self.onnx_):
213 raise RuntimeError( # pragma: no cover
214 "Probable issue as one dimension is null.\n--\n{}".format(
215 self.onnx_)) from e
217 if len(self.onnx_.graph.output) != len(self.outputs): # pragma: no cover
218 # Something is wrong, falls back to default plan.
219 forced = True
220 outputs = get_defined_outputs(
221 self.outputs, self.onnx_node, inputs, variables,
222 dtype=self.dtype, schema=self.alg_class.expected_outputs)
223 self.onnx_ = self.inst_.to_onnx(inputs, outputs=outputs,
224 target_opset=target_opset,
225 domain=domain)
226 if "dim_value: 0" in str(self.onnx_):
227 raise RuntimeError( # pragma: no cover
228 "Probable issue as one dimension is null.\n--\n{}".format(
229 self.onnx_))
230 else:
231 lo = list(self.onnx_.graph.output)
232 outputs = proto2vars(lo)
234 sess_options = session_options or SessionOptions()
235 self.run_options = RunOptions()
237 if session_options is None:
238 try:
239 sess_options.session_log_severity_level = 3
240 # sess_options.sessions_log_verbosity_level = 0
241 except AttributeError:
242 # onnxruntime not recent enough.
243 pass
244 try:
245 self.run_options.run_log_severity_level = 3
246 # self.run_options.run_log_verbosity_level = 0
247 except AttributeError:
248 # onnxruntime not recent enough.
249 pass
250 if disable_optimisation:
251 sess_options.graph_optimization_level = ( # pragma: no cover
252 GraphOptimizationLevel.ORT_DISABLE_ALL)
253 elif disable_optimisation:
254 raise RuntimeError( # pragma: no cover
255 "session_options and disable_optimisation cannot be defined "
256 "at the same time.")
258 if ir_version is not None:
259 self.onnx_.ir_version = ir_version
260 try:
261 self.sess_ = InferenceSession(
262 self.onnx_.SerializeToString(), sess_options=sess_options)
263 except (RuntimeError, OrtNotImplemented, OrtInvalidGraph, OrtFail) as e:
264 raise RuntimeError(
265 "Unable to load node '{}' (output type was {}) inputs={} "
266 "self.inputs={} self.onnx_node.input={} "
267 "variables={} mapping={} "
268 "expected_inputs={}\n{}".format(
269 self.onnx_node.op_type,
270 "guessed" if forced else "inferred",
271 inputs, self.inputs, self.onnx_node.input,
272 variables, self.mapping,
273 self.alg_class.expected_inputs,
274 self.onnx_)) from e
275 self.typed_outputs_ = outputs
277 def run(self, *args, **kwargs):
278 """
279 Should be overwritten.
280 """
281 inputs = {name: val for name, val in zip(self.inputs, args)}
283 try:
284 res = self.sess_.run(None, inputs, self.run_options)
285 except (RuntimeError, OrtInvalidArgument) as e: # pragma: no cover
286 dtypes = {k: v.dtype for k, v in inputs.items()}
287 shapes = {k: v.shape for k, v in inputs.items()}
288 exp = [_.name for _ in self.sess_.get_inputs()]
289 exp_types = [_.type for _ in self.sess_.get_inputs()]
290 raise RuntimeError(
291 "Predictions failed. List of inputs: {}, class={}"
292 "\ndtypes={}\nshapes={}\nexpected={}\nexpected={}\n"
293 "exception={}\n--ONNX--\n{}".format(
294 list(sorted(inputs)), self.alg_class, dtypes,
295 shapes, exp, exp_types, e, self.onnx_)) from e
296 return tuple(res)
298 def need_context(self):
299 """
300 Tells the runtime if this node needs the context
301 (all the results produced so far) as it may silently access
302 one of them (operator Loop).
303 The default answer is `False`.
304 """
305 return False