A converter for a TransformedTargetRegressor#

There is no easy way to convert a sklearn.preprocessing.FunctionTransformer or a :epkg:`sklearn.compose.TransformedTargetRegressor` unless the function is written in such a way the conversion is implicit.

from typing import Any
import numpy as np
from sklearn.compose import TransformedTargetRegressor
from sklearn.preprocessing import FunctionTransformer
from sklearn.linear_model import LinearRegression
from mlprodict.onnx_conv import to_onnx
from mlprodict import __max_supported_opset__ as TARGET_OPSET
from mlprodict.npy import onnxnumpy_default, NDArray
from mlprodict.onnxrt import OnnxInference
import mlprodict.npy.numpy_onnx_impl as npnx

TransformedTargetRegressor#

@onnxnumpy_default
def onnx_log_1(x: NDArray[Any, np.float32]) -> NDArray[(None, None), np.float32]:
    return npnx.log1p(x)


@onnxnumpy_default
def onnx_exp_1(x: NDArray[Any, np.float32]) -> NDArray[(None, None), np.float32]:
    return npnx.exp(x) - np.float32(1)


model = TransformedTargetRegressor(
    regressor=LinearRegression(),
    func=onnx_log_1, inverse_func=onnx_exp_1)

x = np.arange(18).reshape((-1, 3)).astype(np.float32)
y = x.sum(axis=1)
model.fit(x, y)
expected = model.predict(x)
print(expected)
[ 5.3555384  9.108676  15.0781555 24.572792  39.67432   63.693733 ]

Conversion to ONNX

onx = to_onnx(model, x, rewrite_ops=True, target_opset=TARGET_OPSET)
oinf = OnnxInference(onx)
got = oinf.run({'X': x})
print(got)
{'variable': array([[ 5.3555384],
       [ 9.108676 ],
       [15.0781555],
       [24.572792 ],
       [39.67432  ],
       [63.693733 ]], dtype=float32)}

FunctionTransformer#

model = FunctionTransformer(onnx_log_1)
model.fit(x, y)
expected = model.transform(x)
print(expected)
[[0.        0.6931472 1.0986123]
 [1.3862944 1.609438  1.7917595]
 [1.9459101 2.0794415 2.1972246]
 [2.3025851 2.3978953 2.4849067]
 [2.5649493 2.6390574 2.7080503]
 [2.7725887 2.8332133 2.8903718]]

Conversion to ONNX

onx = to_onnx(model, x, rewrite_ops=True, target_opset=TARGET_OPSET)
oinf = OnnxInference(onx)
got = oinf.run({'X': x})
print(got)
{'variable': array([[0.       , 0.6931472, 1.0986123],
       [1.3862944, 1.609438 , 1.7917595],
       [1.9459101, 2.0794415, 2.1972246],
       [2.3025851, 2.3978953, 2.4849067],
       [2.5649493, 2.6390574, 2.7080503],
       [2.7725887, 2.8332133, 2.8903718]], dtype=float32)}

Total running time of the script: ( 0 minutes 0.887 seconds)

Gallery generated by Sphinx-Gallery