Compares implementations of Einsum#

This example compares different equations for function numpy.einsum. It compares numpy implementation to a custom implementation, onnxruntime implementation and opt-einsum optimisation. If available, tensorflow and pytorch are included as well. The custom implementation does not do any transpose. It uses parallelisation and SIMD optimization when the summation happens on the last axis of both matrices. It only implements matrix multiplication. We also measure the improvment made with function einsum.

Available optimisation#

The code shows which optimisation is used for the custom implementation, AVX or SSE and the number of available processors, equal to the default number of used threads to parallelize.

import numpy
import pandas
import matplotlib.pyplot as plt
from onnxruntime import InferenceSession
from skl2onnx.common.data_types import FloatTensorType
from skl2onnx.algebra.onnx_ops import OnnxEinsum
from cpyquickhelper.numbers import measure_time
from tqdm import tqdm
from opt_einsum import contract
from mlprodict.testing.experimental_c_impl.experimental_c import (
    custom_einsum_float, code_optimisation)
from mlprodict.testing.einsum.einsum_fct import _einsum
print(code_optimisation())

Out:

AVX-omp=8

Einsum: common code#

try:
    from tensorflow import einsum as tf_einsum, convert_to_tensor
except ImportError:
    tf_einsum = None
try:
    from torch import einsum as torch_einsum, from_numpy
except ImportError:
    torch_einsum = None


def build_ort_einsum(equation, op_version=14):  # opset=13, 14, ...
    node = OnnxEinsum('x', 'y', equation=equation,
                      op_version=op_version,
                      output_names=['z'])
    onx = node.to_onnx(inputs=[('x', FloatTensorType()),
                               ('y', FloatTensorType())],
                       target_opset=op_version)
    sess = InferenceSession(onx.SerializeToString())
    return lambda x, y: sess.run(None, {'x': x, 'y': y})


def build_ort_decomposed(equation, op_version=14):  # opset=13, 14, ...
    cache = _einsum(equation, numpy.float32, opset=op_version,
                    optimize=True, verbose=True, runtime="python")
    if not hasattr(cache, 'onnx_'):
        cache.build()
    sess = InferenceSession(cache.onnx_.SerializeToString())
    return lambda x, y: sess.run(None, {'X0': x, 'X1': y})


def loop_einsum_eq(fct, equation, xs, ys):
    for x, y in zip(xs, ys):
        fct(equation, x, y)


def loop_einsum_eq_th(fct, equation, xs, ys):
    for x, y in zip(xs, ys):
        fct(equation, x, y, nthread=-1)


def loop_einsum(fct, xs, ys):
    for x, y in zip(xs, ys):
        fct(x, y)


def custom_einsum_float_tr(eq, x, y):
    if eq == "bshn,bthn->bnts":
        x = x.transpose((0, 1, 3, 2))
        y = y.transpose((0, 1, 3, 2))
        return custom_einsum_float("bsnh,btnh->bnts", x, y, nthread=-1)
    if eq == "bhsn,bhtn->bnts":
        x = x.transpose((0, 2, 3, 1))
        y = y.transpose((0, 2, 3, 1))
        return custom_einsum_float("bsnh,btnh->bnts", x, y, nthread=-1)
    return custom_einsum_float(eq, x, y, nthread=-1)


