numpy.digitize as a tree

Function numpy.digitize transforms a real variable into a discrete one by returning the buckets the variable falls into. This bucket can be efficiently retrieved by doing a binary search over the bins. That’s equivalent to decision tree. Function digitize2tree.

Simple example

import warnings
import numpy
from pandas import DataFrame, pivot, pivot_table
import matplotlib.pyplot as plt
from onnxruntime import InferenceSession
from sklearn.tree import export_text
from skl2onnx import to_onnx
from cpyquickhelper.numbers.speed_measure import measure_time
from mlinsights.mltree import digitize2tree
from tqdm import tqdm

x = numpy.array([0.2, 6.4, 3.0, 1.6])
bins = numpy.array([0.0, 1.0, 2.5, 4.0, 7.0])
expected = numpy.digitize(x, bins, right=True)
tree = digitize2tree(bins, right=True)
pred = tree.predict(x.reshape((-1, 1)))
print(expected, pred)

Out:

[1 4 3 2] [1. 4. 3. 2.]

The tree looks like the following.

print(export_text(tree, feature_names=['x']))

Out:

|--- x <= 2.50
|   |--- x <= 1.00
|   |   |--- x <= 0.00
|   |   |   |--- value: [0.00]
|   |   |--- x >  0.00
|   |   |   |--- value: [1.00]
|   |--- x >  1.00
|   |   |--- value: [2.00]
|--- x >  2.50
|   |--- x <= 4.00
|   |   |--- x <= 2.50
|   |   |   |--- value: [2.00]
|   |   |--- x >  2.50
|   |   |   |--- value: [3.00]
|   |--- x >  4.00
|   |   |--- x <= 7.00
|   |   |   |--- x <= 4.00
|   |   |   |   |--- value: [3.00]
|   |   |   |--- x >  4.00
|   |   |   |   |--- value: [4.00]
|   |   |--- x >  7.00
|   |   |   |--- value: [5.00]

Benchmark

Let’s measure the processing time. numpy should be much faster than scikit-learn as it is adding many verification. However, the benchmark also includes a conversion of the tree into ONNX and measure the processing time with onnxruntime.

obs = []

for shape in tqdm([1, 10, 100, 1000, 10000, 100000]):
    x = numpy.random.random(shape).astype(numpy.float32)
    if shape < 1000:
        repeat = number = 100
    else:
        repeat = number = 10

    for n_bins in [1, 10, 100]:
        bins = (numpy.arange(n_bins) / n_bins).astype(numpy.float32)

        ti = measure_time(
            "numpy.digitize(x, bins, right=True)",
            context={'numpy': numpy, "x": x, "bins": bins},
            div_by_number=True, repeat=repeat, number=number)
        ti['name'] = 'numpy'
        ti['n_bins'] = n_bins
        ti['shape'] = shape
        obs.append(ti)

        tree = digitize2tree(bins, right=True)

        ti = measure_time(
            "tree.predict(x)",
            context={'numpy': numpy, "x": x.reshape((-1, 1)), "tree": tree},
            div_by_number=True, repeat=repeat, number=number)
        ti['name'] = 'sklearn'
        ti['n_bins'] = n_bins
        ti['shape'] = shape
        obs.append(ti)

        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=FutureWarning)
            onx = to_onnx(tree, x.reshape((-1, 1)),
                          target_opset=15)

        sess = InferenceSession(onx.SerializeToString())

        ti = measure_time(
            "sess.run(None, {'X': x})",
            context={'numpy': numpy, "x": x.reshape((-1, 1)), "sess": sess},
            div_by_number=True, repeat=repeat, number=number)
        ti['name'] = 'ort'
        ti['n_bins'] = n_bins
        ti['shape'] = shape
        obs.append(ti)


df = DataFrame(obs)
piv = pivot_table(data=df, index="shape", columns=["n_bins", "name"],
                  values=["average"])
print(piv)
Traceback (most recent call last):
  File "somewhere/workspace/mlinsights/mlinsights_UT_39_std/_doc/examples/plot_digitize.py", line 92, in <module>
    sess = InferenceSession(onx.SerializeToString())
  File "/usr/local/lib/python3.9/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 280, in __init__
    self._create_inference_session(providers, provider_options)
  File "/usr/local/lib/python3.9/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 309, in _create_inference_session
    sess = C.InferenceSession(session_options, self._model_bytes, False, self._read_config_from_model)
onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Failed to load model with error: /onnxruntime_src/onnxruntime/core/graph/model.cc:111 onnxruntime::Model::Model(onnx::ModelProto&&, const PathString&, const IOnnxRuntimeOpSchemaRegistryList*, const onnxruntime::logging::Logger&) Unknown model file format version.

Plotting

n_bins = list(sorted(set(df.n_bins)))
fig, ax = plt.subplots(1, len(n_bins), figsize=(14, 4))

for i, nb in enumerate(n_bins):
    piv = pivot(data=df[df.n_bins == nb], index="shape",
                columns="name", values="average")
    piv.plot(title="Benchmark digitize / onnxruntime\nn_bins=%d" % nb,
             logx=True, logy=True, ax=ax[i])
plt.show()

Total running time of the script: ( 0 minutes 7.531 seconds)

Gallery generated by Sphinx-Gallery