Tricky detail when converting a random forest from scikit-learn into ONNX

Links: notebook, html, PDF, python, slides, GitHub

scikit-learn use a specific comparison when computing the preduction of a decision tree, it does (float)x <= threshold (see tree.pyx / method apply_dense). ONNX does not specify such things and compares x to threshold, both having the same type. What to do then when writing the converter.

from jyquickhelper import add_notebook_menu
add_notebook_menu()
%matplotlib inline

Conversion to float

import numpy
N = 1000
delta = 1e-9
factor = 10
dxs = numpy.empty((2 * N,), dtype=numpy.float64)
fxs = numpy.empty((2 * N,), dtype=numpy.float32)
for i, x in enumerate(range(-N, N)):
    dx = (1. + x * delta) * factor
    dxs[i] = dx
    fxs[i] = dx
import matplotlib.pyplot as plt
fig, ax = plt.subplots(1, 1, figsize=(12, 4))
ax.plot(dxs, fxs)
ax.set_title("conversion from double to float")
ax.set_xlabel("double")
ax.set_ylabel("float");
../_images/onnx_float_double_skl_decision_trees_5_0.png

Region where (float)x <= y

Let’s see how the comparison (float)x <= y looks like.

N = 100
delta = 36e-10
xs = []
ys = []
for x in range(-N, N):
    for y in range(-N, N):
        dx = (1. + x * delta) * factor
        dy = (1. + y * delta) * factor
        if numpy.float32(dx) <= numpy.float64(dy):
            xs.append(dx)
            ys.append(dy)
fig, ax = plt.subplots(1, 1, figsize=(12, 4))
ax.plot(xs, ys, ".")
ax.set_title("Region where (float)x <= y")
ax.plot([min(xs), max(xs)], [min(ys), max(ys)], 'k--')
ax.set_xlabel("x")
ax.set_ylabel("y");
../_images/onnx_float_double_skl_decision_trees_9_0.png

Equivalent to (float)x <= (float)y ?

def area_mismatch_rule(N, delta, factor, rule=None):
    if rule is None:
        rule = lambda t: numpy.float32(t)
    xst = []
    yst = []
    xsf = []
    ysf = []
    for x in range(-N, N):
        for y in range(-N, N):
            dx = (1. + x * delta) * factor
            dy = (1. + y * delta) * factor
            c1 = 1 if numpy.float32(dx) <= numpy.float64(dy) else 0
            c2 = 1 if numpy.float32(dx) <= rule(dy) else 0
            key = abs(c1 - c2)
            if key == 1:
                xsf.append(dx)
                ysf.append(dy)
            else:
                xst.append(dx)
                yst.append(dy)
    return xst, yst, xsf, ysf

xst1, yst1, xsf1, ysf1 = area_mismatch_rule(100, delta, 1.)
"factor=%1.1f, error area %1.4f%s" % (1., len(xsf1) * 1.0 / (len(xst1) + len(xsf1)) * 100, "%")
'factor=1.0, error area 5.7525%'

Applied to a decision tree, it does not mean that the evaluation of the condition of each node would fail in 5.75% of the cases, it depends on how the thresholds are built and the area of errors depends on the numbers.

factor = 10
xst, yst, xsf, ysf = area_mismatch_rule(100, delta, factor)
"factor=%1.1f, error area %1.4f%s" % (factor, len(xsf) * 1.0 / (len(xst) + len(xsf)) * 100, "%")
'factor=10.0, error area 6.2025%'
fig, ax = plt.subplots(1, 2, figsize=(14, 4))
ax[0].plot(xst1, yst1, '.', label="agree")
ax[0].plot(xsf1, ysf1, '.', label="disagree")
ax[0].set_title("Factor=1, Region where\n(float)x <= y and (float)x <= (float)y agree")
ax[0].set_xlabel("x")
ax[0].set_ylabel("y")
ax[0].plot([min(xst1), max(xst1)], [min(yst1), max(yst1)], 'k--')
ax[0].legend()
ax[1].plot(xst, yst, '.', label="agree")
ax[1].plot(xsf, ysf, '.', label="disagree")
ax[1].set_title("Factor=%f, Region where\n(float)x <= y and (float)x <= (float)y agree" % factor)
ax[1].set_xlabel("x")
ax[1].set_ylabel("y")
ax[1].plot([min(xst), max(xst)], [min(yst), max(yst)], 'k--')
ax[1].legend();
../_images/onnx_float_double_skl_decision_trees_14_0.png

