Source code for pyquickhelper.benchhelper.grid_benchmark

# -*- coding: utf-8 -*-
"""
Grid benchmark.


:githublink:`%|py|6`
"""

from time import perf_counter
from ..loghelper import noLOG
from .benchmark import BenchMark


[docs]class GridBenchMark(BenchMark): """ Compares a couple of machine learning models. :githublink:`%|py|15` """
[docs] def __init__(self, name, datasets, clog=None, fLOG=noLOG, path_to_images=".", cache_file=None, repetition=1, progressbar=None, **params): """ :param name: name of the test :param datasets: list of dictionary of dataframes :param clog: see :class:`CustomLog <pyquickhelper.loghelper.custom_log.CustomLog>` or string :param fLOG: logging function :param params: extra parameters :param path_to_images: path to images :param cache_file: cache file :param repetition: repetition of the experiment (to get confidence interval) :param progressbar: relies on *tqdm*, example *tnrange* If *cache_file* is specified, the class will store the results of the method :meth:`bench <pyquickhelper.benchhelper.benchmark.GridBenchMark.bench>`. On a second run, the function load the cache and run modified or new run (in *param_list*). *datasets* should be a dictionary with dataframes a values with the following keys: * ``'X'``: features * ``'Y'``: labels (optional) :githublink:`%|py|40` """ BenchMark.__init__(self, name=name, datasets=datasets, clog=clog, fLOG=fLOG, path_to_images=path_to_images, cache_file=cache_file, progressbar=progressbar, **params) if not isinstance(datasets, list): raise TypeError("datasets must be a list") # pragma: no cover for i, df in enumerate(datasets): if not isinstance(df, dict): raise TypeError( # pragma: no cover "Every dataset must be a dictionary, {0}th is not.".format(i)) if "X" not in df: raise KeyError( # pragma: no cover "Dictionary {0} should contain key 'X'.".format(i)) if "di" in df: raise KeyError( # pragma: no cover "Dictionary {0} should not contain key 'di'.".format(i)) if "name" not in df: raise KeyError( # pragma: no cover "Dictionary {0} should not contain key 'name'.".format(i)) if "shortname" not in df: raise KeyError( # pragma: no cover "Dictionary {0} should not contain key 'shortname'.".format(i)) self._datasets = datasets self._repetition = repetition
[docs] def init_main(self): """ initialisation :githublink:`%|py|70` """ skip = {"X", "Y", "weight", "name", "shortname"} self.fLOG("[MlGridBenchmark.init] begin") self._datasets_info = [] self._results = [] for i, dd in enumerate(self._datasets): X = dd["X"] N = X.shape[0] Nc = X.shape[1] info = dict(Nrows=N, Nfeat=Nc) for k, v in dd.items(): if k not in skip: info[k] = v self.fLOG( "[MlGridBenchmark.init] dataset {0}: {1}".format(i, info)) self._datasets_info.append(info) self.fLOG("[MlGridBenchmark.init] end")
[docs] def init(self): """ Skips it. :githublink:`%|py|92` """ pass # pragma: no cover
[docs] def run(self, params_list): """ Runs the benchmark. :githublink:`%|py|98` """ self.init_main() self.fLOG("[MlGridBenchmark.bench] start") self.fLOG("[MlGridBenchmark.bench] number of datasets", len(self._datasets)) self.fLOG("[MlGridBenchmark.bench] number of experiments", len(params_list)) unique = set() for i, pars in enumerate(params_list): if "name" not in pars: raise KeyError( # pragma: no cover "Dictionary {0} must contain key 'name'.".format(i)) if "shortname" not in pars: raise KeyError( # pragma: no cover "Dictionary {0} must contain key 'shortname'.".format(i)) if pars["name"] in unique: raise ValueError( # pragma: no cover "'{0}' is duplicated.".format(pars["name"])) unique.add(pars["name"]) if pars["shortname"] in unique: raise ValueError( # pragma: no cover "'{0}' is duplicated.".format(pars["shortname"])) unique.add(pars["shortname"]) # Multiplies the experiments. full_list = [] for i in range(len(self._datasets)): for pars in params_list: pc = pars.copy() pc["di"] = i full_list.append(pc) # Runs the bench res = BenchMark.run(self, full_list) self.fLOG("[MlGridBenchmark.bench] end") return res
[docs] def bench(self, **params): """ Runs an experiment multiple times, parameter *di* is the dataset to use. :githublink:`%|py|141` """ if "di" not in params: raise KeyError( "key 'di' is missing from params") # pragma: no cover results = [] for iexp in range(self._repetition): di = params["di"] shortname_model = params["shortname"] name_model = params["name"] shortname_ds = self._datasets[di]["shortname"] name_ds = self._datasets[di]["name"] cl = perf_counter() ds, appe, pars = self.preprocess_dataset(di, **params) split = perf_counter() - cl cl = perf_counter() output = self.bench_experiment(ds, **pars) train = perf_counter() - cl cl = perf_counter() metrics, appe_ = self.predict_score_experiment(ds, output) test = perf_counter() - cl metrics["time_preproc"] = split metrics["time_train"] = train metrics["time_test"] = test metrics["_btry"] = "{0}-{1}".format(shortname_model, shortname_ds) metrics["_iexp"] = iexp metrics["model_name"] = name_model metrics["ds_name"] = name_ds appe.update(appe_) appe["_iexp"] = iexp metrics.update(appe) appe["_btry"] = metrics["_btry"] if "_i" in metrics: del metrics["_i"] results.append((metrics, appe)) return results
[docs] def preprocess_dataset(self, dsi, **params): """ Splits the dataset into train and test. :param dsi: dataset index :param params: additional parameters :return: list of (dataset (like info), dictionary for metrics, parameters) :githublink:`%|py|192` """ ds = self._datasets[dsi] appe = self._datasets_info[dsi].copy() params = params.copy() if "di" in params: del params["di"] return ds, appe, params
[docs] def bench_experiment(self, info, **params): """ function to overload :param info: dictionary with at least key ``'X'`` :param params: additional parameters :return: output of the experiment :githublink:`%|py|207` """ raise NotImplementedError() # pragma: no cover
[docs] def predict_score_experiment(self, info, output, **params): """ function to overload :param info: dictionary with at least key ``'X'`` :param output: output of the benchmar :param params: additional parameters :return: output of the experiment, tuple of dictionaries :githublink:`%|py|218` """ raise NotImplementedError() # pragma: no cover