2021-08-11 Decompose einsum into numpy operators#

Notebook Einsum decomposition what function numpy.einsum does and how it can be decomposed into a series of basic operations, all available in ONNX. That’s the purpose of function Function decompose_einsum_equation. With function export2numpy, it is possible to convert back this ONNX graph into a series of numpy operations.

<<<

import numpy
from mlprodict.testing.einsum import decompose_einsum_equation
from mlprodict.onnx_tools.onnx_export import export2numpy

seq_clean = decompose_einsum_equation(
    "bsnh,btnh->bnts", strategy='numpy', clean=True)
onx = seq_clean.to_onnx("Y", "X1", "X2", dtype=numpy.float32)
code = export2numpy(onx, name="einsum", rename=True)
print(code)

>>>

    import numpy
    import scipy.special as scipy_special
    import scipy.spatial.distance as scipy_distance
    from mlprodict.onnx_tools.exports.numpy_helper import (
        argmax_use_numpy_select_last_index,
        argmin_use_numpy_select_last_index,
        array_feature_extrator,
        make_slice)
    
    
    def numpy_einsum(X1, X2):
        '''
        Numpy function for ``einsum``.
    
        * producer: mlprodict
        * version: 0
        * description: 
        '''
        # initializers
    
        D = numpy.array([0, 1], dtype=numpy.int64)
    
        E = numpy.array([4], dtype=numpy.int64)
    
        F = numpy.array([-1], dtype=numpy.int64)
    
        I = numpy.array([1], dtype=numpy.int64)
    
        # nodes
    
        K = X1
        L = numpy.expand_dims(K, axis=4)
        M = numpy.transpose(L, axes=(0, 2, 1, 4, 3))
        N = X2
        O = numpy.expand_dims(N, axis=3)
        P = numpy.transpose(O, axes=(0, 2, 3, 1, 4))
        Q = numpy.array(M.shape, dtype=numpy.int64)
        R = numpy.array(P.shape, dtype=numpy.int64)
        S = numpy.take(Q, D, axis=0)
        T = numpy.take(R, D, axis=0)
        U = S.prod(axis=0, keepdims=1)
        V = T.prod(axis=0, keepdims=1)
        W = numpy.take(Q, E, axis=0)
        X = numpy.take(R, E, axis=0)
        Z = numpy.concatenate([U, F, W], 0)
        BA = numpy.concatenate([V, F, X], 0)
        BB = M.reshape(tuple(Z))
        BC = P.reshape(tuple(BA))
        BD = numpy.transpose(BC, axes=(0, 2, 1))
        BE = BB @ BD
        BF = numpy.maximum(S, T)
        BG = numpy.take(Q, [2], axis=0)
        BH = numpy.take(R, [3], axis=0)
        BI = numpy.concatenate([BF, BG, BH, I], 0)
        BJ = BE.reshape(tuple(BI))
        BK = numpy.transpose(BJ, axes=(0, 4, 1, 3, 2))
        BL = numpy.squeeze(BK, axis=1)
        BM = BL
        Y = BM
    
        return Y

In some cases, it is faster to permute a matrix before doing a matrix multiplication. There exists many equivalent equation by permutating letters inside the initial equation. All leads to the same results but, once decomposed, they do different transpositions. The following code is obtained by looking for the best permutation and converting the optimized ONNX graph into numpy.

<<<

import numpy
from mlprodict.onnx_tools.onnx_export import export2numpy
from mlprodict.testing.einsum import optimize_decompose_einsum_equation

seq_opt = optimize_decompose_einsum_equation(
    "bsnh,btnh->bnts", numpy.float64, strategy='ml', verbose=1,
    runtime="python", optimize=True)

print("best equation:", seq_opt.equation_)
code = export2numpy(seq_opt.onnx_, name="einsum_opt", rename=True)
print(code)