Good threshold

def good_threshold(dy):
    fy = numpy.float32(dy)
    if fy == dy:
        return fy
    if fy < dy:
        return fy
    eps = max(abs(fy), numpy.finfo(numpy.float32).eps) * 10
    nfy = numpy.nextafter([fy], [fy - eps], dtype=numpy.float32)[0]
    return nfy

good_threshold(1.), good_threshold(1 + 1e-8), good_threshold(1 - 1e-8)
(1.0, 1.0, 0.99999994)
xst1, yst1, xsf1, ysf1 = area_mismatch_rule(100, delta, 1., good_threshold)
"factor=%1.1f, error area %1.4f%s" % (1., len(xsf1) * 1.0 / (len(xst1) + len(xsf1)) * 100, "%")
'factor=1.0, error area 0.0000%'
xst, yst, xsf, ysf = area_mismatch_rule(100, delta, 1e20, good_threshold)
"factor=%1.1g, error area %1.4f%s" % (1e20, len(xsf) * 1.0 / (len(xst) + len(xsf)) * 100, "%")
'factor=1e+20, error area 0.0000%'
factor = 10
xst, yst, xsf, ysf = area_mismatch_rule(100, delta, factor, good_threshold)
"factor=%1.1f, error area %1.4f%s" % (factor, len(xsf) * 1.0 / (len(xst) + len(xsf)) * 100, "%")
'factor=10.0, error area 0.0000%'
fig, ax = plt.subplots(1, 2, figsize=(14, 4))
ax[0].plot(xst1, yst1, '.', label="agree")
ax[0].plot(xsf1, ysf1, '.', label="disagree")
ax[0].set_title("Factor=1, Region where\n(float)x <= y and (float)x <= good_threshold(y) agree")
ax[0].set_xlabel("x")
ax[0].set_ylabel("y")
ax[0].plot([min(xst1), max(xst1)], [min(yst1), max(yst1)], 'k--')
ax[0].legend()
ax[1].plot(xst, yst, '.', label="agree")
ax[1].plot(xsf, ysf, '.', label="disagree")
ax[1].set_title("Factor=%f, Region where\n(float)x <= y and (float)x <= good_threshold(y) agree" % factor)
ax[1].set_xlabel("x")
ax[1].set_ylabel("y")
ax[1].plot([min(xst), max(xst)], [min(yst), max(yst)], 'k--')
ax[1].legend();
../_images/onnx_float_double_skl_decision_trees_20_0.png

Let’s draw the function:

N = 1000
dxs2 = numpy.empty((2 * N,), dtype=numpy.float64)
fxs1 = numpy.empty((2 * N,), dtype=numpy.float32)
fxs2 = numpy.empty((2 * N,), dtype=numpy.float32)
for i, x in enumerate(range(-N, N)):
    dx = 1. + x * 1e-9
    dxs2[i] = dx
    fxs1[i] = numpy.float32(dx)
    fxs2[i] = good_threshold(dx)
fig, ax = plt.subplots(1, 1, figsize=(12, 4))
ax.plot(dxs2, fxs1, label="(float)")
ax.plot(dxs2, fxs2, label="good_threshold")
ax.set_title("Function good_threshold")
ax.set_xlabel("double")
ax.set_ylabel("float")
ax.legend();
../_images/onnx_float_double_skl_decision_trees_23_0.png

That’s explain some tricky lines in package skl2onnx. Let’s check if it still works with negative value.

