Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# -*- encoding: utf-8 -*-
2"""
3@file
4@brief Shortcut to *ops_cpu*.
5"""
6from onnx.defs import onnx_opset_version
7from ...tools.asv_options_helper import benchmark_version
8from ._op import OpRunCustom
9from ._op_list import __dict__ as d_op_list
12_additional_ops = {}
15def register_operator(cls, name=None, overwrite=True):
16 """
17 Registers a new runtime operator.
19 @param cls class
20 @param name by default ``cls.__name__``,
21 or *name* if defined
22 @param overwrite overwrite or raise an exception
23 """
24 if name is None:
25 name = cls.__name__
26 if name not in _additional_ops:
27 _additional_ops[name] = cls
28 elif not overwrite:
29 raise RuntimeError( # pragma: no cover
30 "Unable to overwrite existing operator '{}': {} "
31 "by {}".format(name, _additional_ops[name], cls))
34def get_opset_number_from_onnx(benchmark=False):
35 """
36 Retuns the current :epkg:`onnx` opset
37 based on the installed version of :epkg:`onnx`.
39 @param benchmark returns the latest
40 version usable for benchmark
41 @eturn opset number
42 """
43 if benchmark:
44 return benchmark_version()[-1]
45 return onnx_opset_version()
48def load_op(onnx_node, desc=None, options=None):
49 """
50 Gets the operator related to the *onnx* node.
52 @param onnx_node :epkg:`onnx` node
53 @param desc internal representation
54 @param options runtime options
55 @return runtime class
56 """
57 if desc is None:
58 raise ValueError("desc should not be None.") # pragma no cover
59 name = onnx_node.op_type
60 opset = options.get('target_opset', None) if options is not None else None
61 current_opset = get_opset_number_from_onnx()
62 chosen_opset = current_opset
63 if opset == current_opset:
64 opset = None
65 if opset is not None:
66 if not isinstance(opset, int):
67 raise TypeError( # pragma no cover
68 "opset must be an integer not {}".format(type(opset)))
69 name_opset = name + "_" + str(opset)
70 for op in range(opset, 0, -1):
71 nop = name + "_" + str(op)
72 if nop in d_op_list:
73 name_opset = nop
74 chosen_opset = op
75 break
76 else:
77 name_opset = name
79 if name_opset in _additional_ops:
80 cl = _additional_ops[name_opset]
81 elif name in _additional_ops:
82 cl = _additional_ops[name]
83 elif name_opset in d_op_list:
84 cl = d_op_list[name_opset]
85 elif name in d_op_list:
86 cl = d_op_list[name]
87 else:
88 raise NotImplementedError( # pragma no cover
89 "Operator '{}' has no runtime yet. Available list:\n"
90 "{}\n--- +\n{}".format(
91 name, "\n".join(sorted(_additional_ops)),
92 "\n".join(
93 _ for _ in sorted(d_op_list)
94 if "_" not in _ and _ not in {'cl', 'clo', 'name'})))
96 if hasattr(cl, 'version_higher_than'):
97 opv = min(current_opset, chosen_opset)
98 if cl.version_higher_than > opv:
99 # The chosen implementation does not support
100 # the opset version, we need to downgrade it.
101 if ('target_opset' in options and
102 options['target_opset'] is not None): # pragma: no cover
103 raise RuntimeError(
104 "Supported version {} > {} (opset={}) required version, "
105 "unable to find an implementation version {} found "
106 "'{}'\n--ONNX--\n{}\n--AVAILABLE--\n{}".format(
107 cl.version_higher_than, opv, opset,
108 options['target_opset'], cl.__name__, onnx_node,
109 "\n".join(
110 _ for _ in sorted(d_op_list)
111 if "_" not in _ and _ not in {'cl', 'clo', 'name'})))
112 options = options.copy()
113 options['target_opset'] = current_opset
114 return load_op(onnx_node, desc=desc, options=options)
116 if options is None:
117 options = {} # pragma: no cover
118 return cl(onnx_node, desc=desc, **options)