>>>

    best equation: bhts,bnts->btnh
    import numpy
    import scipy.special as scipy_special
    import scipy.spatial.distance as scipy_distance
    from mlprodict.onnx_tools.exports.numpy_helper import (
        argmax_use_numpy_select_last_index,
        argmin_use_numpy_select_last_index,
        array_feature_extrator,
        make_slice)
    
    
    def numpy_einsum_opt(X0, X1):
        '''
        Numpy function for ``einsum_opt``.
    
        * producer: mlprodict
        * version: 0
        * description: 
        '''
        # initializers
    
        D = numpy.array([0, 1], dtype=numpy.int64)
    
        E = numpy.array([4], dtype=numpy.int64)
    
        F = numpy.array([-1], dtype=numpy.int64)
    
        I = numpy.array([1], dtype=numpy.int64)
    
        # nodes
    
        K = X0
        L = numpy.expand_dims(K, axis=2)
        M = numpy.transpose(L, axes=(0, 3, 1, 2, 4))
        N = X1
        O = numpy.expand_dims(N, axis=1)
        P = numpy.transpose(O, axes=(0, 3, 1, 2, 4))
        Q = numpy.array(M.shape, dtype=numpy.int64)
        R = numpy.array(P.shape, dtype=numpy.int64)
        S = numpy.take(Q, D, axis=0)
        T = numpy.take(R, D, axis=0)
        U = S.prod(axis=0, keepdims=1)
        V = T.prod(axis=0, keepdims=1)
        W = numpy.take(Q, E, axis=0)
        X = numpy.take(R, E, axis=0)
        Z = numpy.concatenate([U, F, W], 0)
        BA = numpy.concatenate([V, F, X], 0)
        BB = M.reshape(tuple(Z))
        BC = P.reshape(tuple(BA))
        BD = numpy.transpose(BC, axes=(0, 2, 1))
        BE = BB @ BD
        BF = numpy.maximum(S, T)
        BG = numpy.take(Q, [2], axis=0)
        BH = numpy.take(R, [3], axis=0)
        BI = numpy.concatenate([BF, BG, BH, I], 0)
        BJ = BE.reshape(tuple(BI))
        BK = numpy.transpose(BJ, axes=(0, 1, 3, 4, 2))
        BL = numpy.squeeze(BK, axis=3)
        BM = BL
        Y = BM
    
        return Y
    [runpythonerror]
    0%|          | 0/121 [00:00<?, ?it/s]