def benchmark_equation(equation):
    # equations
    ort_einsum = build_ort_einsum(equation)
    ort_einsum_decomposed = build_ort_decomposed(equation)
    res = []
    for dim in tqdm([8, 16, 32, 64, 100, 128, 200,
                     256, 500, 512]):
        xs = [numpy.random.rand(2, dim, 12, 64).astype(numpy.float32)
              for _ in range(5)]
        ys = [numpy.random.rand(2, dim, 12, 64).astype(numpy.float32)
              for _ in range(5)]

        # numpy
        ctx = dict(equation=equation, xs=xs, ys=ys, einsum=numpy.einsum,
                   loop_einsum=loop_einsum, loop_einsum_eq=loop_einsum_eq,
                   loop_einsum_eq_th=loop_einsum_eq_th)
        obs = measure_time(
            "loop_einsum_eq(einsum, equation, xs, ys)",
            div_by_number=True, context=ctx, repeat=5, number=1)
        obs['dim'] = dim
        obs['fct'] = 'numpy.einsum'
        res.append(obs)

        # opt-einsum
        ctx['einsum'] = contract
        obs = measure_time(
            "loop_einsum_eq(einsum, equation, xs, ys)",
            div_by_number=True, context=ctx, repeat=5, number=1)
        obs['dim'] = dim
        obs['fct'] = 'opt-einsum'
        res.append(obs)

        # onnxruntime
        ctx['einsum'] = ort_einsum
        obs = measure_time(
            "loop_einsum(einsum, xs, ys)",
            div_by_number=True, context=ctx, repeat=5, number=1)
        obs['dim'] = dim
        obs['fct'] = 'ort_einsum'
        res.append(obs)

        # onnxruntime decomposed
        ctx['einsum'] = ort_einsum_decomposed
        obs = measure_time(
            "loop_einsum(einsum, xs, ys)",
            div_by_number=True, context=ctx, repeat=5, number=1)
        obs['dim'] = dim
        obs['fct'] = 'ort_dec'
        res.append(obs)

        # custom implementation
        ctx['einsum'] = custom_einsum_float
        obs = measure_time(
            "loop_einsum_eq_th(einsum, equation, xs, ys)",
            div_by_number=True, context=ctx, repeat=5, number=1)
        obs['dim'] = dim
        obs['fct'] = 'c_einsum'
        res.append(obs)

        # transpose + custom implementation
        ctx['einsum'] = custom_einsum_float_tr
        obs = measure_time(
            "loop_einsum_eq(einsum, equation, xs, ys)",
            div_by_number=True, context=ctx, repeat=5, number=1)
        obs['dim'] = dim
        obs['fct'] = 'c_einsum_tr'
        res.append(obs)

        if tf_einsum is not None:
            # tensorflow
            ctx['einsum'] = tf_einsum
            ctx['xs'] = [convert_to_tensor(x) for x in xs]
            ctx['ys'] = [convert_to_tensor(y) for y in ys]
            obs = measure_time(
                "loop_einsum_eq(einsum, equation, xs, ys)",
                div_by_number=True, context=ctx, repeat=5, number=1)
            obs['dim'] = dim
            obs['fct'] = 'tf_einsum'
            res.append(obs)

        if torch_einsum is not None:
            # torch
            ctx['einsum'] = torch_einsum
            ctx['xs'] = [from_numpy(x) for x in xs]
            ctx['ys'] = [from_numpy(y) for y in ys]
            obs = measure_time(
                "loop_einsum_eq(einsum, equation, xs, ys)",
                div_by_number=True, context=ctx, repeat=5, number=1)
            obs['dim'] = dim
            obs['fct'] = 'torch_einsum'
            res.append(obs)

    # Dataframes
    df = pandas.DataFrame(res)
    piv = df.pivot('dim', 'fct', 'average')

    rs = piv.copy()
    rs['c_einsum'] = rs['numpy.einsum'] / rs['c_einsum']
    rs['ort_einsum'] = rs['numpy.einsum'] / rs['ort_einsum']
    rs['ort_dec'] = rs['numpy.einsum'] / rs['ort_dec']
    rs['opt-einsum'] = rs['numpy.einsum'] / rs['opt-einsum']
    if 'c_einsum_tr' in rs.columns:
        rs['c_einsum_tr'] = rs['numpy.einsum'] / rs['c_einsum_tr']
    if 'tf_einsum' in rs.columns:
        rs['tf_einsum'] = rs['numpy.einsum'] / rs['tf_einsum']
    if 'torch_einsum' in rs.columns:
        rs['torch_einsum'] = rs['numpy.einsum'] / rs['torch_einsum']
    rs['numpy.einsum'] = 1.

    # Graphs.
    fig, ax = plt.subplots(1, 2, figsize=(14, 5))
    piv.plot(logx=True, logy=True, ax=ax[0],
             title="Einsum benchmark\n%s -- (2, N, 12, 64)"
                   " lower better" % equation)
    ax[0].legend(prop={"size": 9})
    rs.plot(logx=True, logy=True, ax=ax[1],
            title="Einsum Speedup, baseline=numpy\n%s -- (2, N, 12, 64)"
                  " higher better" % equation)
    ax[1].plot([min(rs.index), max(rs.index)], [0.5, 0.5], 'g--')
    ax[1].plot([min(rs.index), max(rs.index)], [2., 2.], 'g--')
    ax[1].legend(prop={"size": 9})

    return df, rs, ax

First equation: bsnh,btnh->bnts#

The decomposition of this equation without einsum function gives the following.

dfs = []
equation = "bsnh,btnh->bnts"
df, piv, ax = benchmark_equation(equation)
df.pivot("fct", "dim", "average")
dfs.append(df)
Einsum benchmark bsnh,btnh->bnts -- (2, N, 12, 64) lower better, Einsum Speedup, baseline=numpy bsnh,btnh->bnts -- (2, N, 12, 64) higher better

Out:

  0%|          | 0/121 [00:00<?, ?it/s]
