Einsum decomposition

Links: notebook, html, PDF, python, slides, GitHub

This notebook shows a way to decompose einsum into a subset of operations (expand_dims, squeeze, transpose, extended matrix multiplication).

from jyquickhelper import add_notebook_menu
add_notebook_menu()
%load_ext mlprodict

Operator explanation with equation bac,cd,def=ebc

The operator einsum takes an equation and some inputs. Every letter involved in the equation is a loop. Let’s see on one example.

import numpy

m1 = numpy.arange(0, 8).astype(numpy.float32).reshape((2, 2, 2)) + 10
m2 = numpy.arange(0, 4).astype(numpy.float32).reshape((2, 2)) + 100
m3 = numpy.arange(0, 8).astype(numpy.float32).reshape((2, 2, 2)) + 1000

equation = "bac,cd,def->ebc"
truth = numpy.einsum(equation, m1, m2, m3)
truth
array([[[ 8866198.,  9864696.],
        [12090270., 13152928.]],
       [[ 8883886.,  9884376.],
        [12114390., 13179168.]]], dtype=float32)

This summation is equalent to:

res = numpy.zeros((2, 2, 2))
for a in range(0, 2):
    for b in range(0, 2):
        for c in range(0, 2):
            for d in range(0, 2):
                for e in range(0, 2):
                    for f in range(0, 2):
                        res[e, b, c] += m1[b, a, c] * m2[c, d] * m3[d, e, f]
res
array([[[ 8866198.,  9864696.],
        [12090270., 13152928.]],
       [[ 8883886.,  9884376.],
        [12114390., 13179168.]]])

Theoritically, this summation is in this case has a cost of O(N^6). However this simple computation is usually much longer than using matrix multiplications along the path. O(N^4) is the cost of the heaviest matrix multiplication in this case). But to do that, the equation needs to be decomposed into a sequence of matrix multiplications.

Decomposition of bac,cd,def=ebc

import numpy
from mlprodict.testing.einsum import (
    decompose_einsum_equation, apply_einsum_sequence)
m1 = numpy.arange(0, 8).astype(numpy.float32).reshape((2, 2, 2)) + 10
m2 = numpy.arange(0, 4).astype(numpy.float32).reshape((2, 2)) + 100
m3 = numpy.arange(0, 8).astype(numpy.float32).reshape((2, 2, 2)) + 1000
seq = decompose_einsum_equation("bac,cd,def->ebc")
from jyquickhelper import RenderJsDot
RenderJsDot(seq.to_dot(size=7))

Then the result can be obtained as follows:

apply_einsum_sequence(seq, m1, m2, m3)
array([[[ 8866198.,  9864696.],
        [12090270., 13152928.]],
       [[ 8883886.,  9884376.],
        [12114390., 13179168.]]], dtype=float32)

operator matmul

This operator can be used to represent either a multiplication, either a matrix multiplication but it applies only on arrays with the same number of dimensions. It can be broken into multiplication of matrix multiplication.

seq_clean = decompose_einsum_equation("bac,cd,def->ebc", strategy='numpy', clean=True)
RenderJsDot(seq_clean.to_dot(size=7))

Operator transpose_mm is a regular transposition, it takes two inputs but only tranposes the first input before returning it. Operator batch_dot is a matrix multiplication. It is left that way on purpose as it may be implemented with function dot or gemm. The operator distinguishes between 3 kind of axes: batch axes, kept axes, sum(mation) axes. It then reshapes both input matrices with 3D tensors, batch axis, row axis, column axis to use function numpy.dot.

ONNX

The previous graph can be converted into ONNX.

onx = seq_clean.to_onnx("Y", "X1", "X2", "X3", dtype=numpy.float32)
# with open("einsum.onnx", "wb") as f:
#      f.write(onx.SerializeToString())
%onnxview onx
from onnxruntime import InferenceSession
sess = InferenceSession(onx.SerializeToString())
sess.run(None, {'X1': m1.astype(numpy.float32),
                'X2': m2.astype(numpy.float32),
                'X3': m3.astype(numpy.float32)})[0]
array([[[ 8866198.,  9864696.],
        [12090270., 13152928.]],
       [[ 8883886.,  9884376.],
        [12114390., 13179168.]]], dtype=float32)

onnxruntime

import onnx
from onnx import helper, numpy_helper
from onnxruntime import InferenceSession