4.5 mlbest='bsnh,btnh->bnts':   0%|          | 0/121 [00:00<?, ?it/s]
4.5 mlbest='bsnh,btnh->bnts':   1%|          | 1/121 [00:00<01:22,  1.46it/s]
4.5 mlbest='bnth,bsth->btsn':   1%|          | 1/121 [00:00<01:22,  1.46it/s]
4.5 mlbest='bnth,bsth->btsn':   4%|▍         | 5/121 [00:00<00:14,  7.96it/s]
4.5 mlbest='bnth,bsth->btsn':   7%|▋         | 9/121 [00:00<00:07, 14.01it/s]
4.5 mlbest='bnht,bsht->bhsn':   7%|▋         | 9/121 [00:00<00:07, 14.01it/s]
4.5 mlbest='bnht,bsht->bhsn':  11%|█         | 13/121 [00:01<00:05, 19.34it/s]
4.5 mlbest='bhtn,bstn->btsh':  11%|█         | 13/121 [00:01<00:05, 19.34it/s]
4.5 mlbest='bhtn,bstn->btsh':  14%|█▍        | 17/121 [00:01<00:05, 19.86it/s]
4.5 mlbest='bhts,bnts->btnh':  14%|█▍        | 17/121 [00:01<00:05, 19.86it/s]
4.5 mlbest='bhts,bnts->btnh':  17%|█▋        | 21/121 [00:01<00:04, 23.79it/s]
4.5 mlbest='bhts,bnts->btnh':  21%|██        | 25/121 [00:01<00:03, 27.18it/s]
4.5 mlbest='bhts,bnts->btnh':  24%|██▍       | 29/121 [00:01<00:03, 29.87it/s]
4.5 mlbest='bhts,bnts->btnh':  27%|██▋       | 33/121 [00:01<00:02, 31.38it/s]
4.5 mlbest='bhts,bnts->btnh':  31%|███       | 37/121 [00:01<00:02, 32.98it/s]
4.5 mlbest='bhts,bnts->btnh':  34%|███▍      | 41/121 [00:01<00:02, 34.24it/s]
4.5 mlbest='bhts,bnts->btnh':  37%|███▋      | 45/121 [00:01<00:02, 34.65it/s]
4.5 mlbest='bhts,bnts->btnh':  40%|████      | 49/121 [00:02<00:02, 35.29it/s]
4.5 mlbest='bhts,bnts->btnh':  44%|████▍     | 53/121 [00:02<00:01, 35.95it/s]
4.5 mlbest='bhts,bnts->btnh':  47%|████▋     | 57/121 [00:02<00:01, 36.46it/s]
4.5 mlbest='bhts,bnts->btnh':  50%|█████     | 61/121 [00:02<00:01, 36.02it/s]
4.5 mlbest='bhts,bnts->btnh':  54%|█████▎    | 65/121 [00:02<00:01, 36.35it/s]
4.5 mlbest='bhts,bnts->btnh':  57%|█████▋    | 69/121 [00:02<00:01, 36.62it/s]
4.5 mlbest='bhts,bnts->btnh':  60%|██████    | 73/121 [00:02<00:01, 36.19it/s]
4.5 mlbest='bhts,bnts->btnh':  64%|██████▎   | 77/121 [00:02<00:01, 36.38it/s]
4.5 mlbest='bhts,bnts->btnh':  67%|██████▋   | 81/121 [00:02<00:01, 36.64it/s]
4.5 mlbest='bhts,bnts->btnh':  70%|███████   | 85/121 [00:03<00:00, 36.86it/s]
4.5 mlbest='bhts,bnts->btnh':  74%|███████▎  | 89/121 [00:03<00:00, 36.22it/s]
4.5 mlbest='bhts,bnts->btnh':  77%|███████▋  | 93/121 [00:03<00:00, 36.47it/s]
4.5 mlbest='bhts,bnts->btnh':  80%|████████  | 97/121 [00:03<00:00, 36.68it/s]
4.5 mlbest='bhts,bnts->btnh':  83%|████████▎ | 101/121 [00:03<00:00, 36.20it/s]
4.5 mlbest='bhts,bnts->btnh':  87%|████████▋ | 105/121 [00:03<00:00, 36.40it/s]
4.5 mlbest='bhts,bnts->btnh':  90%|█████████ | 109/121 [00:03<00:00, 36.60it/s]
4.5 mlbest='bhts,bnts->btnh':  93%|█████████▎| 113/121 [00:03<00:00, 36.82it/s]
4.5 mlbest='bhts,bnts->btnh':  97%|█████████▋| 117/121 [00:03<00:00, 36.17it/s]
4.5 mlbest='bhts,bnts->btnh': 100%|██████████| 121/121 [00:04<00:00, 36.42it/s]
4.5 mlbest='bhts,bnts->btnh': 100%|██████████| 121/121 [00:04<00:00, 29.94it/s]

The optimization was done for onnxruntime, that does not guarantee the result will be faster than with numpy.einsum. Let’s check…

<<<

import pprint
import numpy
from mlprodict.onnx_tools.exports.numpy_helper import (
    argmin_use_numpy_select_last_index,
    make_slice)
from cpyquickhelper.numbers.speed_measure import measure_time