0.028 rtbest='bsnh,btnh->bnts':   0%|          | 0/121 [00:00<?, ?it/s]
0.028 rtbest='bsnh,btnh->bnts':   2%|1         | 2/121 [00:00<00:06, 18.69it/s]
0.028 rtbest='bsnh,btnh->bnts':   3%|3         | 4/121 [00:00<00:06, 19.17it/s]
0.028 rtbest='bsnh,btnh->bnts':   5%|4         | 6/121 [00:00<00:05, 19.32it/s]
0.028 rtbest='bsth,bnth->btns':   5%|4         | 6/121 [00:00<00:05, 19.32it/s]
0.028 rtbest='bsth,bnth->btns':   7%|6         | 8/121 [00:00<00:05, 19.34it/s]
0.028 rtbest='bsth,bnth->btns':   8%|8         | 10/121 [00:00<00:05, 19.43it/s]
0.028 rtbest='bsth,bnth->btns':  10%|9         | 12/121 [00:00<00:05, 19.51it/s]
0.028 rtbest='bsht,bnht->bhns':  10%|9         | 12/121 [00:00<00:05, 19.51it/s]
0.028 rtbest='bsht,bnht->bhns':  12%|#1        | 14/121 [00:00<00:05, 19.01it/s]
0.028 rtbest='bsht,bnht->bhns':  13%|#3        | 16/121 [00:00<00:05, 19.17it/s]
0.028 rtbest='bhnt,bsnt->bnsh':  13%|#3        | 16/121 [00:00<00:05, 19.17it/s]
0.028 rtbest='bhnt,bsnt->bnsh':  15%|#4        | 18/121 [00:00<00:05, 19.23it/s]
0.028 rtbest='bhnt,bsnt->bnsh':  17%|#6        | 20/121 [00:01<00:05, 19.34it/s]
0.028 rtbest='bhnt,bsnt->bnsh':  18%|#8        | 22/121 [00:01<00:05, 19.39it/s]
0.028 rtbest='bhnt,bsnt->bnsh':  20%|#9        | 24/121 [00:01<00:04, 19.42it/s]
0.028 rtbest='bnst,bhst->bshn':  20%|#9        | 24/121 [00:01<00:04, 19.42it/s]
0.028 rtbest='bnst,bhst->bshn':  21%|##1       | 26/121 [00:01<00:05, 18.97it/s]
0.028 rtbest='bnst,bhst->bshn':  23%|##3       | 28/121 [00:01<00:04, 19.11it/s]
0.028 rtbest='bnst,bhst->bshn':  25%|##4       | 30/121 [00:01<00:04, 19.23it/s]
0.028 rtbest='nshb,nthb->nhts':  25%|##4       | 30/121 [00:01<00:04, 19.23it/s]
0.028 rtbest='nshb,nthb->nhts':  26%|##6       | 32/121 [00:01<00:04, 19.28it/s]
0.028 rtbest='snhb,sthb->shtn':  26%|##6       | 32/121 [00:01<00:04, 19.28it/s]
0.028 rtbest='snhb,sthb->shtn':  28%|##8       | 34/121 [00:01<00:04, 19.36it/s]
0.028 rtbest='snhb,sthb->shtn':  30%|##9       | 36/121 [00:01<00:04, 19.46it/s]
0.028 rtbest='snhb,sthb->shtn':  31%|###1      | 38/121 [00:01<00:04, 19.52it/s]
0.028 rtbest='snhb,sthb->shtn':  33%|###3      | 40/121 [00:02<00:04, 19.04it/s]
0.028 rtbest='snhb,sthb->shtn':  35%|###4      | 42/121 [00:02<00:04, 19.18it/s]
0.028 rtbest='snhb,sthb->shtn':  36%|###6      | 44/121 [00:02<00:03, 19.26it/s]
0.028 rtbest='snhb,sthb->shtn':  38%|###8      | 46/121 [00:02<00:03, 19.35it/s]
0.028 rtbest='snhb,sthb->shtn':  40%|###9      | 48/121 [00:02<00:03, 19.43it/s]
0.028 rtbest='snhb,sthb->shtn':  41%|####1     | 50/121 [00:02<00:03, 19.52it/s]
0.028 rtbest='snhb,sthb->shtn':  43%|####2     | 52/121 [00:02<00:03, 19.03it/s]
0.028 rtbest='snhb,sthb->shtn':  45%|####4     | 54/121 [00:02<00:03, 19.21it/s]
0.028 rtbest='snhb,sthb->shtn':  46%|####6     | 56/121 [00:02<00:03, 19.35it/s]
0.028 rtbest='snhb,sthb->shtn':  48%|####7     | 58/121 [00:03<00:03, 19.42it/s]
0.028 rtbest='snhb,sthb->shtn':  50%|####9     | 60/121 [00:03<00:03, 19.47it/s]
0.028 rtbest='snhb,sthb->shtn':  51%|#####1    | 62/121 [00:03<00:03, 19.54it/s]
0.028 rtbest='snhb,sthb->shtn':  53%|#####2    | 64/121 [00:03<00:02, 19.60it/s]
0.028 rtbest='snhb,sthb->shtn':  55%|#####4    | 66/121 [00:03<00:02, 19.04it/s]
0.028 rtbest='snhb,sthb->shtn':  56%|#####6    | 68/121 [00:03<00:02, 19.18it/s]
0.028 rtbest='snhb,sthb->shtn':  58%|#####7    | 70/121 [00:03<00:02, 19.27it/s]
0.028 rtbest='snhb,sthb->shtn':  60%|#####9    | 72/121 [00:03<00:02, 19.36it/s]
0.028 rtbest='snhb,sthb->shtn':  61%|######1   | 74/121 [00:03<00:02, 19.41it/s]
0.028 rtbest='snhb,sthb->shtn':  63%|######2   | 76/121 [00:03<00:02, 19.45it/s]
0.028 rtbest='snhb,sthb->shtn':  64%|######4   | 78/121 [00:04<00:02, 18.96it/s]
0.028 rtbest='snhb,sthb->shtn':  66%|######6   | 80/121 [00:04<00:02, 19.10it/s]
0.028 rtbest='snhb,sthb->shtn':  68%|######7   | 82/121 [00:04<00:02, 19.20it/s]
0.028 rtbest='snhb,sthb->shtn':  69%|######9   | 84/121 [00:04<00:01, 19.31it/s]
0.028 rtbest='snhb,sthb->shtn':  71%|#######1  | 86/121 [00:04<00:01, 19.40it/s]
0.028 rtbest='snhb,sthb->shtn':  73%|#######2  | 88/121 [00:04<00:01, 19.49it/s]
0.028 rtbest='snhb,sthb->shtn':  74%|#######4  | 90/121 [00:04<00:01, 19.00it/s]
0.028 rtbest='snhb,sthb->shtn':  76%|#######6  | 92/121 [00:04<00:01, 19.13it/s]
0.028 rtbest='snhb,sthb->shtn':  78%|#######7  | 94/121 [00:04<00:01, 19.22it/s]
0.028 rtbest='snhb,sthb->shtn':  79%|#######9  | 96/121 [00:04<00:01, 19.32it/s]
0.028 rtbest='snhb,sthb->shtn':  81%|########  | 98/121 [00:05<00:01, 19.38it/s]
0.028 rtbest='snhb,sthb->shtn':  83%|########2 | 100/121 [00:05<00:01, 19.40it/s]
0.028 rtbest='snhb,sthb->shtn':  84%|########4 | 102/121 [00:05<00:00, 19.44it/s]
0.028 rtbest='snhb,sthb->shtn':  86%|########5 | 104/121 [00:05<00:00, 18.96it/s]
0.028 rtbest='snhb,sthb->shtn':  88%|########7 | 106/121 [00:05<00:00, 19.11it/s]
0.028 rtbest='snhb,sthb->shtn':  89%|########9 | 108/121 [00:05<00:00, 19.22it/s]
0.028 rtbest='tnsh,tbsh->tsbn':  89%|########9 | 108/121 [00:05<00:00, 19.22it/s]
0.028 rtbest='tnsh,tbsh->tsbn':  91%|######### | 110/121 [00:05<00:00, 19.31it/s]
0.028 rtbest='tnsh,tbsh->tsbn':  93%|#########2| 112/121 [00:05<00:00, 19.36it/s]
0.028 rtbest='tnsh,tbsh->tsbn':  94%|#########4| 114/121 [00:05<00:00, 19.43it/s]
0.028 rtbest='tnsh,tbsh->tsbn':  96%|#########5| 116/121 [00:06<00:00, 18.97it/s]
0.028 rtbest='tnsh,tbsh->tsbn':  98%|#########7| 118/121 [00:06<00:00, 19.10it/s]
0.028 rtbest='tnsh,tbsh->tsbn':  99%|#########9| 120/121 [00:06<00:00, 19.19it/s]
0.028 rtbest='tnsh,tbsh->tsbn': 100%|##########| 121/121 [00:06<00:00, 19.27it/s]

  0%|          | 0/10 [00:00<?, ?it/s]
 10%|#         | 1/10 [00:00<00:01,  5.48it/s]
 20%|##        | 2/10 [00:00<00:01,  6.11it/s]
 30%|###       | 3/10 [00:00<00:01,  4.24it/s]
 40%|####      | 4/10 [00:01<00:02,  2.30it/s]
 50%|#####     | 5/10 [00:02<00:04,  1.18it/s]
 60%|######    | 6/10 [00:05<00:05,  1.39s/it]
 70%|#######   | 7/10 [00:11<00:08,  2.81s/it]
 80%|########  | 8/10 [00:20<00:09,  4.90s/it]
 90%|######### | 9/10 [00:59<00:15, 15.42s/it]
