Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from onnx.defs import onnx_opset_version 

9from ._op import OpRun 

10from ..shape_object import ShapeObject 

11 

12 

13class CommonRNN(OpRun): 

14 

15 def __init__(self, onnx_node, expected_attributes=None, desc=None, 

16 **options): 

17 OpRun.__init__(self, onnx_node, desc=desc, 

18 expected_attributes=expected_attributes, 

19 **options) 

20 

21 if self.direction in ("forward", "reverse"): 

22 self.num_directions = 1 

23 elif self.direction == "bidirectional": 

24 self.num_directions = 2 

25 else: 

26 raise RuntimeError( # pragma: no cover 

27 "Unknown direction '{}'.".format(self.direction)) 

28 

29 if len(self.activation_alpha) != self.num_directions: 

30 raise RuntimeError( # pragma: no cover 

31 "activation_alpha must have the same size as num_directions={}".format( 

32 self.num_directions)) 

33 if len(self.activation_beta) != self.num_directions: 

34 raise RuntimeError( # pragma: no cover 

35 "activation_beta must have the same size as num_directions={}".format( 

36 self.num_directions)) 

37 

38 self.f1 = self.choose_act(self.activations[0], 

39 self.activation_alpha[0], 

40 self.activation_beta[0]) 

41 if len(self.activations) > 1: 

42 self.f2 = self.choose_act(self.activations[1], 

43 self.activation_alpha[1], 

44 self.activation_beta[1]) 

45 self.nb_outputs = len(onnx_node.output) 

46 if getattr(self, 'layout', 0) != 0: 

47 raise NotImplementedError( 

48 "The runtime is not implemented when layout=%r != 0." % self.layout) 

49 

50 def choose_act(self, name, alpha, beta): 

51 if name == b"Tanh": 

52 return self._f_tanh 

53 if name == b"Affine": 

54 return lambda x: x * alpha + beta 

55 raise RuntimeError( # pragma: no cover 

56 "Unknown activation function '{}'.".format(name)) 

57 

58 def _f_tanh(self, x): 

59 return numpy.tanh(x) 

60 

61 def _step(self, X, R, B, W, H_0): 

62 h_list = [] 

63 H_t = H_0 

64 for x in numpy.split(X, X.shape[0], axis=0): 

65 H = self.f1(numpy.dot(x, numpy.transpose(W)) + 

66 numpy.dot(H_t, numpy.transpose(R)) + 

67 numpy.add(*numpy.split(B, 2))) 

68 h_list.append(H) 

69 H_t = H 

70 concatenated = numpy.concatenate(h_list) 

71 if self.num_directions == 1: 

72 output = numpy.expand_dims(concatenated, 1) 

73 return output, h_list[-1] 

74 

75 def _run(self, X, W, R, B=None, sequence_lens=None, initial_h=None): # pylint: disable=W0221 

76 self.num_directions = W.shape[0] 

77 

78 if self.num_directions == 1: 

79 R = numpy.squeeze(R, axis=0) 

80 W = numpy.squeeze(W, axis=0) 

81 if B is not None: 

82 B = numpy.squeeze(B, axis=0) 

83 if sequence_lens is not None: 

84 sequence_lens = numpy.squeeze(sequence_lens, axis=0) 

85 if initial_h is not None: 

86 initial_h = numpy.squeeze(initial_h, axis=0) 

87 

88 hidden_size = R.shape[-1] 

89 batch_size = X.shape[1] 

90 

91 b = (B if B is not None else 

92 numpy.zeros(2 * hidden_size, dtype=numpy.float32)) 

93 h_0 = (initial_h if initial_h is not None else 

94 numpy.zeros((batch_size, hidden_size), dtype=numpy.float32)) 

95 

96 B = b 

97 H_0 = h_0 

98 else: 

99 raise NotImplementedError() # pragma: no cover 

100 

101 Y, Y_h = self._step(X, R, B, W, H_0) 

102 return (Y, ) if self.nb_outputs == 1 else (Y, Y_h) 

103 

104 def _infer_shapes(self, X, W, R, B=None, sequence_lens=None, initial_h=None): # pylint: disable=W0221 

105 num_directions = W.shape[0] 

106 

107 if num_directions == 1: 

108 hidden_size = R[-1] 

109 batch_size = X[1] 

110 y_shape = ShapeObject((X[0], num_directions, batch_size, hidden_size), 

111 dtype=X.dtype) 

112 else: 

113 raise NotImplementedError() # pragma: no cover 

114 if self.nb_outputs == 1: 

115 return (y_shape, ) 

116 y_h_shape = ShapeObject((num_directions, batch_size, hidden_size), 

117 dtype=X.dtype) 

118 return (y_shape, y_h_shape) 

119 

120 def _infer_types(self, X, W, R, B=None, sequence_lens=None, initial_h=None): # pylint: disable=W0221 

121 return (X, X) 

122 

123 

124class RNN_7(CommonRNN): 

125 

126 atts = { 

127 'activation_alpha': [0.], 

128 'activation_beta': [0.], 

129 'activations': ['tanh', 'tanh'], 

130 'clip': [], 

131 'direction': 'forward', 

132 'hidden_size': None, 

133 } 

134 

135 def __init__(self, onnx_node, desc=None, **options): 

136 CommonRNN.__init__(self, onnx_node, desc=desc, 

137 expected_attributes=RNN_7.atts, 

138 **options) 

139 

140 

141class RNN_14(CommonRNN): 

142 

143 atts = { 

144 'activation_alpha': [0.], 

145 'activation_beta': [0.], 

146 'activations': ['tanh', 'tanh'], 

147 'clip': [], 

148 'direction': 'forward', 

149 'hidden_size': None, 

150 'layout': 0, 

151 } 

152 

153 def __init__(self, onnx_node, desc=None, **options): 

154 CommonRNN.__init__(self, onnx_node, desc=desc, 

155 expected_attributes=RNN_14.atts, 

156 **options) 

157 

158 

159if onnx_opset_version() >= 14: 

160 RNN = RNN_14 

161else: # pragma: no cover 

162 RNN = RNN_7