def numpy_einsum(X1, X2):
    '''
    Numpy function for ``einsum``.

    * producer: mlprodict
    * version: 0
    * description:
    '''
    # initializers

    B = numpy.array([4], dtype=numpy.int64)
    C = numpy.array([3], dtype=numpy.int64)
    D = numpy.array([0, 1], dtype=numpy.int64)
    E = numpy.array([4], dtype=numpy.int64)
    F = numpy.array([-1], dtype=numpy.int64)
    G = numpy.array([2], dtype=numpy.int64)
    H = numpy.array([3], dtype=numpy.int64)
    I = numpy.array([1], dtype=numpy.int64)
    J = numpy.array([1], dtype=numpy.int64)

    # nodes

    K = X1
    L = numpy.expand_dims(K, axis=tuple(B))
    M = numpy.transpose(L, axes=(0, 2, 1, 4, 3))
    N = X2
    O = numpy.expand_dims(N, axis=tuple(C))
    P = numpy.transpose(O, axes=(0, 2, 3, 1, 4))
    Q = numpy.array(M.shape, dtype=numpy.int64)
    R = numpy.array(P.shape, dtype=numpy.int64)
    S = numpy.take(Q, D, axis=0)
    T = numpy.take(R, D, axis=0)
    U = S.prod(axis=0, keepdims=1)
    V = T.prod(axis=0, keepdims=1)
    W = numpy.take(Q, E, axis=0)
    X = numpy.take(R, E, axis=0)
    Z = numpy.concatenate([U, F, W], 0)
    BA = numpy.concatenate([V, F, X], 0)
    BB = M.reshape(tuple(Z))
    BC = P.reshape(tuple(BA))
    BD = numpy.transpose(BC, axes=(0, 2, 1))
    BE = BB @ BD
    BF = numpy.maximum(S, T)
    BG = numpy.take(Q, G, axis=0)
    BH = numpy.take(R, H, axis=0)
    BI = numpy.concatenate([BF, BG, BH, I], 0)
    BJ = BE.reshape(tuple(BI))
    BK = numpy.transpose(BJ, axes=(0, 4, 1, 3, 2))
    BL = numpy.squeeze(BK, axis=tuple(J))
    BM = BL
    Y = BM

    return Y


def numpy_einsum_opt(X0, X1):
    '''
    Numpy function for ``einsum``.

    * producer: mlprodict
    * version: 0
    * description:
    '''
    # initializers

    B = numpy.array([2], dtype=numpy.int64)
    C = numpy.array([1], dtype=numpy.int64)
    D = numpy.array([0, 1], dtype=numpy.int64)
    E = numpy.array([4], dtype=numpy.int64)
    F = numpy.array([-1], dtype=numpy.int64)
    G = numpy.array([2], dtype=numpy.int64)
    H = numpy.array([3], dtype=numpy.int64)
    I = numpy.array([1], dtype=numpy.int64)
    J = numpy.array([3], dtype=numpy.int64)

    # nodes

    K = X0
    L = numpy.expand_dims(K, axis=tuple(B))
    M = numpy.transpose(L, axes=(0, 3, 1, 2, 4))
    N = X1
    O = numpy.expand_dims(N, axis=tuple(C))
    P = numpy.transpose(O, axes=(0, 3, 1, 2, 4))
    Q = numpy.array(M.shape, dtype=numpy.int64)
    R = numpy.array(P.shape, dtype=numpy.int64)
    S = numpy.take(Q, D, axis=0)
    T = numpy.take(R, D, axis=0)
    U = S.prod(axis=0, keepdims=1)
    V = T.prod(axis=0, keepdims=1)
    W = numpy.take(Q, E, axis=0)
    X = numpy.take(R, E, axis=0)
    Z = numpy.concatenate([U, F, W], 0)
    BA = numpy.concatenate([V, F, X], 0)
    BB = M.reshape(tuple(Z))
    BC = P.reshape(tuple(BA))
    BD = numpy.transpose(BC, axes=(0, 2, 1))
    BE = BB @ BD
    BF = numpy.maximum(S, T)
    BG = numpy.take(Q, G, axis=0)
    BH = numpy.take(R, H, axis=0)
    BI = numpy.concatenate([BF, BG, BH, I], 0)
    BJ = BE.reshape(tuple(BI))
    BK = numpy.transpose(BJ, axes=(0, 1, 3, 4, 2))
    BL = numpy.squeeze(BK, axis=tuple(J))
    BM = BL
    Y = BM

    return Y


