Source code for mlinsights.mlbatch.cache_model
"""
Caches to cache training.
:githublink:`%|py|5`
"""
import numpy
_caches = {}
[docs]class MLCache:
"""
Implements a cache to reduce the number of trainings
a grid search has to do.
:githublink:`%|py|14`
"""
[docs] def __init__(self, name):
"""
:param name: name of the cache
:githublink:`%|py|19`
"""
self.name = name
self.cached = {}
self.count_ = {}
[docs] def cache(self, params, value):
"""
Caches one object.
:param params: dictionary of parameters
:param value: value to cache
:githublink:`%|py|30`
"""
key = MLCache.as_key(params)
if key in self.cached:
raise KeyError( # pragma: no cover
"Key {0} already exists".format(params))
self.cached[key] = value
self.count_[key] = 0
[docs] def get(self, params, default=None):
"""
Retrieves an element from the cache.
:param params: dictionary of parameters
:param default: if not found
:return: value or None if it does not exists
:githublink:`%|py|45`
"""
key = MLCache.as_key(params)
res = self.cached.get(key, default)
if res != default:
self.count_[key] += 1
return res
[docs] def count(self, params):
"""
Retrieves the number of times
an elements was retrieved from the cache.
:param params: dictionary of parameters
:return: int
:githublink:`%|py|59`
"""
key = MLCache.as_key(params)
return self.count_.get(key, 0)
[docs] @staticmethod
def as_key(params):
"""
Converts a list of parameters into a key.
:param params: dictionary
:return: key as a string
:githublink:`%|py|70`
"""
if isinstance(params, str):
return params
els = []
for k, v in sorted(params.items()):
if isinstance(v, (int, float, str)):
sv = str(v)
elif isinstance(v, tuple):
if not all(map(lambda e: isinstance(e, (int, float, str)), v)):
raise TypeError( # pragma: no cover
"Unable to create a key with value '{0}':{1}".format(k, v))
return str(v)
elif isinstance(v, numpy.ndarray):
# id(v) may have been better but
# it does not play well with joblib.
sv = hash(v.tostring())
elif v is None:
sv = ""
else:
raise TypeError( # pragma: no cover
"Unable to create a key with value '{0}':{1}".format(k, v))
els.append((k, sv))
return str(els)
[docs] def __len__(self):
"""
Returns the number of cached items.
:githublink:`%|py|97`
"""
return len(self.cached)
[docs] def items(self):
"""
Enumerates all cached items.
:githublink:`%|py|103`
"""
for item in self.cached.items():
yield item
[docs] def keys(self):
"""
Enumerates all cached keys.
:githublink:`%|py|110`
"""
for k in self.cached.keys(): # pylint: disable=C0201
yield k
[docs] @staticmethod
def create_cache(name):
"""
Creates a new cache.
:param name: name
:return: created cache
:githublink:`%|py|121`
"""
global _caches # pylint: disable=W0603
if name in _caches:
raise RuntimeError( # pragma: no cover
"cache '{0}' already exists.".format(name))
cache = MLCache(name)
_caches[name] = cache
return cache
[docs] @staticmethod
def remove_cache(name):
"""
Removes a cache with a given name.
:param name: name
:githublink:`%|py|137`
"""
global _caches # pylint: disable=W0603
del _caches[name]
[docs] @staticmethod
def get_cache(name):
"""
Gets a cache with a given name.
:param name: name
:return: created cache
:githublink:`%|py|148`
"""
global _caches # pylint: disable=W0603
return _caches[name]
[docs] @staticmethod
def has_cache(name):
"""
Tells if cache *name* is present.
:param name: name
:return: boolean
:githublink:`%|py|159`
"""
global _caches # pylint: disable=W0603
return name in _caches