Source code for mlinsights.mlbatch.cache_model

"""
Caches to cache training.


:githublink:`%|py|5`
"""
import numpy

_caches = {}


[docs]class MLCache: """ Implements a cache to reduce the number of trainings a grid search has to do. :githublink:`%|py|14` """
[docs] def __init__(self, name): """ :param name: name of the cache :githublink:`%|py|19` """ self.name = name self.cached = {} self.count_ = {}
[docs] def cache(self, params, value): """ Caches one object. :param params: dictionary of parameters :param value: value to cache :githublink:`%|py|30` """ key = MLCache.as_key(params) if key in self.cached: raise KeyError( # pragma: no cover "Key {0} already exists".format(params)) self.cached[key] = value self.count_[key] = 0
[docs] def get(self, params, default=None): """ Retrieves an element from the cache. :param params: dictionary of parameters :param default: if not found :return: value or None if it does not exists :githublink:`%|py|45` """ key = MLCache.as_key(params) res = self.cached.get(key, default) if res != default: self.count_[key] += 1 return res
[docs] def count(self, params): """ Retrieves the number of times an elements was retrieved from the cache. :param params: dictionary of parameters :return: int :githublink:`%|py|59` """ key = MLCache.as_key(params) return self.count_.get(key, 0)
[docs] @staticmethod def as_key(params): """ Converts a list of parameters into a key. :param params: dictionary :return: key as a string :githublink:`%|py|70` """ if isinstance(params, str): return params els = [] for k, v in sorted(params.items()): if isinstance(v, (int, float, str)): sv = str(v) elif isinstance(v, tuple): if not all(map(lambda e: isinstance(e, (int, float, str)), v)): raise TypeError( # pragma: no cover "Unable to create a key with value '{0}':{1}".format(k, v)) return str(v) elif isinstance(v, numpy.ndarray): # id(v) may have been better but # it does not play well with joblib. sv = hash(v.tostring()) elif v is None: sv = "" else: raise TypeError( # pragma: no cover "Unable to create a key with value '{0}':{1}".format(k, v)) els.append((k, sv)) return str(els)
[docs] def __len__(self): """ Returns the number of cached items. :githublink:`%|py|97` """ return len(self.cached)
[docs] def items(self): """ Enumerates all cached items. :githublink:`%|py|103` """ for item in self.cached.items(): yield item
[docs] def keys(self): """ Enumerates all cached keys. :githublink:`%|py|110` """ for k in self.cached.keys(): # pylint: disable=C0201 yield k
[docs] @staticmethod def create_cache(name): """ Creates a new cache. :param name: name :return: created cache :githublink:`%|py|121` """ global _caches # pylint: disable=W0603 if name in _caches: raise RuntimeError( # pragma: no cover "cache '{0}' already exists.".format(name)) cache = MLCache(name) _caches[name] = cache return cache
[docs] @staticmethod def remove_cache(name): """ Removes a cache with a given name. :param name: name :githublink:`%|py|137` """ global _caches # pylint: disable=W0603 del _caches[name]
[docs] @staticmethod def get_cache(name): """ Gets a cache with a given name. :param name: name :return: created cache :githublink:`%|py|148` """ global _caches # pylint: disable=W0603 return _caches[name]
[docs] @staticmethod def has_cache(name): """ Tells if cache *name* is present. :param name: name :return: boolean :githublink:`%|py|159` """ global _caches # pylint: disable=W0603 return name in _caches