N = 2
m1 = numpy.random.randn(N, N, N, N)
m2 = numpy.random.randn(N, N, N, N)

print("Discrepencies?")
print(numpy.einsum("bsnh,btnh->bnts", m1, m2))
print(numpy_einsum(m1, m2))
print(numpy_einsum_opt(m1, m2))

N = 20
m1 = numpy.random.randn(N, N, N, N)
m2 = numpy.random.randn(N, N, N, N)

print('numpy.einsum')
res = measure_time(
    lambda: numpy.einsum("bsnh,btnh->bnts", m1, m2),
    repeat=10, number=20, div_by_number=True,
    context={'numpy': numpy, 'm1': m1, 'm2': m2})
pprint.pprint(res)

print('numpy.einsum decomposed')
res = measure_time(
    lambda: numpy_einsum(m1, m2),
    repeat=10, number=20, div_by_number=True,
    context={'numpy': numpy, 'm1': m1, 'm2': m2,
             'numpy_einsum': numpy_einsum})
pprint.pprint(res)

print('numpy.einsum decomposed and optimized')
res = measure_time(
    lambda: numpy_einsum_opt(m1, m2),
    repeat=10, number=20, div_by_number=True,
    context={'numpy': numpy, 'm1': m1, 'm2': m2,
             'numpy_einsum_opt': numpy_einsum_opt})
pprint.pprint(res)

>>>

    Discrepencies?
    [[[[ 0.596 -1.144]
       [ 0.249 -0.75 ]]
    
      [[-0.689  1.266]
       [ 0.329 -0.672]]]
    
    
     [[[-1.94   0.75 ]
       [-0.844  0.467]]
    
      [[-0.783  6.539]
       [ 1.329 -4.247]]]]
    [[[[ 0.596 -1.144]
       [ 0.249 -0.75 ]]
    
      [[-0.689  1.266]
       [ 0.329 -0.672]]]
    
    
     [[[-1.94   0.75 ]
       [-0.844  0.467]]
    
      [[-0.783  6.539]
       [ 1.329 -4.247]]]]
    [[[[ 0.596 -1.144]
       [ 0.249 -0.75 ]]
    
      [[-0.689  1.266]
       [ 0.329 -0.672]]]
    
    
     [[[-1.94   0.75 ]
       [-0.844  0.467]]
    
      [[-0.783  6.539]
       [ 1.329 -4.247]]]]
    numpy.einsum
    {'average': 0.010348163943272084,
     'context_size': 232,
     'deviation': 2.535631389049523e-05,
     'max_exec': 0.010421611747005954,
     'min_exec': 0.010331268195295706,
     'number': 20,
     'repeat': 10,
     'ttime': 0.10348163943272085}
    numpy.einsum decomposed
    {'average': 0.009003338220645671,
     'context_size': 232,
     'deviation': 1.2839754626705566e-05,
     'max_exec': 0.009013596147997304,
     'min_exec': 0.00896662615123205,
     'number': 20,
     'repeat': 10,
     'ttime': 0.09003338220645672}
    numpy.einsum decomposed and optimized
    {'average': 0.009032777614193039,
     'context_size': 232,
     'deviation': 6.272080934281115e-05,
     'max_exec': 0.009220432047732175,
     'min_exec': 0.009001977799925953,
     'number': 20,
     'repeat': 10,
     'ttime': 0.09032777614193038}

The optimization is not faster than the first decomposition but the decomposition is faster than the numpy implementation.