Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Experimental implementation. 

4""" 

5from collections import OrderedDict 

6import numpy 

7 

8 

9def custom_pad(arr, paddings, constant=0, verbose=False): 

10 """ 

11 Implements function 

12 `pad <https://numpy.org/doc/stable/reference/ 

13 generated/numpy.pad.html>`_ in python, 

14 only the constant version. 

15 

16 :param arr: array 

17 :param paddings: paddings 

18 :param constant: constant 

19 :return: padded array 

20 """ 

21 if paddings.shape[0] != len(arr.shape): 

22 raise ValueError( # pragma: no cover 

23 "Input shape {} and paddings {} are inconsistent.".format( 

24 arr.shape, paddings)) 

25 if min(paddings.ravel()) < 0: 

26 raise NotImplementedError("Negative paddings is not implemented yet.") 

27 if not arr.flags['C_CONTIGUOUS']: 

28 arr = numpy.ascontiguousarray(arr) 

29 

30 new_shape = tuple( 

31 a + s for a, s in zip(arr.shape, numpy.sum(paddings, axis=1, keepdims=0))) 

32 

33 cumulative_copy = [1] 

34 for a in reversed(new_shape): 

35 cumulative_copy.insert(0, a * cumulative_copy[0]) 

36 cumulative_input = [1] 

37 for a in reversed(arr.shape): 

38 cumulative_input.insert(0, a * cumulative_input[0]) 

39 

40 input_arr = arr.ravel() 

41 if verbose: 

42 res = numpy.zeros(cumulative_copy[0], dtype=arr.dtype) - 1 

43 else: 

44 res = numpy.empty(cumulative_copy[0], dtype=arr.dtype) 

45 

46 # preparation 

47 first_index = sum( 

48 p * c for p, c in zip(paddings[:, 0], cumulative_copy[1:])) 

49 dh_input = arr.shape[-1] 

50 dh_copy = new_shape[-1] 

51 

52 # constance 

53 no_constant = 1 if constant == 0 else 0 

54 res[first_index:cumulative_copy[0]:dh_copy] = no_constant 

55 

56 # padding 

57 for i, sh in enumerate(new_shape): 

58 upper_number = cumulative_copy[0] // cumulative_copy[i] 

59 contiguous = cumulative_copy[i + 1] 

60 big_index = 0 

61 p_left = paddings[i, 0] * contiguous 

62 p_right = paddings[i, 1] * contiguous 

63 dp = sh * contiguous - p_right 

64 for _ in range(upper_number): 

65 if p_left > 0: 

66 res[big_index:big_index + p_left] = constant 

67 if p_right > 0: 

68 index = big_index + dp 

69 res[index:index + p_right] = constant 

70 big_index += cumulative_copy[i] 

71 

72 # copy 

73 index_input = 0 

74 index_copy = first_index 

75 while index_copy < cumulative_copy[0]: 

76 if res[index_copy] == no_constant: 

77 res[index_copy:index_copy + dh_input] = \ 

78 input_arr[index_input:index_input + dh_input] 

79 index_input += dh_input 

80 index_copy += dh_copy 

81 

82 # final 

83 return res.reshape(new_shape) 

84 

85 

86def custom_einsum(equation, x, y, verbose=False): 

87 """ 

88 Experimental implementation of operator Einsum 

89 when it does a matrix multiplication. 

90 Case: ``bsnh,btnh->bnts`` with shapes 

91 `(1,512,12,64)` and `(1,512,12,64)`. 

92 

93 :param equation: equation 

94 :param x: first matrix 

95 :param y: second matrix 

96 :param verbose: display internal information 

97 :return: result of *einsum* 

98 

99 This implementation does not any transpose, 

100 it does a direct computation of the final result. 

101 It does not implementation diagonal summation (square product). 

