Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# -*- encoding: utf-8 -*- 

2""" 

3@file 

4@brief Shortcut to *ops_cpu*. 

5""" 

6from onnx.defs import onnx_opset_version 

7from ...tools.asv_options_helper import benchmark_version 

8from ._op import OpRunCustom 

9from ._op_list import __dict__ as d_op_list 

10 

11 

12_additional_ops = {} 

13 

14 

15def register_operator(cls, name=None, overwrite=True): 

16 """ 

17 Registers a new runtime operator. 

18 

19 @param cls class 

20 @param name by default ``cls.__name__``, 

21 or *name* if defined 

22 @param overwrite overwrite or raise an exception 

23 """ 

24 if name is None: 

25 name = cls.__name__ 

26 if name not in _additional_ops: 

27 _additional_ops[name] = cls 

28 elif not overwrite: 

29 raise RuntimeError( # pragma: no cover 

30 "Unable to overwrite existing operator '{}': {} " 

31 "by {}".format(name, _additional_ops[name], cls)) 

32 

33 

34def get_opset_number_from_onnx(benchmark=False): 

35 """ 

36 Retuns the current :epkg:`onnx` opset 

37 based on the installed version of :epkg:`onnx`. 

38 

39 @param benchmark returns the latest 

40 version usable for benchmark 

41 @eturn opset number 

42 """ 

43 if benchmark: 

44 return benchmark_version()[-1] 

45 return onnx_opset_version() 

46 

47 

48def load_op(onnx_node, desc=None, options=None): 

49 """ 

50 Gets the operator related to the *onnx* node. 

51 

52 @param onnx_node :epkg:`onnx` node 

53 @param desc internal representation 

54 @param options runtime options 

55 @return runtime class 

56 """ 

57 if desc is None: 

58 raise ValueError("desc should not be None.") # pragma no cover 

59 name = onnx_node.op_type 

60 opset = options.get('target_opset', None) if options is not None else None 

61 current_opset = get_opset_number_from_onnx() 

62 chosen_opset = current_opset 

63 if opset == current_opset: 

64 opset = None 

65 if opset is not None: 

66 if not isinstance(opset, int): 

67 raise TypeError( # pragma no cover 

68 "opset must be an integer not {}".format(type(opset))) 

69 name_opset = name + "_" + str(opset) 

70 for op in range(opset, 0, -1): 

71 nop = name + "_" + str(op) 

72 if nop in d_op_list: 

73 name_opset = nop 

74 chosen_opset = op 

75 break 

76 else: 

77 name_opset = name 

78 

79 if name_opset in _additional_ops: 

80 cl = _additional_ops[name_opset] 

81 elif name in _additional_ops: 

82 cl = _additional_ops[name] 

83 elif name_opset in d_op_list: 

84 cl = d_op_list[name_opset] 

85 elif name in d_op_list: 

86 cl = d_op_list[name] 

87 else: 

88 raise NotImplementedError( # pragma no cover 

89 "Operator '{}' has no runtime yet. Available list:\n" 

90 "{}\n--- +\n{}".format( 

91 name, "\n".join(sorted(_additional_ops)), 

92 "\n".join( 

93 _ for _ in sorted(d_op_list) 

94 if "_" not in _ and _ not in {'cl', 'clo', 'name'}))) 

95 

96 if hasattr(cl, 'version_higher_than'): 

97 opv = min(current_opset, chosen_opset) 

98 if cl.version_higher_than > opv: 

99 # The chosen implementation does not support 

100 # the opset version, we need to downgrade it. 

101 if ('target_opset' in options and 

102 options['target_opset'] is not None): # pragma: no cover 

103 raise RuntimeError( 

104 "Supported version {} > {} (opset={}) required version, " 

105 "unable to find an implementation version {} found " 

106 "'{}'\n--ONNX--\n{}\n--AVAILABLE--\n{}".format( 

107 cl.version_higher_than, opv, opset, 

108 options['target_opset'], cl.__name__, onnx_node, 

109 "\n".join( 

110 _ for _ in sorted(d_op_list) 

111 if "_" not in _ and _ not in {'cl', 'clo', 'name'}))) 

112 options = options.copy() 

113 options['target_opset'] = current_opset 

114 return load_op(onnx_node, desc=desc, options=options) 

115 

116 if options is None: 

117 options = {} # pragma: no cover 

118 return cl(onnx_node, desc=desc, **options)