.. blogpost:: :title: Decompose einsum into numpy operators :keywords: einsum :date: 2021-08-11 :categories: onnx Notebook :ref:`einsumdecompositionrst` what function :epkg:`numpy:einsum` does and how it can be decomposed into a series of basic operations, all available in ONNX. That's the purpose of function Function :func:`decompose_einsum_equation `. With function :func:`export2numpy `, it is possible to convert back this ONNX graph into a series of numpy operations. .. runpython:: :showcode: :process: import numpy from mlprodict.testing.einsum import decompose_einsum_equation from mlprodict.onnx_tools.onnx_export import export2numpy seq_clean = decompose_einsum_equation( "bsnh,btnh->bnts", strategy='numpy', clean=True) onx = seq_clean.to_onnx("Y", "X1", "X2", dtype=numpy.float32) code = export2numpy(onx, name="einsum", rename=True) print(code) In some cases, it is faster to permute a matrix before doing a matrix multiplication. There exists many equivalent equation by permutating letters inside the initial equation. All leads to the same results but, once decomposed, they do different transpositions. The following code is obtained by looking for the best permutation and converting the optimized ONNX graph into *numpy*. .. runpython:: :showcode: :process: import numpy from mlprodict.onnx_tools.onnx_export import export2numpy from mlprodict.testing.einsum import optimize_decompose_einsum_equation seq_opt = optimize_decompose_einsum_equation( "bsnh,btnh->bnts", numpy.float64, strategy='ml', verbose=1, runtime="python", optimize=True) print("best equation:", seq_opt.equation_) code = export2numpy(seq_opt.onnx_, name="einsum_opt", rename=True) print(code) The optimization was done for :epkg:`onnxruntime`, that does not guarantee the result will be faster than with :epkg:`numpy:einsum`. Let's check... .. runpython:: :showcode: :process: import pprint import numpy from mlprodict.onnx_tools.exports.numpy_helper import ( argmin_use_numpy_select_last_index, make_slice) from cpyquickhelper.numbers.speed_measure import measure_time def numpy_einsum(X1, X2): ''' Numpy function for ``einsum``. * producer: mlprodict * version: 0 * description: ''' # initializers B = numpy.array([4], dtype=numpy.int64) C = numpy.array([3], dtype=numpy.int64) D = numpy.array([0, 1], dtype=numpy.int64) E = numpy.array([4], dtype=numpy.int64) F = numpy.array([-1], dtype=numpy.int64) G = numpy.array([2], dtype=numpy.int64) H = numpy.array([3], dtype=numpy.int64) I = numpy.array([1], dtype=numpy.int64) J = numpy.array([1], dtype=numpy.int64) # nodes K = X1 L = numpy.expand_dims(K, axis=tuple(B)) M = numpy.transpose(L, axes=(0, 2, 1, 4, 3)) N = X2 O = numpy.expand_dims(N, axis=tuple(C)) P = numpy.transpose(O, axes=(0, 2, 3, 1, 4)) Q = numpy.array(M.shape, dtype=numpy.int64) R = numpy.array(P.shape, dtype=numpy.int64) S = numpy.take(Q, D, axis=0) T = numpy.take(R, D, axis=0) U = S.prod(axis=0, keepdims=1) V = T.prod(axis=0, keepdims=1) W = numpy.take(Q, E, axis=0) X = numpy.take(R, E, axis=0) Z = numpy.concatenate([U, F, W], 0) BA = numpy.concatenate([V, F, X], 0) BB = M.reshape(tuple(Z)) BC = P.reshape(tuple(BA)) BD = numpy.transpose(BC, axes=(0, 2, 1)) BE = BB @ BD BF = numpy.maximum(S, T) BG = numpy.take(Q, G, axis=0) BH = numpy.take(R, H, axis=0) BI = numpy.concatenate([BF, BG, BH, I], 0) BJ = BE.reshape(tuple(BI)) BK = numpy.transpose(BJ, axes=(0, 4, 1, 3, 2)) BL = numpy.squeeze(BK, axis=tuple(J)) BM = BL Y = BM return Y def numpy_einsum_opt(X0, X1): ''' Numpy function for ``einsum``. * producer: mlprodict * version: 0 * description: ''' # initializers B = numpy.array([2], dtype=numpy.int64) C = numpy.array([1], dtype=numpy.int64) D = numpy.array([0, 1], dtype=numpy.int64) E = numpy.array([4], dtype=numpy.int64) F = numpy.array([-1], dtype=numpy.int64) G = numpy.array([2], dtype=numpy.int64) H = numpy.array([3], dtype=numpy.int64) I = numpy.array([1], dtype=numpy.int64) J = numpy.array([3], dtype=numpy.int64) # nodes K = X0 L = numpy.expand_dims(K, axis=tuple(B)) M = numpy.transpose(L, axes=(0, 3, 1, 2, 4)) N = X1 O = numpy.expand_dims(N, axis=tuple(C)) P = numpy.transpose(O, axes=(0, 3, 1, 2, 4)) Q = numpy.array(M.shape, dtype=numpy.int64) R = numpy.array(P.shape, dtype=numpy.int64) S = numpy.take(Q, D, axis=0) T = numpy.take(R, D, axis=0) U = S.prod(axis=0, keepdims=1) V = T.prod(axis=0, keepdims=1) W = numpy.take(Q, E, axis=0) X = numpy.take(R, E, axis=0) Z = numpy.concatenate([U, F, W], 0) BA = numpy.concatenate([V, F, X], 0) BB = M.reshape(tuple(Z)) BC = P.reshape(tuple(BA)) BD = numpy.transpose(BC, axes=(0, 2, 1)) BE = BB @ BD BF = numpy.maximum(S, T) BG = numpy.take(Q, G, axis=0) BH = numpy.take(R, H, axis=0) BI = numpy.concatenate([BF, BG, BH, I], 0) BJ = BE.reshape(tuple(BI)) BK = numpy.transpose(BJ, axes=(0, 1, 3, 4, 2)) BL = numpy.squeeze(BK, axis=tuple(J)) BM = BL Y = BM return Y N = 2 m1 = numpy.random.randn(N, N, N, N) m2 = numpy.random.randn(N, N, N, N) print("Discrepencies?") print(numpy.einsum("bsnh,btnh->bnts", m1, m2)) print(numpy_einsum(m1, m2)) print(numpy_einsum_opt(m1, m2)) N = 20 m1 = numpy.random.randn(N, N, N, N) m2 = numpy.random.randn(N, N, N, N) print('numpy.einsum') res = measure_time( lambda: numpy.einsum("bsnh,btnh->bnts", m1, m2), repeat=10, number=20, div_by_number=True, context={'numpy': numpy, 'm1': m1, 'm2': m2}) pprint.pprint(res) print('numpy.einsum decomposed') res = measure_time( lambda: numpy_einsum(m1, m2), repeat=10, number=20, div_by_number=True, context={'numpy': numpy, 'm1': m1, 'm2': m2, 'numpy_einsum': numpy_einsum}) pprint.pprint(res) print('numpy.einsum decomposed and optimized') res = measure_time( lambda: numpy_einsum_opt(m1, m2), repeat=10, number=20, div_by_number=True, context={'numpy': numpy, 'm1': m1, 'm2': m2, 'numpy_einsum_opt': numpy_einsum_opt}) pprint.pprint(res) The optimization is not faster than the first decomposition but the decomposition is faster than the numpy implementation.