def make_model1(equation):
    model = helper.make_model(
        opset_imports=[helper.make_operatorsetid('', 13)],
        graph=helper.make_graph(
            name='einsum_test',
            inputs=[helper.make_tensor_value_info("X", onnx.TensorProto.FLOAT, None),
                    helper.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, None),
                    helper.make_tensor_value_info("Z", onnx.TensorProto.FLOAT, None)],
            outputs=[helper.make_tensor_value_info("A", onnx.TensorProto.FLOAT, None)],
            nodes=[
                helper.make_node("Einsum", ["X", "Y", "Z"], ["A"], equation=equation)
            ]
        )
    )
    return model


model = make_model1("bac,cd,def->ebc")
sess = InferenceSession(model.SerializeToString())
sess.run(None, {'X': m1.astype(numpy.float32),
                'Y': m2.astype(numpy.float32),
                'Z': m3.astype(numpy.float32)})[0]
array([[[ 8866198.,  9864696.],
        [12090270., 13152928.]],
       [[ 8883886.,  9884376.],
        [12114390., 13179168.]]], dtype=float32)

Benchmark

It clearly shows the summation done with the basic algorithm is the slowest.

from mlprodict.onnxrt.validate.validate_helper import measure_time
from tqdm import tqdm
from pandas import DataFrame


def raw_product(m1, m2, m3):
    N = m1.shape[0]
    res = numpy.zeros((N, N, N))
    for a in range(0, N):
        for b in range(0, N):
            for c in range(0, N):
                for d in range(0, N):
                    for e in range(0, N):
                        for f in range(0, N):
                            res[e, b, c] += m1[b, a, c] * m2[c, d] * m3[d, e, f]
    return res


def benchmark0(equation):
    sess = None
    sess2 = None
    seq = None
    seq2 = None

    results = []
    for N in tqdm([2, 3, 4, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60]):
        m1 = numpy.random.randn(N, N, N)
        m2 = numpy.random.randn(N, N)
        m3 = numpy.random.randn(N, N, N)

        if seq is None:
            seq = decompose_einsum_equation(equation, clean=True)
        if seq2 is None:
            seq2 = decompose_einsum_equation(equation, clean=True, strategy='numpy')
        if sess is None:
            model = make_model1(equation)
            sess = InferenceSession(model.SerializeToString())
        if sess2 is None:
            onx = seq2.to_onnx("Y", "X1", "X2", "X3", dtype=numpy.float32)
            sess2 = InferenceSession(onx.SerializeToString())

        res = measure_time(lambda x: numpy.einsum(equation, *x, optimize=True),
                           [m1, m2, m3],
                           repeat=10, number=10)

        res['name'] = "numpy.einsum"
        res["N"] = N
        results.append(res)

        if N <= 4:
            res = measure_time(lambda x: raw_product(*x),
                               [m1, m2, m3],
                               repeat=10, number=10)
            res['name'] = "raw_product"
            res["N"] = N
            results.append(res)

        res = measure_time(lambda x: apply_einsum_sequence(seq, *x),
                           [m1, m2, m3],
                           repeat=10, number=10)

        res['name'] = "custom_einsum"
        res["N"] = N
        results.append(res)

        res = measure_time(lambda x: apply_einsum_sequence(seq, *x, matmul_impl="pyf"),
                           [m1, m2, m3],
                           repeat=10, number=10)
        res['name'] = "dec-matmul"
        res["N"] = N
        results.append(res)

        res = measure_time(lambda x: apply_einsum_sequence(seq2, *x, matmul_impl="pyf"),
                           [m1, m2, m3],
                           repeat=10, number=10)
        res['name'] = "dec-batch_dot"
        res["N"] = N
        results.append(res)

        res = measure_time(lambda x: sess.run(None, {'X': x[0], 'Y': x[1], 'Z': x[2]}),
                           [m1.astype(numpy.float32), m2.astype(numpy.float32),
                            m3.astype(numpy.float32)],
                           repeat=10, number=10)
        res['name'] = "ort-einsum"
        res["N"] = N
        results.append(res)

        res = measure_time(lambda x: sess2.run(None, {'X1': x[0], 'X2': x[1], 'X3': x[2]}),
                           [m1.astype(numpy.float32), m2.astype(numpy.float32),
                            m3.astype(numpy.float32)],
                           repeat=10, number=10)
        res['name'] = "ort-matmul"
        res["N"] = N
        results.append(res)
    return DataFrame(results)

