Source code for mlprodict.onnx_conv.parsers.parse_lightgbm

"""
Parsers for LightGBM booster.


:githublink:`%|py|5`
"""
import numpy
from sklearn.base import ClassifierMixin
from skl2onnx._parse import _parse_sklearn_classifier, _parse_sklearn_simple_model
from skl2onnx.common._apply_operation import apply_concat, apply_cast
from skl2onnx.common.data_types import guess_proto_type
from skl2onnx.proto import onnx_proto


[docs]class WrappedLightGbmBooster: """ A booster can be a classifier, a regressor. Trick to wrap it in a minimal function. :githublink:`%|py|17` """
[docs] def __init__(self, booster): self.booster_ = booster self._model_dict = self.booster_.dump_model() self.classes_ = self._generate_classes(self._model_dict) self.n_features_ = len(self._model_dict['feature_names']) if self._model_dict['objective'].startswith('binary'): self.operator_name = 'LgbmClassifier' elif self._model_dict['objective'].startswith('regression'): # pragma: no cover self.operator_name = 'LgbmRegressor' else: # pragma: no cover raise NotImplementedError('Unsupported LightGbm objective: {}'.format( self._model_dict['objective'])) if self._model_dict.get('average_output', False): self.boosting_type = 'rf' else: # Other than random forest, other boosting types do not affect later conversion. # Here `gbdt` is chosen for no reason. self.boosting_type = 'gbdt'
[docs] def _generate_classes(self, model_dict): if model_dict['num_class'] == 1: return numpy.asarray([0, 1]) return numpy.arange(model_dict['num_class'])
[docs]class WrappedLightGbmBoosterClassifier(ClassifierMixin): """ Trick to wrap a LGBMClassifier into a class. :githublink:`%|py|47` """
[docs] def __init__(self, wrapped): # pylint: disable=W0231 for k in {'boosting_type', '_model_dict', 'operator_name', 'classes_', 'booster_', 'n_features_'}: setattr(self, k, getattr(wrapped, k))
[docs]class MockWrappedLightGbmBoosterClassifier(WrappedLightGbmBoosterClassifier): """ Mocked lightgbm. :githublink:`%|py|58` """
[docs] def __init__(self, tree): # pylint: disable=W0231 self.dumped_ = tree
[docs] def dump_model(self): "mock dump_model method" self.visited = True return self.dumped_
[docs]def lightgbm_parser(scope, model, inputs, custom_parsers=None): """ Agnostic parser for LightGBM Booster. :githublink:`%|py|72` """ if hasattr(model, "fit"): raise TypeError( # pragma: no cover "This converter does not apply on type '{}'." "".format(type(model))) if len(inputs) == 1: wrapped = WrappedLightGbmBooster(model) if wrapped._model_dict['objective'].startswith('binary'): wrapped = WrappedLightGbmBoosterClassifier(wrapped) return _parse_sklearn_classifier( scope, wrapped, inputs, custom_parsers=custom_parsers) if wrapped._model_dict['objective'].startswith('regression'): # pragma: no cover return _parse_sklearn_simple_model( scope, wrapped, inputs, custom_parsers=custom_parsers) raise NotImplementedError( # pragma: no cover "Objective '{}' is not implemented yet.".format( wrapped._model_dict['objective'])) # Multiple columns this_operator = scope.declare_local_operator('LightGBMConcat') this_operator.raw_operator = model this_operator.inputs = inputs var = scope.declare_local_variable( 'Xlgbm', inputs[0].type.__class__([None, None])) this_operator.outputs.append(var) return lightgbm_parser(scope, model, this_operator.outputs, custom_parsers=custom_parsers)
[docs]def shape_calculator_lightgbm_concat(operator): """ Shape calculator for operator *LightGBMConcat*. :githublink:`%|py|105` """ pass
[docs]def converter_lightgbm_concat(scope, operator, container): """ Converter for operator *LightGBMConcat*. :githublink:`%|py|112` """ op = operator.raw_operator options = container.get_options(op, dict(cast=False)) proto_dtype = guess_proto_type(operator.inputs[0].type) if proto_dtype != onnx_proto.TensorProto.DOUBLE: # pylint: disable=E1101 proto_dtype = onnx_proto.TensorProto.FLOAT # pylint: disable=E1101 if options['cast']: concat_name = scope.get_unique_variable_name('cast_lgbm') apply_cast(scope, concat_name, operator.outputs[0].full_name, container, operator_name=scope.get_unique_operator_name('cast_lgmb'), to=proto_dtype) else: concat_name = operator.outputs[0].full_name apply_concat(scope, [_.full_name for _ in operator.inputs], concat_name, container, axis=1)