N = 100
xst = []
yst = []
xsf = []
ysf = []
for x in range(-N, N):
    for y in range(-N, N):
        dx = -1. + x * delta
        dy = -1. + y * delta
        c1 = 1 if numpy.float32(dx) <= numpy.float64(dy) else 0
        c2 = 1 if numpy.float32(dx) <= good_threshold(dy) else 0
        key = abs(c1 - c2)
        if key == 1:
            xsf.append(dx)
            ysf.append(dy)
        else:
            xst.append(dx)
            yst.append(dy)

"error area %1.4f%s" % (len(xsf) * 1.0 / (len(xst) + len(xsf)) * 100, "%")
'error area 0.0000%'

It works.

What about double double?

def area_mismatch_rule_double(N, delta, factor, rule=None):
    if rule is None:
        rule = lambda t: numpy.float64(t)
    xst = []
    yst = []
    xsf = []
    ysf = []
    for x in range(-N, N):
        for y in range(-N, N):
            dx = (1. + x * delta) * factor
            dy = (1. + y * delta) * factor
            c1 = 1 if numpy.float32(dx) <= numpy.float64(dy) else 0
            c2 = 1 if numpy.float64(dx) <= rule(dy) else 0
            key = abs(c1 - c2)
            if key == 1:
                xsf.append(dx)
                ysf.append(dy)
            else:
                xst.append(dx)
                yst.append(dy)
    return xst, yst, xsf, ysf

xst1, yst1, xsf1, ysf1 = area_mismatch_rule_double(100, delta, 1.)
"factor=%1.1f, error area %1.4f%s" % (1., len(xsf1) * 1.0 / (len(xst1) + len(xsf1)) * 100, "%")
'factor=1.0, error area 3.1125%'
xst, yst, xsf, ysf = area_mismatch_rule_double(100, delta, 1e20)
"factor=%1.1g, error area %1.4f%s" % (1e20, len(xsf) * 1.0 / (len(xst) + len(xsf)) * 100, "%")
'factor=1e+20, error area 2.9900%'
xst, yst, xsf, ysf = area_mismatch_rule_double(100, delta, factor)
"factor=%1.1f, error area %1.4f%s" % (factor, len(xsf) * 1.0 / (len(xst) + len(xsf)) * 100, "%")
'factor=10.0, error area 3.1975%'

The probability it fails is lower than for floats but still significant.

fig, ax = plt.subplots(1, 2, figsize=(14, 4))
ax[0].plot(xst1, yst1, '.', label="agree")
ax[0].plot(xsf1, ysf1, '.', label="disagree")
xs = list(sorted(set(xst1)))
ys = [numpy.float32(x) for x in xs]
ax[0].plot(xs, ys, 'g', label='rule')
ax[0].set_title("Factor=1, Region where\n(float)x <= y and (double)x <= (double)y agree")
ax[0].set_xlabel("x")
ax[0].set_ylabel("y")
ax[0].plot([min(xst1), max(xst1)], [min(yst1), max(yst1)], 'k--')
ax[0].legend()
ax[1].plot(xst, yst, '.', label="agree")
ax[1].plot(xsf, ysf, '.', label="disagree")
xs = list(sorted(set(xst)))
ys = [numpy.float32(x) for x in xs]
ax[1].plot(xs, ys, 'g', label='rule')
ax[1].set_title("Factor=%f, Region where\n(float)x <= y and (double)x <= (double)y agree" % factor)
ax[1].set_xlabel("x")
ax[1].set_ylabel("y")
ax[1].plot([min(xst), max(xst)], [min(yst), max(yst)], 'k--')
ax[1].legend();
../_images/onnx_float_double_skl_decision_trees_32_0.png

Let’s fix it in a similar way. Let’s first define a function which finds the split double which defines the border between doubles, below the are rounded to one float, above it, they are rounded to another float. And it is not always to middle of it.

def find_switch_point(fy, nfy):
    "Finds the double so that ``(float)x != (float)(x + espilon)``."
    a = numpy.float64(fy)
    b = numpy.float64(nfy)
    fa = numpy.float32(a)
    fb = numpy.float32(b)
    a0, b0 = a, a
    while a != a0 or b != b0:
        a0, b0 = a, b
        m = (a + b) / 2
        fm = numpy.float32(m)
        if fm == fa:
            a = m
            fa = fm
        else:
            b = m
            fb = fm
    return a