df = benchmark0("bac,cd,def->ebc")
df.tail()
C:xavierdupre__home_github_forkscikit-learnsklearnexperimentalenable_hist_gradient_boosting.py:16: UserWarning: Since version 1.0, it is not needed to import enable_hist_gradient_boosting anymore. HistGradientBoostingClassifier and HistGradientBoostingRegressor are now stable and can be normally imported from sklearn.ensemble.
  warnings.warn(
100%|██████████| 14/14 [00:20<00:00,  1.47s/it]
average deviation min_exec max_exec repeat number total name N
82 0.065132 0.001338 0.063801 0.068927 10 10 0.651318 custom_einsum 60
83 0.051615 0.001206 0.049987 0.053465 10 10 0.516154 dec-matmul 60
84 0.062689 0.003658 0.058949 0.073073 10 10 0.626888 dec-batch_dot 60
85 0.009917 0.000274 0.009737 0.010686 10 10 0.099166 ort-einsum 60
86 0.015518 0.001107 0.014413 0.018179 10 10 0.155178 ort-matmul 60
import matplotlib.pyplot as plt

piv = df.pivot("N", "name", "average")
piv2 = piv.copy()
np = piv["numpy.einsum"]
for c in piv2.columns:
    piv2[c] /= np

fig, ax = plt.subplots(1, 2, figsize=(14, 6))
piv.plot(logy=True, logx=True, ax=ax[0])
ax[0].set_title("Benchmark einsum function\nbac,cd,def->ebc")
piv2.plot(logy=True, logx=True, ax=ax[1])
ax[1].set_title("Benchmark einsum function\n(ratio, baseline=numpy)");
../_images/einsum_decomposition_26_0.png

Version dec-matmul is an implementation based on the decomposition of a simplified einsum into a sequence of transpose, reshape, (batch_)dot or mul operations. This decomposition is converted into ONNX and executed with onnxruntime, version ort-matmul. Both versions are faster than the numpy optimized version.

Another example with bsnh,btnh=bnts

Another case, more frequent in deep learning.

Decomposition of bsnh,btnh=bnts

seq2 = decompose_einsum_equation("bsnh,btnh->bnts", strategy='numpy', clean=True)
RenderJsDot(seq2.to_dot(size=7))

ONNX version

onx2 = seq2.to_onnx("Y", "X1", "X2", dtype=numpy.float32)
%onnxview onx2

Benchmark

def make_model2(equation):
    model = helper.make_model(
        opset_imports=[helper.make_operatorsetid('', 13)],
        graph=helper.make_graph(
            name='einsum_test',
            inputs=[helper.make_tensor_value_info("X", onnx.TensorProto.FLOAT, None),
                    helper.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, None)],
            outputs=[helper.make_tensor_value_info("A", onnx.TensorProto.FLOAT, None)],
            nodes=[
                helper.make_node("Einsum", ["X", "Y"], ["A"], equation=equation)
            ]
        )
    )
    return model


def benchmark(equation, second_input_size=4):
    sess = None
    sess2 = None
    seq = None
    seq2 = None


    results = []
    for N in tqdm([2, 3, 4, 10, 20, 30, 40]):
        m1 = numpy.random.randn(10, N, N, N)
        m2 = numpy.random.randn(10 * N ** (second_input_size-1)).reshape((10, ) + (N, ) * (second_input_size-1))


        if seq is None:
            seq = decompose_einsum_equation(equation, clean=True)
        if seq2 is None:
            seq2 = decompose_einsum_equation(equation, clean=True, strategy='numpy')
        if sess is None:
            model = make_model2(equation)
            sess = InferenceSession(model.SerializeToString())
        if sess2 is None:
            onx = seq2.to_onnx("Y", "X1", "X2", dtype=numpy.float32)
            sess2 = InferenceSession(onx.SerializeToString())

        res = measure_time(lambda x: numpy.einsum(equation, *x, optimize=True),
                           [m1, m2],
                           repeat=10, number=10)

        res['name'] = "numpy.einsum"
        res["N"] = N
        results.append(res)

        res = measure_time(lambda x: apply_einsum_sequence(seq, *x),
                           [m1, m2],
                           repeat=10, number=10)
        res['name'] = "custom_einsum"
        res["N"] = N
        results.append(res)

        res = measure_time(lambda x: apply_einsum_sequence(seq, *x, matmul_impl="pyf"),
                           [m1, m2],
                           repeat=10, number=10)
        res['name'] = "dec-matmul"
        res["N"] = N
        results.append(res)

        res = measure_time(lambda x: apply_einsum_sequence(seq2, *x, matmul_impl="pyf"),
                           [m1, m2],
                           repeat=10, number=10)
        res['name'] = "dec-batch_dot"
        res["N"] = N
        results.append(res)

        res = measure_time(lambda x: sess.run(None, {'X': x[0], 'Y': x[1]}),
                           [m1.astype(numpy.float32), m2.astype(numpy.float32),
                            m3.astype(numpy.float32)],
                           repeat=10, number=10)
        res['name'] = "ort-einsum"
        res["N"] = N
        results.append(res)

        res = measure_time(lambda x: sess2.run(None, {'X1': x[0], 'X2': x[1]}),
                           [m1.astype(numpy.float32), m2.astype(numpy.float32),
                            m3.astype(numpy.float32)],
                           repeat=10, number=10)
        res['name'] = "ort-matmul"
        res["N"] = N
        results.append(res)
    return DataFrame(results)


