Source code for mlinsights.sklapi.sklearn_base
# -*- coding: utf-8 -*-
"""
Implements a *learner* or a *transform* which follows the same API
as every :epkg:`scikit-learn` transform.
:githublink:`%|py|7`
"""
import textwrap
import warnings
from .sklearn_parameters import SkLearnParameters
[docs]class SkBase:
"""
Pattern of a *learner* or a *transform* which follows the API
of :epkg:`scikit-learn`.
:githublink:`%|py|16`
"""
[docs] def __init__(self, **kwargs):
"""
Stores the parameters, see
:class:`SkLearnParameters <mlinsights.sklapi.sklearn_parameters.SkLearnParameters>`, it keeps a copy of
the parameters to easily implements method *get_params*
and clones a model.
:githublink:`%|py|24`
"""
self.P = SkLearnParameters(**kwargs)
[docs] def fit(self, X, y=None, sample_weight=None):
"""
Trains a model.
:param X: features
:param y: target
:param sample_weight: weight
:return: self
:githublink:`%|py|35`
"""
raise NotImplementedError() # pragma: no cover
[docs] def get_params(self, deep=True):
"""
Returns the parameters which define the objet,
all are needed to clone the object.
:param deep: unused here
:return: dict
:githublink:`%|py|45`
"""
return self.P.to_dict()
[docs] def set_params(self, **values):
"""
Udpates parameters which define the object,
all needed to clone the object.
:param values: values
:return: dictionary
:githublink:`%|py|55`
"""
self.P = SkLearnParameters(**values)
return self
[docs] def __eq__(self, o):
"""
Compares two objects, more precisely,
compares the parameters which define the object.
:githublink:`%|py|63`
"""
return self.test_equality(o, False)
[docs] def test_equality(self, o, exc=True):
"""
Compares two objects and checks parameters have
the same values.
:param p1: dictionary
:param p2: dictionary
:param exc: raises an exception if there is a difference
:return: boolean
:githublink:`%|py|75`
"""
if self.__class__ != o.__class__:
return False
p1 = self.get_params()
p2 = o.get_params()
return SkBase.compare_params(p1, p2, exc=exc)
[docs] @staticmethod
def compare_params(p1, p2, exc=True):
"""
Compares two sets of parameters.
:param p1: dictionary
:param p2: dictionary
:param exc: raises an exception if error is met
:return: boolean
:githublink:`%|py|91`
"""
if p1 == p2:
return True
for k in p1:
if k not in p2:
if exc:
raise KeyError("Key '{0}' was removed.".format(k))
else:
return False
for k in p2:
if k not in p1:
if exc:
raise KeyError("Key '{0}' was added.".format(k))
return False
for k in sorted(p1):
v1, v2 = p1[k], p2[k]
if hasattr(v1, 'test_equality'):
b = v1.test_equality(v2, exc=exc)
if exc and v1 is not v2:
warnings.warn( # pragma: no cover
"v2 is a clone of v1 not v1 itself for key '{0}' and class {1}."
"".format(k, type(v1)))
elif isinstance(v1, list) and isinstance(v2, list) and len(v1) == len(v2):
b = True
for e1, e2 in zip(v1, v2):
if hasattr(e1, 'test_equality'):
b = e1.test_equality(e2, exc=exc)
if not b:
return b
elif isinstance(v1, dict) and isinstance(v2, dict) and set(v1) == set(v2):
b = True
for e1, e2 in zip(sorted(v1.items()), sorted(v2.items())):
if hasattr(e1[1], 'test_equality'):
b = e1[1].test_equality(e2[1], exc=exc)
if not b:
return b
elif e1[1] != e2[1]:
return False
elif hasattr(v1, "get_params") and hasattr(v2, "get_params"):
b = SkBase.compare_params(v1.get_params(
deep=False), v2.get_params(deep=False), exc=exc)
else:
b = v1 == v2
if not b:
if exc:
raise ValueError(
"Values for key '{0}' are different.\n---\n{1}\n---\n{2}".format(k, v1, v2))
else:
return False
return True
[docs] def __repr__(self):
"""
usual
:githublink:`%|py|145`
"""
res = "{0}({1})".format(self.__class__.__name__, str(self.P))
return "\n".join(textwrap.wrap(res, subsequent_indent=" "))