100%|##########| 10/10 [01:43<00:00, 24.23s/it]
100%|##########| 10/10 [01:43<00:00, 10.30s/it]

Second equation: bshn,bthn->bnts#

The summation does not happen on the last axis but on the previous one. Is it worth transposing before doing the summation… The decomposition of this equation without einsum function gives the following.

equation = "bshn,bthn->bnts"
df, piv, ax = benchmark_equation(equation)
df.pivot("fct", "dim", "average")
dfs.append(df)
Einsum benchmark bshn,bthn->bnts -- (2, N, 12, 64) lower better, Einsum Speedup, baseline=numpy bshn,bthn->bnts -- (2, N, 12, 64) higher better

Out:

  0%|          | 0/121 [00:00<?, ?it/s]
0.027 rtbest='bshn,bthn->bnts':   0%|          | 0/121 [00:00<?, ?it/s]
0.027 rtbest='bshn,bthn->bnts':   2%|1         | 2/121 [00:00<00:06, 19.32it/s]
0.027 rtbest='bthn,bshn->bnst':   2%|1         | 2/121 [00:00<00:06, 19.32it/s]
0.027 rtbest='bthn,bshn->bnst':   4%|4         | 5/121 [00:00<00:05, 19.90it/s]
0.027 rtbest='bths,bnhs->bsnt':   4%|4         | 5/121 [00:00<00:05, 19.90it/s]
0.027 rtbest='bths,bnhs->bsnt':   6%|5         | 7/121 [00:00<00:05, 19.92it/s]
0.027 rtbest='btnh,bsnh->bhst':   6%|5         | 7/121 [00:00<00:05, 19.92it/s]
0.027 rtbest='bnsh,btsh->bhtn':   6%|5         | 7/121 [00:00<00:05, 19.92it/s]
0.027 rtbest='bnsh,btsh->bhtn':   8%|8         | 10/121 [00:00<00:05, 20.03it/s]
0.027 rtbest='btsh,bnsh->bhnt':   8%|8         | 10/121 [00:00<00:05, 20.03it/s]
0.027 rtbest='btsh,bnsh->bhnt':  11%|#         | 13/121 [00:00<00:05, 19.90it/s]
0.027 rtbest='btsh,bnsh->bhnt':  13%|#3        | 16/121 [00:00<00:05, 20.03it/s]
0.027 rtbest='btsh,bnsh->bhnt':  16%|#5        | 19/121 [00:00<00:05, 20.13it/s]
0.027 rtbest='btsh,bnsh->bhnt':  18%|#8        | 22/121 [00:01<00:04, 20.21it/s]
0.027 rtbest='btsh,bnsh->bhnt':  21%|##        | 25/121 [00:01<00:04, 20.26it/s]
0.027 rtbest='btsh,bnsh->bhnt':  23%|##3       | 28/121 [00:01<00:04, 20.00it/s]
0.027 rtbest='btsh,bnsh->bhnt':  26%|##5       | 31/121 [00:01<00:04, 20.04it/s]
0.027 rtbest='btsh,bnsh->bhnt':  28%|##8       | 34/121 [00:01<00:04, 19.93it/s]
0.027 rtbest='btsh,bnsh->bhnt':  30%|##9       | 36/121 [00:01<00:04, 19.87it/s]
0.027 rtbest='btsh,bnsh->bhnt':  32%|###2      | 39/121 [00:01<00:04, 19.96it/s]
0.027 rtbest='btsh,bnsh->bhnt':  34%|###3      | 41/121 [00:02<00:04, 19.59it/s]
0.027 rtbest='btsh,bnsh->bhnt':  36%|###5      | 43/121 [00:02<00:03, 19.66it/s]
0.027 rtbest='btsh,bnsh->bhnt':  38%|###8      | 46/121 [00:02<00:03, 19.77it/s]
0.027 rtbest='btsh,bnsh->bhnt':  40%|###9      | 48/121 [00:02<00:03, 19.78it/s]
0.027 rtbest='btsh,bnsh->bhnt':  41%|####1     | 50/121 [00:02<00:03, 19.74it/s]
0.027 rtbest='btsh,bnsh->bhnt':  43%|####2     | 52/121 [00:02<00:03, 19.75it/s]
0.027 rtbest='btsh,bnsh->bhnt':  45%|####5     | 55/121 [00:02<00:03, 19.52it/s]
0.027 rtbest='btsh,bnsh->bhnt':  47%|####7     | 57/121 [00:02<00:03, 19.54it/s]
0.027 rtbest='btsh,bnsh->bhnt':  49%|####8     | 59/121 [00:02<00:03, 19.58it/s]
0.027 rtbest='btsh,bnsh->bhnt':  50%|#####     | 61/121 [00:03<00:03, 19.59it/s]
0.027 rtbest='btsh,bnsh->bhnt':  52%|#####2    | 63/121 [00:03<00:02, 19.59it/s]
0.027 rtbest='btsh,bnsh->bhnt':  54%|#####3    | 65/121 [00:03<00:02, 19.61it/s]
0.027 rtbest='btsh,bnsh->bhnt':  55%|#####5    | 67/121 [00:03<00:02, 19.67it/s]
0.027 rtbest='btsh,bnsh->bhnt':  57%|#####7    | 69/121 [00:03<00:02, 19.29it/s]
0.027 rtbest='btsh,bnsh->bhnt':  59%|#####8    | 71/121 [00:03<00:02, 19.35it/s]
0.027 rtbest='btsh,bnsh->bhnt':  60%|######    | 73/121 [00:03<00:02, 19.43it/s]
0.027 rtbest='btsh,bnsh->bhnt':  63%|######2   | 76/121 [00:03<00:02, 19.72it/s]
0.027 rtbest='btsh,bnsh->bhnt':  65%|######5   | 79/121 [00:03<00:02, 19.88it/s]
0.027 rtbest='btsh,bnsh->bhnt':  68%|######7   | 82/121 [00:04<00:01, 19.65it/s]
0.027 rtbest='btsh,bnsh->bhnt':  69%|######9   | 84/121 [00:04<00:01, 19.67it/s]
0.027 rtbest='btsh,bnsh->bhnt':  71%|#######1  | 86/121 [00:04<00:01, 19.65it/s]
0.027 rtbest='btsh,bnsh->bhnt':  73%|#######2  | 88/121 [00:04<00:01, 19.63it/s]
0.027 rtbest='btsh,bnsh->bhnt':  74%|#######4  | 90/121 [00:04<00:01, 19.64it/s]
0.027 rtbest='btsh,bnsh->bhnt':  76%|#######6  | 92/121 [00:04<00:01, 19.69it/s]
0.027 rtbest='btsh,bnsh->bhnt':  79%|#######8  | 95/121 [00:04<00:01, 19.85it/s]
0.027 rtbest='btsh,bnsh->bhnt':  80%|########  | 97/121 [00:04<00:01, 19.43it/s]
0.027 rtbest='btsh,bnsh->bhnt':  82%|########1 | 99/121 [00:05<00:01, 19.56it/s]
0.027 rtbest='btsh,bnsh->bhnt':  84%|########4 | 102/121 [00:05<00:00, 19.79it/s]
0.027 rtbest='btsh,bnsh->bhnt':  87%|########6 | 105/121 [00:05<00:00, 19.93it/s]
0.027 rtbest='btsh,bnsh->bhnt':  88%|########8 | 107/121 [00:05<00:00, 19.84it/s]
0.027 rtbest='btsh,bnsh->bhnt':  91%|######### | 110/121 [00:05<00:00, 19.56it/s]
0.027 rtbest='btsh,bnsh->bhnt':  93%|#########2| 112/121 [00:05<00:00, 19.53it/s]
0.027 rtbest='btsh,bnsh->bhnt':  94%|#########4| 114/121 [00:05<00:00, 19.54it/s]
0.027 rtbest='btsh,bnsh->bhnt':  96%|#########5| 116/121 [00:05<00:00, 19.62it/s]
0.027 rtbest='btsh,bnsh->bhnt':  98%|#########8| 119/121 [00:06<00:00, 19.78it/s]
0.027 rtbest='btsh,bnsh->bhnt': 100%|##########| 121/121 [00:06<00:00, 19.71it/s]
0.027 rtbest='btsh,bnsh->bhnt': 100%|##########| 121/121 [00:06<00:00, 19.74it/s]

  0%|          | 0/10 [00:00<?, ?it/s]
 20%|##        | 2/10 [00:00<00:00,  8.16it/s]
 30%|###       | 3/10 [00:00<00:01,  4.25it/s]
 40%|####      | 4/10 [00:01<00:03,  1.83it/s]
 50%|#####     | 5/10 [00:04<00:05,  1.19s/it]
 60%|######    | 6/10 [00:07<00:08,  2.03s/it]
 70%|#######   | 7/10 [00:17<00:13,  4.62s/it]
 80%|########  | 8/10 [00:34<00:16,  8.34s/it]
 90%|######### | 9/10 [02:10<00:35, 35.50s/it]