102 """ 

103 def _check_eq(eq, sh): 

104 if len(eq) != len(sh): 

105 raise ValueError( 

106 "Unable to map equation %r to shape %r." % (eq, sh)) 

107 

108 def _split(eq, sh): 

109 dx = OrderedDict((e, (v, i)) for i, (e, v) in enumerate(zip(eq, sh))) 

110 return dx 

111 

112 def _interpret(dx, dy, eqr): 

113 c_uni = [] 

114 c_trp = [] 

115 c_sum = [] 

116 for r in eqr: 

117 if r in dx: 

118 if r in dy: 

119 if dx[r][0] != dy[r][0]: 

120 raise ValueError( 

121 "Dimension mismatch for letter " 

122 "%r dx=%r dy=%r." % (r, dx, dy)) 

123 c_trp.append(r) 

124 else: 

125 c_uni.append((r, None)) 

126 elif r in dy: 

127 c_uni.append((None, r)) 

128 else: 

129 raise ValueError( # pragma: no cover 

130 "Unexpected letter %r in result %r." % (r, eqr)) 

131 for c in dx: 

132 if c not in eqr: 

133 if c not in dy: 

134 raise ValueError( # pragma: no cover 

135 "Unable to guess what to do with column %r (left side)" % c) 

136 if dx[c][0] != dy[c][0]: 

137 raise ValueError( # pragma: no cover 

138 "Dimension mismatch for letter " 

139 "%r dx=%r dy=%r." % (c, dx, dy)) 

140 c_sum.append(c) 

141 for c in dy: 

142 if c not in eqr and c not in dx: 

143 raise ValueError( # pragma: no cover 

144 "Unable to guess what to do with column %r (right side)" % c) 

145 shape = OrderedDict() 

146 for i, r in enumerate(eqr): 

147 if r in c_trp: 

148 shape[r] = (dx[r][0], i) 

149 else: 

150 for a, b in c_uni: 

151 if a == r: 

152 shape[r] = (dx[r][0], i) 

153 break 

154 if b == r: 

155 shape[r] = (dy[r][0], i) 

156 break 

157 if len(shape) != len(eqr): 

158 raise RuntimeError( # pragma: no cover 

159 "Unable to compute the output shape " 

160 "dx=%r dy=%r eqr=%r got shape=%r." % (dx, dy, eqr, shape)) 

161 return shape, c_trp, c_uni, c_sum 

162 

163 def _inc(d): 

164 t = 1 

165 drev = list(reversed(d.items())) 

166 res = [] 

167 for c, (sh, p) in drev: 

168 res.append((c, (t, p))) 

169 t *= sh 

170 return OrderedDict(reversed(res)) 

171 

172 def prod(seq): 

173 p = 1 

174 for s in seq: 

175 p *= s 

176 return p 

177 

178 def get_index(cd, shape, index, col_sum): 

179 ind = 0 

180 for c, i in zip(shape, index): 

181 if c in cd: 

182 inc = cd[c][0] 

183 ind += inc * i 

184 return ind, cd[col_sum][0] 

185 

186 def get_incs(cd, shape): 

187 incs = [] 

188 for c in shape: 

189 inc = cd[c][0] if c in cd else 0 

190 incs.append(inc) 

191 return incs 

192 

193 if x.dtype != y.dtype: 

194 raise RuntimeError("x and y must have the same dtype.") 

195 eqx = equation.split(',')[0] 

196 eqy = equation.split(',')[-1].split('->')[0] 

197 eqr = equation.split('->')[-1] 

198 _check_eq(eqx, x.shape) 

199 _check_eq(eqy, y.shape) 

200 dx = _split(eqx, x.shape) 

201 dy = _split(eqy, y.shape) 

202 shape, __, _, c_sum = _interpret(dx, dy, eqr) 

203 cdx = _inc(dx) 

204 cdy = _inc(dy) 

205 xrav = x.ravel() 

206 yrav = y.ravel() 

207 full_size = prod(v[0] for v in shape.values()) 

208 zrav = numpy.empty((full_size, ), dtype=x.dtype) 

209 

210 # loop 

211 if len(c_sum) != 1: 

212 raise NotImplementedError( 

213 "More than one summation indices %r in equation %r." % ( 

214 c_sum, equation)) 

215 zeros = numpy.zeros((1, ), dtype=x.dtype) 

216 shape_dims = [v[0] for v in shape.values()] 

217 index = [0 for s in shape] 

218 len_index = len(index) 

219 loop_size = dx[c_sum[0]][0] 

220 

221 i_left_loop, inc_left = get_index(cdx, shape, index, c_sum[0]) 

222 i_right_loop, inc_right = get_index(cdy, shape, index, c_sum[0]) 

223 left_incs = get_incs(cdx, shape) 

224 right_incs = get_incs(cdy, shape) 

225 

226 if verbose: 

227 def MakeString(*args): 

228 return "".join(map(str, args)) 

229 

230 print(MakeString("equation=", equation)) 

231 print(MakeString("c_sum=", c_sum)) 

232 print(MakeString("full_size=", full_size)) 

233 print(MakeString("loop_size=", loop_size)) 

234 print(MakeString("i_left_loop=", i_left_loop)) 

235 print(MakeString("i_right_loop=", i_right_loop)) 

236 print(MakeString("inc_left=", inc_left)) 

237 print(MakeString("inc_right=", inc_right)) 

238 print(MakeString("left_incs=", left_incs)) 

239 print(MakeString("right_incs=", right_incs)) 

240 print(MakeString("shape=", shape)) 

241 print(MakeString("cdx=", cdx)) 

242 print(MakeString("cdy=", cdy)) 

243 

244 for i in range(0, full_size): 

245 

246 i_left = i_left_loop 

247 i_right = i_right_loop 

248 

249 # summation 

250 add = zeros[0] 

251 for _ in range(loop_size): 

252 add += xrav[i_left] * yrav[i_right] 

253 i_left += inc_left 

254 i_right += inc_right 

255 zrav[i] = add 

256 

257 if verbose: 

258 print(MakeString( 

259 " -- index=", index, " ii=", i, 

260 " i_left_loop=", i_left_loop, " i_right_loop=", i_right_loop, 

261 " add=", add)) 

262 

263 # increment 

264 pos = len_index - 1 

265 index[pos] += 1 

266 i_left_loop += left_incs[pos] 

267 i_right_loop += right_incs[pos] 

268 while pos > 0 and index[pos] >= shape_dims[pos]: 

269 i_left_loop -= left_incs[pos] * index[pos] 

270 i_right_loop -= right_incs[pos] * index[pos] 

271 index[pos] = 0 

272 pos -= 1 

273 index[pos] += 1 

274 i_left_loop += left_incs[pos] 

275 i_right_loop += right_incs[pos] 

276 

277 new_shape = tuple(v[0] for v in shape.values()) 

278 return zrav.reshape(new_shape)