Source code for mlprodict.onnx_conv.sklconv.svm_converters

"""
Rewrites some of the converters implemented in
:epkg:`sklearn-onnx`.


:githublink:`%|py|6`
"""
import numpy
from skl2onnx.operator_converters.support_vector_machines import (
    convert_sklearn_svm_regressor,
    convert_sklearn_svm_classifier)
from skl2onnx.common.data_types import guess_numpy_type


[docs]def _op_type_domain_regressor(dtype): """ Defines *op_type* and *op_domain* based on `dtype`. :githublink:`%|py|16` """ if dtype == numpy.float32: return 'SVMRegressor', 'ai.onnx.ml', 1 if dtype == numpy.float64: return 'SVMRegressorDouble', 'mlprodict', 1 raise RuntimeError( # pragma: no cover "Unsupported dtype {}.".format(dtype))
[docs]def _op_type_domain_classifier(dtype): """ Defines *op_type* and *op_domain* based on `dtype`. :githublink:`%|py|28` """ if dtype == numpy.float32: return 'SVMClassifier', 'ai.onnx.ml', 1 if dtype == numpy.float64: return 'SVMClassifierDouble', 'mlprodict', 1 raise RuntimeError( # pragma: no cover "Unsupported dtype {}.".format(dtype))
[docs]def new_convert_sklearn_svm_regressor(scope, operator, container): """ Rewrites the converters implemented in :epkg:`sklearn-onnx` to support an operator supporting doubles. :githublink:`%|py|42` """ dtype = guess_numpy_type(operator.inputs[0].type) if dtype != numpy.float64: dtype = numpy.float32 op_type, op_domain, op_version = _op_type_domain_regressor(dtype) convert_sklearn_svm_regressor( scope, operator, container, op_type=op_type, op_domain=op_domain, op_version=op_version)
[docs]def new_convert_sklearn_svm_classifier(scope, operator, container): """ Rewrites the converters implemented in :epkg:`sklearn-onnx` to support an operator supporting doubles. :githublink:`%|py|57` """ dtype = guess_numpy_type(operator.inputs[0].type) if dtype != numpy.float64: dtype = numpy.float32 op_type, op_domain, op_version = _op_type_domain_classifier(dtype) convert_sklearn_svm_classifier( scope, operator, container, op_type=op_type, op_domain=op_domain, op_version=op_version)