100%|##########| 10/10 [04:05<00:00, 59.90s/it]
100%|##########| 10/10 [04:05<00:00, 24.52s/it]

Third equation: bhsn,bhtn->bnts#

The summation does not happen on the last axis but on the second one. It is worth transposing before multiplying. The decomposition of this equation without einsum function gives the following.

equation = "bhsn,bhtn->bnts"
df, piv, ax = benchmark_equation(equation)
df.pivot("fct", "dim", "average")
dfs.append(df)
Einsum benchmark bhsn,bhtn->bnts -- (2, N, 12, 64) lower better, Einsum Speedup, baseline=numpy bhsn,bhtn->bnts -- (2, N, 12, 64) higher better

Out:

  0%|          | 0/121 [00:00<?, ?it/s]
0.028 rtbest='bhsn,bhtn->bnts':   0%|          | 0/121 [00:00<?, ?it/s]
0.028 rtbest='bhsn,bhtn->bnts':   2%|1         | 2/121 [00:00<00:06, 18.91it/s]
0.028 rtbest='bhns,bhts->bstn':   2%|1         | 2/121 [00:00<00:06, 18.91it/s]
0.028 rtbest='bhns,bhts->bstn':   3%|3         | 4/121 [00:00<00:06, 19.29it/s]
0.028 rtbest='bhns,bhts->bstn':   5%|4         | 6/121 [00:00<00:06, 18.44it/s]
0.028 rtbest='bhst,bhnt->btns':   5%|4         | 6/121 [00:00<00:06, 18.44it/s]
0.028 rtbest='bhst,bhnt->btns':   7%|6         | 8/121 [00:00<00:05, 18.92it/s]
0.028 rtbest='bsnh,bsth->bhtn':   7%|6         | 8/121 [00:00<00:05, 18.92it/s]
0.028 rtbest='bsnh,bsth->bhtn':   8%|8         | 10/121 [00:00<00:05, 19.20it/s]
0.028 rtbest='bsnh,bsth->bhtn':  11%|#         | 13/121 [00:00<00:05, 19.30it/s]
0.028 rtbest='bsnh,bsth->bhtn':  12%|#2        | 15/121 [00:00<00:05, 19.41it/s]
0.028 rtbest='bsnh,bsth->bhtn':  14%|#4        | 17/121 [00:00<00:05, 19.49it/s]
0.028 rtbest='bsnh,bsth->bhtn':  16%|#5        | 19/121 [00:00<00:05, 19.54it/s]
0.028 rtbest='bsnh,bsth->bhtn':  17%|#7        | 21/121 [00:01<00:05, 19.60it/s]
0.028 rtbest='bsnh,bsth->bhtn':  19%|#9        | 23/121 [00:01<00:04, 19.64it/s]
0.028 rtbest='bsnh,bsth->bhtn':  21%|##        | 25/121 [00:01<00:04, 19.68it/s]
0.028 rtbest='bsnh,bsth->bhtn':  22%|##2       | 27/121 [00:01<00:04, 19.38it/s]
0.028 rtbest='bsnh,bsth->bhtn':  24%|##3       | 29/121 [00:01<00:04, 19.44it/s]
0.028 rtbest='bsnh,bsth->bhtn':  26%|##5       | 31/121 [00:01<00:04, 19.50it/s]
0.028 rtbest='bsnh,bsth->bhtn':  27%|##7       | 33/121 [00:01<00:04, 19.42it/s]
0.028 rtbest='bsnh,bsth->bhtn':  29%|##8       | 35/121 [00:01<00:04, 19.38it/s]
0.028 rtbest='bsnh,bsth->bhtn':  31%|###       | 37/121 [00:01<00:04, 19.35it/s]
0.028 rtbest='bsnh,bsth->bhtn':  32%|###2      | 39/121 [00:02<00:04, 19.45it/s]
0.028 rtbest='bsnh,bsth->bhtn':  34%|###3      | 41/121 [00:02<00:04, 19.06it/s]
0.028 rtbest='bsnh,bsth->bhtn':  36%|###5      | 43/121 [00:02<00:04, 19.14it/s]
0.028 rtbest='bsnh,bsth->bhtn':  37%|###7      | 45/121 [00:02<00:03, 19.26it/s]
0.028 rtbest='bsnh,bsth->bhtn':  39%|###8      | 47/121 [00:02<00:03, 19.25it/s]
0.028 rtbest='bsnh,bsth->bhtn':  40%|####      | 49/121 [00:02<00:03, 19.31it/s]
0.028 rtbest='bsnh,bsth->bhtn':  42%|####2     | 51/121 [00:02<00:03, 19.29it/s]
0.028 rtbest='bsnh,bsth->bhtn':  44%|####3     | 53/121 [00:02<00:03, 19.35it/s]
0.028 rtbest='bsnh,bsth->bhtn':  45%|####5     | 55/121 [00:02<00:03, 18.99it/s]
0.028 rtbest='bsnh,bsth->bhtn':  47%|####7     | 57/121 [00:02<00:03, 19.02it/s]
0.028 rtbest='bsnh,bsth->bhtn':  49%|####8     | 59/121 [00:03<00:03, 19.06it/s]
0.028 rtbest='bsnh,bsth->bhtn':  50%|#####     | 61/121 [00:03<00:03, 19.10it/s]
0.028 rtbest='bsnh,bsth->bhtn':  52%|#####2    | 63/121 [00:03<00:03, 19.13it/s]
0.028 rtbest='bsnh,bsth->bhtn':  54%|#####3    | 65/121 [00:03<00:02, 19.15it/s]
0.028 rtbest='bsnh,bsth->bhtn':  55%|#####5    | 67/121 [00:03<00:02, 19.19it/s]
0.028 rtbest='bsnh,bsth->bhtn':  57%|#####7    | 69/121 [00:03<00:02, 18.85it/s]
0.028 rtbest='bsnh,bsth->bhtn':  59%|#####8    | 71/121 [00:03<00:02, 18.92it/s]
0.028 rtbest='bsnh,bsth->bhtn':  60%|######    | 73/121 [00:03<00:02, 18.98it/s]
0.028 rtbest='bsnh,bsth->bhtn':  62%|######1   | 75/121 [00:03<00:02, 19.17it/s]
0.028 rtbest='bsnh,bsth->bhtn':  64%|######3   | 77/121 [00:04<00:02, 19.30it/s]
0.028 rtbest='bsnh,bsth->bhtn':  65%|######5   | 79/121 [00:04<00:02, 19.40it/s]
0.028 rtbest='bsnh,bsth->bhtn':  67%|######6   | 81/121 [00:04<00:02, 19.47it/s]
0.028 rtbest='bsnh,bsth->bhtn':  69%|######8   | 83/121 [00:04<00:01, 19.03it/s]
0.028 rtbest='bsnh,bsth->bhtn':  70%|#######   | 85/121 [00:04<00:01, 19.09it/s]
0.028 rtbest='bsnh,bsth->bhtn':  72%|#######1  | 87/121 [00:04<00:01, 19.10it/s]
0.028 rtbest='bsnh,bsth->bhtn':  74%|#######3  | 89/121 [00:04<00:01, 19.15it/s]
0.028 rtbest='bsnh,bsth->bhtn':  75%|#######5  | 91/121 [00:04<00:01, 19.17it/s]
0.028 rtbest='bsnh,bsth->bhtn':  77%|#######6  | 93/121 [00:04<00:01, 19.31it/s]
0.028 rtbest='bsnh,bsth->bhtn':  79%|#######8  | 95/121 [00:04<00:01, 19.34it/s]
0.028 rtbest='bsnh,bsth->bhtn':  80%|########  | 97/121 [00:05<00:01, 18.94it/s]
0.028 rtbest='bsnh,bsth->bhtn':  82%|########1 | 99/121 [00:05<00:01, 19.12it/s]
0.028 rtbest='bsnh,bsth->bhtn':  83%|########3 | 101/121 [00:05<00:01, 19.24it/s]
0.028 rtbest='bsnh,bsth->bhtn':  85%|########5 | 103/121 [00:05<00:00, 19.32it/s]
0.028 rtbest='bsnh,bsth->bhtn':  87%|########6 | 105/121 [00:05<00:00, 19.40it/s]
0.028 rtbest='bsnh,bsth->bhtn':  88%|########8 | 107/121 [00:05<00:00, 19.33it/s]
0.028 rtbest='bsnh,bsth->bhtn':  90%|######### | 109/121 [00:05<00:00, 19.36it/s]
0.028 rtbest='bsnh,bsth->bhtn':  92%|#########1| 111/121 [00:05<00:00, 18.94it/s]
0.028 rtbest='bsnh,bsth->bhtn':  93%|#########3| 113/121 [00:05<00:00, 18.98it/s]
0.028 rtbest='bsnh,bsth->bhtn':  95%|#########5| 115/121 [00:05<00:00, 19.01it/s]
0.028 rtbest='bsnh,bsth->bhtn':  97%|#########6| 117/121 [00:06<00:00, 19.17it/s]
0.028 rtbest='bsnh,bsth->bhtn':  98%|#########8| 119/121 [00:06<00:00, 19.25it/s]
0.028 rtbest='bsnh,bsth->bhtn': 100%|##########| 121/121 [00:06<00:00, 19.23it/s]
0.028 rtbest='bsnh,bsth->bhtn': 100%|##########| 121/121 [00:06<00:00, 19.22it/s]

  0%|          | 0/10 [00:00<?, ?it/s]
 10%|#         | 1/10 [00:00<00:01,  6.49it/s]
 20%|##        | 2/10 [00:00<00:01,  5.27it/s]
 30%|###       | 3/10 [00:00<00:01,  3.90it/s]
 40%|####      | 4/10 [00:01<00:02,  2.54it/s]
 50%|#####     | 5/10 [00:02<00:02,  1.76it/s]
 60%|######    | 6/10 [00:03<00:03,  1.32it/s]
 70%|#######   | 7/10 [00:05<00:03,  1.16s/it]
 80%|########  | 8/10 [00:08<00:03,  1.65s/it]
 90%|######### | 9/10 [00:15<00:03,  3.61s/it]
100%|##########| 10/10 [00:23<00:00,  4.97s/it]
100%|##########| 10/10 [00:23<00:00,  2.39s/it]

Conclusion#

pytorch seems quite efficient on these examples. The custom implementation was a way to investigate the implementation of einsum and find some ways to optimize it.

merged = pandas.concat(dfs)
name = "einsum"
merged.to_csv("plot_%s.csv" % name, index=False)
merged.to_excel("plot_%s.xlsx" % name, index=False)
plt.savefig("plot_%s.png" % name)

plt.show()
plot op einsum

Total running time of the script: ( 6 minutes 37.525 seconds)

Gallery generated by Sphinx-Gallery