Speed up scikit-learn inference with ONNX#

Is it possible to make scikit-learn faster with ONNX? That’s question this example tries to answer. The scenario is is the following:

  • a model is trained

  • it is converted into ONNX for inference

  • it selects a runtime to compute the prediction

The following runtime are tested:

  • python: python runtime for ONNX

  • onnxruntime1: onnxruntime

  • numpy: the ONNX graph is converted into numpy code

  • numba: the numpy code is accelerated with numba.

PCA#

Let’s look at a very simple model, a PCA.

import numpy
from pandas import DataFrame
import matplotlib.pyplot as plt
from sklearn.datasets import make_regression
from sklearn.decomposition import PCA
from pyquickhelper.pycode.profiling import profile
from mlprodict.sklapi import OnnxSpeedupTransformer
from cpyquickhelper.numbers.speed_measure import measure_time
from tqdm import tqdm

Data and models to test.

data, _ = make_regression(1000, n_features=20)
data = data.astype(numpy.float32)
models = [
    ('sklearn', PCA(n_components=10)),
    ('python', OnnxSpeedupTransformer(
        PCA(n_components=10), runtime='python')),
    ('onnxruntime1', OnnxSpeedupTransformer(
        PCA(n_components=10), runtime='onnxruntime1')),
    ('numpy', OnnxSpeedupTransformer(
        PCA(n_components=10), runtime='numpy')),
    ('numba', OnnxSpeedupTransformer(
        PCA(n_components=10), runtime='numba'))]

Training.

for name, model in tqdm(models):
    model.fit(data)

Out:

  0%|          | 0/5 [00:00<?, ?it/s]
 40%|####      | 2/5 [00:00<00:00, 11.61it/s]
 80%|########  | 4/5 [00:01<00:00,  2.44it/s]
100%|##########| 5/5 [00:06<00:00,  1.86s/it]
100%|##########| 5/5 [00:06<00:00,  1.35s/it]

Profiling of runtime onnxruntime1.

def fct():
    for i in range(1000):
        models[2][1].transform(data)


res = profile(fct, pyinst_format="text")
print(res[1])

Out:

  _     ._   __/__   _ _  _  _ _/_   Recorded: 02:36:57 AM Samples:  545
 /_//_/// /_\ / //_// / //_'/ //     Duration: 0.614     CPU time: 2.101
/   _/                      v4.1.1

Program: somewhere/workspace/mlprodict/mlprodict_UT_39_std/_doc/examples/plot_speedup_pca.py

0.614 profile  ../pycode/profiling.py:457
`- 0.614 fct  plot_speedup_pca.py:67
      [42 frames hidden]  plot_speedup_pca, mlprodict, <built-in>
         0.400 run  mlprodict/onnxrt/ops_whole/session.py:97
         `- 0.397 [self]

Profiling of runtime numpy.

def fct():
    for i in range(1000):
        models[3][1].transform(data)


res = profile(fct, pyinst_format="text")
print(res[1])

Out:

  _     ._   __/__   _ _  _  _ _/_   Recorded: 02:36:58 AM Samples:  289
 /_//_/// /_\ / //_// / //_'/ //     Duration: 0.320     CPU time: 0.305
/   _/                      v4.1.1

Program: somewhere/workspace/mlprodict/mlprodict_UT_39_std/_doc/examples/plot_speedup_pca.py