df = benchmark("bsnh,btnh->bnts")
df.tail()
100%|██████████| 7/7 [00:13<00:00,  1.93s/it]
average deviation min_exec max_exec repeat number total name N
37 0.229418 0.020792 0.217997 0.291032 10 10 2.294175 custom_einsum 40
38 0.160575 0.005435 0.150772 0.167411 10 10 1.605746 dec-matmul 40
39 0.112844 0.011305 0.102173 0.141890 10 10 1.128436 dec-batch_dot 40
40 0.051181 0.003533 0.047244 0.057054 10 10 0.511815 ort-einsum 40
41 0.078827 0.008735 0.067893 0.099156 10 10 0.788271 ort-matmul 40
piv = df.pivot("N", "name", "average")
piv2 = piv.copy()
np = piv["numpy.einsum"]
for c in piv2.columns:
    piv2[c] /= np

fig, ax = plt.subplots(1, 2, figsize=(14, 6))
piv.plot(logy=True, logx=True, ax=ax[0])
ax[0].set_title("Benchmark einsum function\nbsnh,btnh->bnts")
piv2.plot(logy=True, logx=True, ax=ax[1])
ax[1].set_title("Benchmark einsum function\n(ratio, baseline=numpy)");
../_images/einsum_decomposition_35_0.png

Permutation

Einsum’s algorithm started by aligning all matrices involved in the computation to the same dimension in the same order. But which order is the best, that’s the question.

equation = "bsnh,btnh->bnts"
letters = list(sorted(set([c for c in equation if "a" <= c < "z"])))
letters
['b', 'h', 'n', 's', 't']
from itertools import permutations


