Source code for onnxmltools.convert.xgboost.convert
# SPDX-License-Identifier: Apache-2.0
from uuid import uuid4
import xgboost
import onnx
from ..common.onnx_ex import get_maximum_opset_supported
from ..common._topology import convert_topology
from ._parse import parse_xgboost, WrappedBooster
# Invoke the registration of all our converters and shape calculators
# from . import shape_calculators
from . import operator_converters, shape_calculators
[docs]def convert(model, name=None, initial_types=None, doc_string='', target_opset=None,
targeted_onnx=onnx.__version__, custom_conversion_functions=None,
custom_shape_calculators=None):
'''
This function produces an equivalent ONNX model of the given xgboost model.
:param model: A xgboost model
:param initial_types: a python list. Each element is a tuple of a variable name and a type defined in data_types.py
:param name: The name of the graph (type: GraphProto) in the produced ONNX model (type: ModelProto)
:param doc_string: A string attached onto the produced ONNX model
:param target_opset: number, for example, 7 for ONNX 1.2, and 8 for ONNX 1.3.
:param targeted_onnx: A string (for example, '1.1.2' and '1.2') used to specify the targeted ONNX version of the
produced model. If ONNXMLTools cannot find a compatible ONNX python package, an error may be thrown.
:param custom_conversion_functions: a dictionary for specifying the user customized conversion function
:param custom_shape_calculators: a dictionary for specifying the user customized shape calculator
:return: An ONNX model (type: ModelProto) which is equivalent to the input xgboost model
'''
if initial_types is None:
raise ValueError('Initial types are required. See usage of convert(...) in \
onnxmltools.convert.xgboost.convert for details')
if name is None:
name = str(uuid4().hex)
if isinstance(model, xgboost.Booster):
model = WrappedBooster(model)
target_opset = target_opset if target_opset else get_maximum_opset_supported()
topology = parse_xgboost(model, initial_types, target_opset, custom_conversion_functions, custom_shape_calculators)
topology.compile()
onnx_model = convert_topology(topology, name, doc_string, target_opset, targeted_onnx)
return onnx_model