0.319 profile  ../pycode/profiling.py:457
`- 0.319 fct  plot_speedup_pca.py:79
      [16 frames hidden]  plot_speedup_pca, mlprodict, sklearn,...
         0.281 numpy_mlprodict_ONNX_PCA  <string>:11
         |- 0.206 [self]
         |- 0.069 array  <built-in>:0

The class OnnxSpeedupTransformer converts the PCA into ONNX and then converts it into a python code using numpy. The code is the following.

print(models[3][1].numpy_code_)

Out:

import numpy
import scipy.special as scipy_special
import scipy.spatial.distance as scipy_distance
from mlprodict.onnx_tools.exports.numpy_helper import (
    argmax_use_numpy_select_last_index,
    argmin_use_numpy_select_last_index,
    array_feature_extrator,
    make_slice)


def numpy_mlprodict_ONNX_PCA(X):
    '''
    Numpy function for ``mlprodict_ONNX_PCA``.

    * producer: skl2onnx
    * version: 0
    * description:
    '''
    # initializers

    list_value = [0.24949690699577332, 0.28755664825439453, -0.2858494818210602, 0.31893104314804077, 0.2583758533000946, -0.11266523599624634, -0.055626265704631805, -0.1609230488538742, 0.1426141858100891, -0.055346425622701645, -0.35189685225486755, 0.06510689854621887, -0.17713449895381927, 0.110974982380867, -0.4618053138256073, 0.10629879683256149, 0.4018704295158386, -0.27363264560699463, 0.03810042887926102, -0.08029346168041229, -0.26192978024482727, 0.3145394027233124, 0.3087042272090912, 0.10290239006280899, 0.2555423080921173, -0.10654161870479584, 0.2804138660430908, 0.13779263198375702, 0.14337405562400818, 0.1666984260082245, -0.34065741300582886, -0.09624733775854111, -0.28282591700553894, 0.009630147367715836, 0.3342042863368988, -0.09198075532913208, 0.09750989824533463, -0.3125186264514923, 0.050983868539333344, 0.38411372900009155, -0.18291771411895752, -0.3588113784790039, -0.0898633748292923, -0.25906652212142944, -0.052607789635658264, 0.09125427901744843, -0.2113122195005417, -0.18517854809761047, 0.16711249947547913, 0.18703332543373108, 0.26028934121131897, 0.08020167797803879, -0.3732190430164337, -0.029572412371635437, -0.12425972521305084, -0.05784647539258003, 0.044847503304481506, 0.33086904883384705, -0.14477390050888062, 0.12189781665802002, -0.28346967697143555, -0.25543421506881714, 0.022770805284380913, 0.047992266714572906, 0.03133411705493927, -0.30805861949920654, 0.16875246167182922, 0.06574207544326782, 0.3677937388420105, -0.07037877291440964, -0.051086392253637314, -0.12809500098228455, -0.0792570412158966, -0.13049374520778656, -0.240582674741745, -0.2366640865802765, -0.32895636558532715, -0.2792789340019226, 0.12147863954305649, -0.2686566114425659, 0.10836990922689438, 0.33807337284088135, -0.270295113325119, -0.20316153764724731, -0.0981939435005188, -0.4210911691188812, 0.1569139063358307, -0.2666860818862915, -0.20914778113365173, 0.2618987262248993, 0.14967288076877594, -0.2256295382976532, -0.31239062547683716, -0.23140691220760345, -0.09412252902984619, 0.17502768337726593, 0.36465877294540405, 0.11804287135601044, -0.06709002703428268, -0.0740446075797081,
                  0.0070879459381103516, -0.12549898028373718, -0.3152640163898468, 0.13397471606731415, 0.18128454685211182, 0.2779339849948883, 0.09880752861499786, -0.003065187484025955, -0.06381163001060486, 0.12952569127082825, 0.26423344016075134, -0.026432909071445465, -0.16982239484786987, -0.04714816063642502, 0.10190736502408981, 0.10972216725349426, -0.29880037903785706, -0.25056594610214233, 0.2820923626422882, 0.04168280214071274, 0.2673230767250061, -0.0399121455848217, 0.18881763517856598, 0.2021227329969406, -0.5152174830436707, -0.29796159267425537, 0.07303787767887115, -0.014666000381112099, 0.1800355613231659, 0.24645911157131195, 0.23614808917045593, -0.08076481521129608, -0.0048734527081251144, 0.3981764018535614, -0.13116654753684998, 0.34717440605163574, 0.21401554346084595, -0.13270992040634155, 0.37103548645973206, 0.13514988124370575, 0.23844203352928162, -0.02685135044157505, 0.1361999660730362, -0.41368070244789124, 0.16272768378257751, 0.013388212770223618, 0.4593519866466522, -0.2639171779155731, 0.01295333169400692, -0.3770277500152588, 0.26066136360168457, -0.10737838596105576, 0.310264527797699, -0.057026635855436325, 0.10959860682487488, -0.1438370645046234, 0.04480040818452835, -0.35692939162254333, -0.12829402089118958, 0.2640196681022644, 0.01249966211616993, 0.4942779541015625, 0.09470222890377045, -0.4332524240016937, -0.08146287500858307, 0.3477427363395691, -0.06080116331577301, -0.02506687305867672, 0.4422890245914459, 0.14148502051830292, -0.13801909983158112, 0.22521330416202545, 0.08507519960403442, 0.285713255405426, -0.07882598042488098, 0.24952054023742676, -0.1052757054567337, -0.42015522718429565, -0.3527989685535431, -0.2750120460987091, 0.2388441264629364, -0.3014121949672699, 0.2672480642795563, 0.07443109154701233, 0.15369398891925812, 0.08493483066558838, 0.14471399784088135, -0.12882134318351746, -0.10991410911083221, 0.13637037575244904, -0.11312133073806763, -0.011874906718730927, 0.14303702116012573, -0.15590441226959229, -0.21574485301971436, 0.27530863881111145, -0.08601583540439606, 0.0031386837363243103, -0.3246106505393982, 0.4364517629146576]
    B = numpy.array(list_value, dtype=numpy.float32).reshape((20, 10))

    list_value = [-0.01973085105419159, 0.04642193019390106, -0.026720644906163216, -0.034755442291498184, 0.02267843671143055, 0.013573219068348408, -0.006667513865977526, 0.003996053244918585, -0.009159430861473083, -0.01723489724099636,
                  0.019976329058408737, 0.047059305012226105, -0.02513911761343479, 0.029120657593011856, 0.018971232697367668, 0.009760499931871891, -0.0002578531566541642, 0.06603453308343887, -0.010292286053299904, 0.004966238513588905]
    C = numpy.array(list_value, dtype=numpy.float32)

    # nodes

    D = X - C
    variable = D @ B

    return variable

Benchmark.

bench = []
for name, model in tqdm(models):
    for size in (1, 10, 100, 1000, 10000, 100000, 200000):
        data, _ = make_regression(size, n_features=20)
        data = data.astype(numpy.float32)

        # We run it a first time (numba compiles
        # the function during the first execution).
        model.transform(data)
        res = measure_time(
            lambda: model.transform(data), div_by_number=True,
            context={'data': data, 'model': model})
        res['name'] = name
        res['size'] = size
        bench.append(res)

df = DataFrame(bench)
piv = df.pivot("size", "name", "average")
piv

Out:

  0%|          | 0/5 [00:00<?, ?it/s]
 20%|##        | 1/5 [00:39<02:37, 39.46s/it]
 40%|####      | 2/5 [01:07<01:38, 32.97s/it]
 60%|######    | 3/5 [01:20<00:47, 23.80s/it]
 80%|########  | 4/5 [01:46<00:24, 24.67s/it]
100%|##########| 5/5 [02:21<00:00, 28.46s/it]
100%|##########| 5/5 [02:21<00:00, 28.39s/it]
name numba numpy onnxruntime1 python sklearn
size
1 0.000020 0.000073 0.000377 0.000178 0.000285
10 0.000023 0.000080 0.000222 0.000140 0.000299
100 0.000035 0.000095 0.000260 0.000146 0.000340
1000 0.000146 0.000205 0.000517 0.000307 0.000529
10000 0.002076 0.001582 0.001833 0.004186 0.003000
100000 0.015375 0.015828 0.008872 0.016737 0.025106
200000 0.036508 0.032395 0.011852 0.033362 0.047196


Graph.

fig, ax = plt.subplots(1, 2, figsize=(10, 4))
piv.plot(title="Speedup PCA with ONNX (lower better)",
         logx=True, logy=True, ax=ax[0])
piv2 = piv.copy()
for c in piv2.columns:
    piv2[c] /= piv['sklearn']
print(piv2)
piv2.plot(title="baseline=scikit-learn (lower better)",
          logx=True, logy=True, ax=ax[1])
plt.show()
Speedup PCA with ONNX (lower better), baseline=scikit-learn (lower better)

Out:

name       numba     numpy  onnxruntime1    python  sklearn
size
1       0.068535  0.254596      1.322579  0.622805      1.0
10      0.077582  0.265771      0.741999  0.467786      1.0
100     0.104360  0.278008      0.763107  0.430638      1.0
1000    0.275575  0.387053      0.978690  0.581229      1.0
10000   0.692054  0.527487      0.611021  1.395231      1.0
100000  0.612397  0.630467      0.353381  0.666665      1.0
200000  0.773550  0.686407      0.251120  0.706884      1.0

Total running time of the script: ( 2 minutes 32.279 seconds)

Gallery generated by Sphinx-Gallery