def benchmark_perm(equation, number=5, second_input_size=4, repeat=3, N=15):

    def n_operator(seq, name):
        n = 0
        for op in seq:
            if op.name == name:
                n += 1
        return n


    def n_onnx_op(onx, name):
        n = 0
        for op in onx.graph.node:
            if op.op_type == name:
                n += 1
        return n


    def get_kind(seq):
        n = 0
        for op in seq:
            if op.name == 'batch_dot':
                return op.get_dot_kind()
        return None


    m1 = numpy.random.randn(N, N, N, N)
    m2 = numpy.random.randn(N ** second_input_size).reshape((N, ) * second_input_size)

    results = []
    for perm in tqdm(list(permutations(letters))):
        replace = {d: c for c, d in zip(letters, perm)}
        eq = equation
        for k, v in replace.items():
            eq = eq.replace(k, v.upper())
        eq = eq.lower()

        seq = decompose_einsum_equation(eq, clean=True)
        seq2 = decompose_einsum_equation(eq, clean=True, strategy='numpy')
        model = make_model2(eq)
        sess = InferenceSession(model.SerializeToString())
        onx = seq2.to_onnx("Y", "X1", "X2", dtype=numpy.float32)
        sess2 = InferenceSession(onx.SerializeToString())

        n_tra = n_operator(seq2, 'transpose')
        n_tra_onnx = n_onnx_op(onx, 'Transpose')
        n_gemm_onnx = n_onnx_op(onx, 'Gemm')
        kind = get_kind(seq2)

        res = measure_time(lambda x: numpy.einsum(eq, *x, optimize=True),
                           [m1, m2],
                           repeat=repeat, number=number)

        res['name'] = "numpy.einsum"
        res["N"] = N
        res["eq"] = eq
        results.append(res)

        res = measure_time(lambda x: apply_einsum_sequence(seq, *x),
                           [m1, m2],
                           repeat=repeat, number=number)
        res['name'] = "custom_einsum"
        res["N"] = N
        res["eq"] = eq
        res['transpose'] = n_tra
        res['kind'] = kind
        results.append(res)

        res = measure_time(lambda x: apply_einsum_sequence(seq, *x, matmul_impl="pyf"),
                           [m1, m2],
                           repeat=repeat, number=number)
        res['name'] = "dec-matmul"
        res["N"] = N
        res["eq"] = eq
        res['transpose'] = n_tra
        res['kind'] = kind
        results.append(res)

        res = measure_time(lambda x: apply_einsum_sequence(seq2, *x, matmul_impl="pyf"),
                           [m1, m2],
                           repeat=repeat, number=number)
        res['name'] = "dec-batch_dot"
        res["N"] = N
        res["eq"] = eq
        res['transpose'] = n_tra
        res['kind'] = kind
        results.append(res)

        res = measure_time(lambda x: sess.run(None, {'X': x[0], 'Y': x[1]}),
                           [m1.astype(numpy.float32), m2.astype(numpy.float32),
                            m3.astype(numpy.float32)],
                           repeat=repeat, number=number)
        res['name'] = "ort-einsum"
        res["N"] = N
        res["eq"] = eq
        res['transpose'] = n_tra_onnx
        res['gemm'] = n_gemm_onnx
        results.append(res)

        res = measure_time(lambda x: sess2.run(None, {'X1': x[0], 'X2': x[1]}),
                           [m1.astype(numpy.float32), m2.astype(numpy.float32),
                            m3.astype(numpy.float32)],
                           repeat=repeat, number=number)
        res['name'] = "ort-matmul"
        res["N"] = N
        res["eq"] = eq
        res['transpose'] = n_tra_onnx
        res['gemm'] = n_gemm_onnx
        results.append(res)
    return DataFrame(results)


df = benchmark_perm("bsnh,btnh->bnts", number=4)
df.tail()
100%|██████████| 120/120 [00:11<00:00, 10.23it/s]
average deviation min_exec max_exec repeat number total name N eq transpose kind gemm
715 0.006162 0.000038 0.006128 0.006216 3 4 0.018485 custom_einsum 15 thns,tbns->tnbh 3.0 NN NaN
716 0.002343 0.000046 0.002294 0.002405 3 4 0.007029 dec-matmul 15 thns,tbns->tnbh 3.0 NN NaN
717 0.001645 0.000035 0.001610 0.001694 3 4 0.004934 dec-batch_dot 15 thns,tbns->tnbh 3.0 NN NaN
718 0.000833 0.000015 0.000820 0.000853 3 4 0.002498 ort-einsum 15 thns,tbns->tnbh 4.0 NaN 0.0
719 0.001251 0.000012 0.001238 0.001268 3 4 0.003753 ort-matmul 15 thns,tbns->tnbh 4.0 NaN 0.0
df = df.sort_values("average").reset_index(drop=True)
df.head()
average deviation min_exec max_exec repeat number total name N eq transpose kind gemm
0 0.000758 0.000015 0.000738 0.000771 3 4 0.002275 ort-matmul 15 hsnt,hbnt->hnbs 4.0 NaN 0.0
1 0.000770 0.000023 0.000739 0.000793 3 4 0.002310 ort-matmul 15 hnts,hbts->htbn 4.0 NaN 0.0
2 0.000778 0.000020 0.000758 0.000806 3 4 0.002334 ort-matmul 15 bnst,bhst->bshn 4.0 NaN 0.0
3 0.000783 0.000021 0.000760 0.000812 3 4 0.002350 ort-matmul 15 bnht,bsht->bhsn 4.0 NaN 0.0
4 0.000784 0.000011 0.000774 0.000799 3 4 0.002351 ort-matmul 15 hnst,hbst->hsbn 4.0 NaN 0.0
df.tail()
average deviation min_exec max_exec repeat number total name N eq transpose kind gemm
715 0.011529 0.000882 0.010456 0.012617 3 4 0.034587 custom_einsum 15 sbnt,shnt->snhb 3.0 NN NaN
716 0.011548 0.000422 0.010967 0.011953 3 4 0.034644 custom_einsum 15 htsb,hnsb->hsnt 3.0 NN NaN
717 0.013971 0.001984 0.012279 0.016754 3 4 0.041912 custom_einsum 15 nbsh,ntsh->nstb 3.0 NN NaN
718 0.014765 0.001483 0.013366 0.016818 3 4 0.044295 numpy.einsum 15 bnsh,btsh->bstn NaN NaN NaN
719 0.015813 0.002921 0.012546 0.019636 3 4 0.047438 numpy.einsum 15 nbsh,ntsh->nstb NaN NaN NaN
piv = df.pivot("eq", "name", "average").sort_values("numpy.einsum")

