Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7from onnx.defs import onnx_opset_version 

8from ._op import OpRun 

9from ..shape_object import DimensionObject, ShapeObject 

10 

11 

12class CommonSplit(OpRun): 

13 """ 

14 Runtime for operator *Split*. 

15 """ 

16 

17 def __init__(self, onnx_node, desc=None, 

18 expected_attributes=None, **options): 

19 if 'split' not in options: 

20 options['split'] = None 

21 OpRun.__init__(self, onnx_node, desc=desc, 

22 expected_attributes=expected_attributes, 

23 **options) 

24 self.nb_outputs = len(onnx_node.output) 

25 

26 def common_run(self, mat, split): # pylint: disable=W0221 

27 if split is None: 

28 div = mat.shape[self.axis] // self.nb_outputs 

29 split = [div] * self.nb_outputs 

30 split[-1] += mat.shape[self.axis] - sum(split) 

31 sli = [slice(0, s) for s in mat.shape] 

32 res = [] 

33 pos = 0 

34 for spl in split: 

35 sli[self.axis] = slice(pos, pos + spl) 

36 pos += spl 

37 res.append(mat[tuple(sli)]) 

38 return tuple(res) 

39 

40 def common_infer_shapes(self, data, split): # pylint: disable=W0221 

41 if split is None: 

42 return tuple([ShapeObject(None, dtype=data.dtype) 

43 for o in range(self.nb_outputs)]) 

44 res = [] 

45 pos = 0 

46 for spl in split: 

47 shape = data.copy() 

48 shape[self.axis] = DimensionObject(spl) 

49 pos += spl 

50 res.append(shape) 

51 return tuple(res) 

52 

53 def _infer_types(self, data, split): # pylint: disable=W0221 

54 if split is None: 

55 return tuple([data for o in range(self.nb_outputs)]) 

56 return tuple(data for _ in split) 

57 

58 def _infer_sizes(self, *args, **kwargs): # pylint: disable=W0221 

59 res = self.run(*args, **kwargs) 

60 return (dict(temp=0), ) + res 

61 

62 

63class Split_2(CommonSplit): 

64 """ 

65 Runtime for operator *Split*. 

66 """ 

67 

68 atts = {'axis': 0, 'split': None} 

69 

70 def __init__(self, onnx_node, desc=None, **options): 

71 CommonSplit.__init__(self, onnx_node, desc=desc, 

72 expected_attributes=Split_2.atts, **options) 

73 

74 def _run(self, mat): # pylint: disable=W0221 

75 return self.common_run(mat, self.split) 

76 

77 def _infer_shapes(self, data): # pylint: disable=W0221 

78 return self.common_infer_shapes(data, self.split) 

79 

80 def _infer_types(self, data): # pylint: disable=W0221 

81 if self.split is None: 

82 return tuple([data for o in range(self.nb_outputs)]) 

83 return tuple(data for _ in self.split) 

84 

85 

86class Split_11(Split_2): 

87 """ 

88 Runtime for operator *Split*. 

89 """ 

90 pass 

91 

92 

93class Split_13(CommonSplit): 

94 """ 

95 Runtime for operator *Split*. 

96 """ 

97 

98 atts = {'axis': 0} 

99 

100 def __init__(self, onnx_node, desc=None, **options): 

101 CommonSplit.__init__(self, onnx_node, desc=desc, 

102 expected_attributes=Split_13.atts, **options) 

103 

104 def _run(self, mat, split=None): # pylint: disable=W0221 

105 return self.common_run(mat, split) 

106 

107 def _infer_shapes(self, data, split=None): # pylint: disable=W0221 

108 return tuple([ShapeObject(None, dtype=data.dtype) 

109 for o in range(self.nb_outputs)]) 

110 

111 def _infer_types(self, data, split=None): # pylint: disable=W0221 

112 return tuple(data for o in range(self.nb_outputs)) 

113 

114 

115if onnx_opset_version() >= 13: 

116 Split = Split_13 

117elif onnx_opset_version() >= 11: # pragma: no cover 

118 Split = Split_11 

119else: # pragma: no cover 

120 Split = Split_2