Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from onnx.defs import onnx_opset_version 

9from ._op import OpRun 

10from ..shape_object import ShapeObject 

11 

12 

13def _check_dtype(val): 

14 a = val.dtype 

15 if not isinstance(a, numpy.dtype) and a not in { 

16 numpy.int8, numpy.uint8, numpy.float16, numpy.float32, 

17 numpy.float64, numpy.int32, numpy.int64, numpy.int16, 

18 numpy.uint16, numpy.uint32, numpy.bool_, numpy.str_, 

19 numpy.uint64, bool, str, }: 

20 raise TypeError( # pragma: no cover 

21 "Type ({}, {}) is not a numpy type (operator 'Constant')".format( 

22 a, type(a))) 

23 

24 

25class Constant_9(OpRun): 

26 

27 atts = {'value': numpy.array([0], dtype=numpy.float32)} 

28 

29 def __init__(self, onnx_node, desc=None, **options): 

30 OpRun.__init__(self, onnx_node, desc=desc, 

31 expected_attributes=Constant_9.atts, 

32 **options) 

33 self.cst = self.value 

34 _check_dtype(self.cst) 

35 

36 def _run(self): # pylint: disable=W0221 

37 return (self.cst, ) 

38 

39 def _infer_shapes(self): # pylint: disable=W0221 

40 # pref = str(hex(id(self))[2:]) 

41 return (ShapeObject(self.cst.shape, self.cst.dtype), ) 

42 

43 def _infer_types(self): # pylint: disable=W0221 

44 # pref = str(hex(id(self))[2:]) 

45 return (self.cst.dtype, ) 

46 

47 def _infer_sizes(self, *args, **kwargs): 

48 res = self.run(*args, **kwargs) 

49 return (dict(temp=0), ) + res 

50 

51 

52class Constant_11(OpRun): 

53 

54 atts = {'value': numpy.array([0], dtype=numpy.float32), 

55 'sparse_value': None, } 

56 

57 def __init__(self, onnx_node, desc=None, **options): 

58 OpRun.__init__(self, onnx_node, desc=desc, 

59 expected_attributes=Constant_11.atts, 

60 **options) 

61 if getattr(self, 'sparse_value', None) is not None: 

62 self.cst = self.sparse_value 

63 else: 

64 self.cst = self.value 

65 _check_dtype(self.cst) 

66 

67 def _run(self): # pylint: disable=W0221 

68 return (self.cst, ) 

69 

70 def _infer_shapes(self): # pylint: disable=W0221 

71 # pref = str(hex(id(self))[2:]) 

72 return (ShapeObject(self.cst.shape, self.cst.dtype), ) 

73 

74 def _infer_types(self): # pylint: disable=W0221 

75 # pref = str(hex(id(self))[2:]) 

76 return (self.cst.dtype, ) 

77 

78 def _infer_sizes(self, *args, **kwargs): 

79 res = self.run(*args, **kwargs) 

80 return (dict(temp=0), ) + res 

81 

82 

83class Constant_12(OpRun): 

84 

85 atts = {'value': numpy.array([0], dtype=numpy.float32), 

86 'sparse_value': None, 

87 'value_float': None, 

88 'value_floats': None, 

89 'value_int': None, 

90 'value_ints': None, 

91 'value_string': None, 

92 'value_strings': None, 

93 } 

94 

95 def __init__(self, onnx_node, desc=None, **options): 

96 OpRun.__init__(self, onnx_node, desc=desc, 

97 expected_attributes=Constant_12.atts, 

98 **options) 

99 if hasattr(self, 'sparse_value') and self.sparse_value is not None: 

100 self.cst = self.sparse_value 

101 elif hasattr(self, 'value_float') and self.value_float is not None: 

102 self.cst = self.value_float.astype(numpy.float32) 

103 elif hasattr(self, 'value_floats') and self.value_floats is not None: 

104 self.cst = self.value_floats.astype(numpy.float32) 

105 elif hasattr(self, 'value_int') and self.value_int is not None: 

106 self.cst = self.value_int.astype(numpy.int64) 

107 elif hasattr(self, 'value_ints') and self.value_ints is not None: 

108 self.cst = self.value_ints.astype(numpy.int64) 

109 elif hasattr(self, 'value_string') and self.value_string is not None: 

110 self.cst = self.value_string 

111 elif hasattr(self, 'value_strings') and self.value_strings is not None: 

112 self.cst = self.value_strings 

113 elif hasattr(self, 'value') and self.value is not None: 

114 self.cst = self.value 

115 else: 

116 raise AttributeError( 

117 "No constant is defined for operator 'Constant'.") 

118 _check_dtype(self.cst) 

119 

120 def _run(self): # pylint: disable=W0221 

121 return (self.cst, ) 

122 

123 def _infer_shapes(self): # pylint: disable=W0221 

124 # pref = str(hex(id(self))[2:]) 

125 return (ShapeObject(self.cst.shape, self.cst.dtype), ) 

126 

127 def _infer_types(self): # pylint: disable=W0221 

128 # pref = str(hex(id(self))[2:]) 

129 return (self.cst.dtype, ) 

130 

131 def _infer_sizes(self, *args, **kwargs): 

132 res = self.run(*args, **kwargs) 

133 return (dict(temp=0), ) + res 

134 

135 

136if onnx_opset_version() >= 12: 

137 Constant = Constant_12 

138elif onnx_opset_version() >= 11: # pragma: no cover 

139 Constant = Constant_11 

140else: # pragma: no cover 

141 Constant = Constant_9