fig, ax = plt.subplots(1, 1, figsize=(14, 6))
piv.plot(logy=True, logx=True, ax=ax)
ax.set_title("Benchmark einsum function - bsnh,btnh->bnts");
../_images/einsum_decomposition_41_0.png
set(df['transpose'].dropna()), set(df['gemm'].dropna()), set(df['kind'].dropna())
({3.0, 4.0}, {0.0}, {'NN'})

Decomposition of bsnh,ctnh=nts

seq3 = decompose_einsum_equation("bsnh,ctnh->nts", strategy='numpy', clean=True)
RenderJsDot(seq3.to_dot(size=7))
onx3 = seq3.to_onnx("Y", "X1", "X2", dtype=numpy.float32)
%onnxview onx3

Benchmark size

df = benchmark("bsnh,ctnh->nts")
df.tail()
100%|██████████| 7/7 [00:39<00:00,  5.71s/it]
average deviation min_exec max_exec repeat number total name N
37 0.043389 0.016879 0.030195 0.077480 10 10 0.433885 custom_einsum 40
38 0.015310 0.000222 0.014909 0.015622 10 10 0.153098 dec-matmul 40
39 0.013508 0.000425 0.013148 0.014576 10 10 0.135085 dec-batch_dot 40
40 0.032725 0.000266 0.032409 0.033212 10 10 0.327254 ort-einsum 40
41 0.057384 0.002703 0.053734 0.062845 10 10 0.573841 ort-matmul 40
piv = df.pivot("N", "name", "average")
piv2 = piv.copy()
np = piv["numpy.einsum"]
for c in piv2.columns:
    piv2[c] /= np

fig, ax = plt.subplots(1, 2, figsize=(14, 6))
piv.plot(logy=True, logx=True, ax=ax[0])
ax[0].set_title("Benchmark einsum function\nbsnh,ctnh->nts")
piv2.plot(logy=True, logx=True, ax=ax[1])
ax[1].set_title("Benchmark einsum function\n(ratio, baseline=numpy)");
../_images/einsum_decomposition_48_0.png

Benchmark permutation

df = benchmark_perm("bsnh,ctnh->nts", number=2, repeat=3, N=10)
100%|██████████| 120/120 [00:06<00:00, 17.41it/s]
df = df.sort_values("average").reset_index(drop=True)
df.head()
average deviation min_exec max_exec repeat number total name N eq transpose kind gemm
0 0.000125 0.000008 0.000118 0.000136 3 2 0.000374 ort-matmul 10 bnst,chst->shn 4.0 NaN 0.0
1 0.000126 0.000007 0.000119 0.000136 3 2 0.000377 ort-matmul 10 bhst,cnst->snh 4.0 NaN 0.0
2 0.000141 0.000006 0.000136 0.000150 3 2 0.000422 ort-matmul 10 hbst,cnst->snb 5.0 NaN 0.0
3 0.000141 0.000007 0.000135 0.000151 3 2 0.000423 ort-matmul 10 nbst,chst->shb 5.0 NaN 0.0
4 0.000144 0.000007 0.000138 0.000154 3 2 0.000432 ort-matmul 10 btns,chns->nht 5.0 NaN 0.0
set(df['transpose'].dropna()), set(df['gemm'].dropna()), set(df['kind'].dropna())
({3.0, 4.0, 5.0, 6.0}, {0.0}, {'NN'})
piv = df.pivot("eq", "name", "average").sort_values("numpy.einsum")

fig, ax = plt.subplots(1, 1, figsize=(14, 6))
piv.plot(logy=True, logx=True, ax=ax)
ax.set_title("Benchmark einsum function");
../_images/einsum_decomposition_53_0.png

Best permutation

One of the best permutation is bnst,chst->shn.

seq4 = decompose_einsum_equation("bnst,chst->shn", strategy='numpy', clean=True)
RenderJsDot(seq4.to_dot(size=7))
onx4 = seq4.to_onnx("Y", "X1", "X2", dtype=numpy.float32)
%onnxview onx4