find_switch_point(1, 1.0000000876)
1.0000000596046448
def good_threshold_double(dy):
    fy = numpy.float32(dy)
    eps = max(abs(fy), numpy.finfo(numpy.float32).eps) * 10
    afy = numpy.nextafter([fy], [fy - eps], dtype=numpy.float32)[0]
    afy2 = find_switch_point(afy, fy)
    if fy > dy > afy2:
        return afy2
    bfy = numpy.nextafter([fy], [fy + eps], dtype=numpy.float32)[0]
    bfy2 = find_switch_point(fy, bfy)
    if fy <= dy <= bfy2:
        return bfy2
    return fy

good_threshold_double(1.0), numpy.float32(1.0000000216)
(1.0000000596046448, 1.0)
xst1, yst1, xsf1, ysf1 = area_mismatch_rule_double(100, delta, 1., good_threshold_double)
"factor=%1.1f, error area %1.4f%s" % (1., len(xsf1) * 1.0 / (len(xst1) + len(xsf1)) * 100, "%")
'factor=1.0, error area 0.0000%'
xst, yst, xsf, ysf = area_mismatch_rule_double(100, delta, 1e20, good_threshold_double)
"factor=%1.1g, error area %1.4f%s" % (1e20, len(xsf) * 1.0 / (len(xst) + len(xsf)) * 100, "%")
'factor=1e+20, error area 0.0000%'
xst, yst, xsf, ysf = area_mismatch_rule_double(100, delta, factor, good_threshold_double)
"factor=%1.1f, error area %1.4f%s" % (factor, len(xsf) * 1.0 / (len(xst) + len(xsf)) * 100, "%")
'factor=10.0, error area 0.0000%'
fig, ax = plt.subplots(1, 1, figsize=(12, 4))
ax.plot(xst, yst, '.', label="agree")
ax.plot(xsf, ysf, '.', label="disagree")
xs = list(sorted(set(xst)))
ys = [good_threshold_double(x) for x in xs]
ax.plot(xs, ys, 'g', label='rule')
ax.set_title("Region where (float)x <= y and (double)x <= good_threshold_double(y) agree")
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.plot([min(xst), max(xst)], [min(yst), max(yst)], 'k--')
ax.legend();
../_images/onnx_float_double_skl_decision_trees_39_0.png

All doubles equivalent to the same float

We can use the previous code to determine a double interval in which every double is converted into the same float.

def double_interval_for_float(dy):
    fy = numpy.float32(dy)
    eps = max(abs(fy), numpy.finfo(numpy.float32).eps) * 10
    afy = numpy.nextafter([fy], [fy - eps], dtype=numpy.float32)[0]
    afy2 = find_switch_point(afy, fy)
    eps64 = numpy.finfo(numpy.float64).eps
    bfy = numpy.nextafter([fy], [fy + eps], dtype=numpy.float32)[0]
    bfy2 = find_switch_point(fy, bfy)
    return (afy2 + eps64, bfy2)

double_interval_for_float(1.)
(0.9999999701976777, 1.0000000596046448)
double_interval_for_float(1. + 1e-8)
(0.9999999701976777, 1.0000000596046448)
eps = numpy.finfo(numpy.float64).eps
double_interval_for_float(1.0000000596046448 + eps)
(1.000000059604645, 1.000000178813934)

Verification

Let’s check the rules works for many random x.

def verification(rnd):
    errors = []
    for x in rnd:
        skl = numpy.float32(x) <= x
        flo = numpy.float32(x) <= good_threshold(x)
        dou = numpy.float64(x) <= good_threshold_double(x)
        if skl != flo or skl != dou:
            errors.append((x, skl, flo, dou))
    return errors

rnd = (numpy.random.rand(10) - 0.5)
verification(rnd)
[]
rnd = (numpy.random.rand(10) - 0.5) * 10
verification(rnd)
[]