Source code for mlprodict.onnx_conv.register_rewritten_converters

"""
Rewrites some of the converters implemented in
:epkg:`sklearn-onnx`.


:githublink:`%|py|6`
"""
from skl2onnx.common._registration import _converter_pool

try:
    from skl2onnx.common._registration import RegisteredConverter
except ImportError:  # pragma: no cover
    # sklearn-onnx <= 1.6.0
    RegisteredConverter = lambda fct, opts: fct

from .sklconv.tree_converters import (
    new_convert_sklearn_decision_tree_classifier,
    new_convert_sklearn_decision_tree_regressor,
    new_convert_sklearn_gradient_boosting_classifier,
    new_convert_sklearn_gradient_boosting_regressor,
    new_convert_sklearn_random_forest_classifier,
    new_convert_sklearn_random_forest_regressor,
)
from .sklconv.svm_converters import (
    new_convert_sklearn_svm_classifier,
    new_convert_sklearn_svm_regressor,
)


_overwritten_operators = {
    #
    'SklearnOneClassSVM': RegisteredConverter(
        new_convert_sklearn_svm_regressor,
        _converter_pool['SklearnOneClassSVM'].get_allowed_options()),
    'SklearnSVR': RegisteredConverter(
        new_convert_sklearn_svm_regressor,
        _converter_pool['SklearnSVR'].get_allowed_options()),
    'SklearnSVC': RegisteredConverter(
        new_convert_sklearn_svm_classifier,
        _converter_pool['SklearnSVC'].get_allowed_options()),
    #
    'SklearnDecisionTreeRegressor': RegisteredConverter(
        new_convert_sklearn_decision_tree_regressor,
        _converter_pool['SklearnDecisionTreeRegressor'].get_allowed_options()),
    'SklearnDecisionTreeClassifier': RegisteredConverter(
        new_convert_sklearn_decision_tree_classifier,
        _converter_pool['SklearnDecisionTreeClassifier'].get_allowed_options()),
    #
    'SklearnExtraTreeRegressor': RegisteredConverter(
        new_convert_sklearn_decision_tree_regressor,
        _converter_pool['SklearnExtraTreeRegressor'].get_allowed_options()),
    'SklearnExtraTreeClassifier': RegisteredConverter(
        new_convert_sklearn_decision_tree_classifier,
        _converter_pool['SklearnExtraTreeClassifier'].get_allowed_options()),
    #
    'SklearnExtraTreesRegressor': RegisteredConverter(
        new_convert_sklearn_random_forest_regressor,
        _converter_pool['SklearnExtraTreesRegressor'].get_allowed_options()),
    'SklearnExtraTreesClassifier': RegisteredConverter(
        new_convert_sklearn_random_forest_classifier,
        _converter_pool['SklearnExtraTreesClassifier'].get_allowed_options()),
    #
    'SklearnGradientBoostingRegressor': RegisteredConverter(
        new_convert_sklearn_gradient_boosting_regressor,
        _converter_pool['SklearnGradientBoostingRegressor'].get_allowed_options()),
    'SklearnGradientBoostingClassifier': RegisteredConverter(
        new_convert_sklearn_gradient_boosting_classifier,
        _converter_pool['SklearnGradientBoostingClassifier'].get_allowed_options()),
    #
    'SklearnHistGradientBoostingRegressor': RegisteredConverter(
        new_convert_sklearn_random_forest_regressor,
        _converter_pool['SklearnHistGradientBoostingRegressor'].get_allowed_options()),
    'SklearnHistGradientBoostingClassifier': RegisteredConverter(
        new_convert_sklearn_random_forest_classifier,
        _converter_pool['SklearnHistGradientBoostingClassifier'].get_allowed_options()),
    #
    'SklearnRandomForestRegressor': RegisteredConverter(
        new_convert_sklearn_random_forest_regressor,
        _converter_pool['SklearnRandomForestRegressor'].get_allowed_options()),
    'SklearnRandomForestClassifier': RegisteredConverter(
        new_convert_sklearn_random_forest_classifier,
        _converter_pool['SklearnRandomForestClassifier'].get_allowed_options()),
}


[docs]def register_rewritten_operators(new_values=None): """ Registers modified operators and returns the old values. :param new_values: operators to rewrite or None to rewrite default ones :return: old values :githublink:`%|py|91` """ if new_values is None: for rew in _overwritten_operators: if rew not in _converter_pool: raise KeyError( # pragma: no cover "skl2onnx was not imported and '{}' was not registered." "".format(rew)) old_values = {k: _converter_pool[k] for k in _overwritten_operators} _converter_pool.update(_overwritten_operators) return old_values for rew in new_values: if rew not in _converter_pool: raise KeyError( # pragma: no cover "skl2onnx was not imported and '{}' was not registered." "".format(rew)) old_values = {k: _converter_pool[k] for k in new_values} _converter_pool.update(new_values) return old_values