Source code for mlprodict.plotting.plotting_benchmark

"""
Useful plots.


:githublink:`%|py|5`
"""

import numpy
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm


[docs]def heatmap(data, row_labels, col_labels, ax=None, cbar_kw=None, cbarlabel=None, **kwargs): """ Creates a heatmap from a numpy array and two lists of labels. See :func:`plot_benchmark_metrics <mlprodict.plotting.plotting_benchmark.plot_benchmark_metrics>` for an example. :param data: a 2D numpy array of shape (N, M). :param row_labels: a list or array of length N with the labels for the rows. :param col_labels: a list or array of length M with the labels for the columns. :param ax: a `matplotlib.axes.Axes` instance to which the heatmap is plotted, if not provided, use current axes or create a new one. Optional. :param cbar_kw: a dictionary with arguments to `matplotlib.Figure.colorbar <https://matplotlib.org/api/_as_gen/matplotlib.pyplot.colorbar.html>`_. Optional. :param cbarlabel: the label for the colorbar. Optional. :param kwargs: all other arguments are forwarded to `imshow <https://matplotlib.org/api/_as_gen/matplotlib.pyplot.imshow.html>`_ :return: ax, image, color bar :githublink:`%|py|30` """ if not ax: ax = plt.gca() # pragma: no cover # Plot the heatmap im = ax.imshow(data, **kwargs) # Create colorbar if cbar_kw is None: cbar_kw = {} cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw) if cbarlabel is not None: cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom") # We want to show all ticks... ax.set_xticks(numpy.arange(data.shape[1])) ax.set_yticks(numpy.arange(data.shape[0])) # ... and label them with the respective list entries. ax.set_xticklabels(col_labels) ax.set_yticklabels(row_labels) # Let the horizontal axes labeling appear on top. ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False) # Rotate the tick labels and set their alignment. plt.setp(ax.get_xticklabels(), rotation=-30, ha="right", rotation_mode="anchor") # Turn spines off and create white grid. for _, spine in ax.spines.items(): spine.set_visible(False) ax.set_xticks(numpy.arange(data.shape[1] + 1) - .5, minor=True) ax.set_yticks(numpy.arange(data.shape[0] + 1) - .5, minor=True) ax.grid(which="minor", color="w", linestyle='-', linewidth=3) ax.tick_params(which="minor", bottom=False, left=False) return ax, im, cbar
[docs]def annotate_heatmap(im, data=None, valfmt="{x:.2f}", textcolors=("black", "black"), threshold=None, **textkw): """ Annotates a heatmap. See :func:`plot_benchmark_metrics <mlprodict.plotting.plotting_benchmark.plot_benchmark_metrics>` for an example. :param im: the *AxesImage* to be labeled. :param data: data used to annotate. If None, the image's data is used. Optional. :param valfmt: the format of the annotations inside the heatmap. This should either use the string format method, e.g. `"$ {x:.2f}"`, or be a `matplotlib.ticker.Formatter <https://matplotlib.org/api/ticker_api.html>`_. Optional. :param textcolors: a list or array of two color specifications. The first is used for values below a threshold, the second for those above. Optional. :param threshold: value in data units according to which the colors from textcolors are applied. If None (the default) uses the middle of the colormap as separation. Optional. :param textkw: all other arguments are forwarded to each call to `text` used to create the text labels. :return: annotated objects :githublink:`%|py|92` """ if not isinstance(data, (list, numpy.ndarray)): data = im.get_array() if threshold is not None: threshold = im.norm(threshold) # pragma: no cover else: threshold = im.norm(data.max()) / 2. kw = dict(horizontalalignment="center", verticalalignment="center") kw.update(textkw) # Get the formatter in case a string is supplied if isinstance(valfmt, str): valfmt = matplotlib.ticker.StrMethodFormatter(valfmt) texts = [] for i in range(data.shape[0]): for j in range(data.shape[1]): kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)]) text = im.axes.text(j, i, valfmt(data[i, j], None), **kw) texts.append(text) return texts
[docs]def plot_benchmark_metrics(metric, xlabel=None, ylabel=None, middle=1., transpose=False, ax=None, cbar_kw=None, cbarlabel=None, valfmt="{x:.2f}x"): """ Plots a heatmap which represents a benchmark. See example below. :param metric: dictionary ``{ (x,y): value }`` :param xlabel: x label :param ylabel: y label :param middle: force the white color to be this value :param transpose: switches *x* and *y* :param ax: axis to borrow :param cbar_kw: a dictionary with arguments to `matplotlib.Figure.colorbar <https://matplotlib.org/api/_as_gen/matplotlib.pyplot.colorbar.html>`_. Optional. :param cbarlabel: the label for the colorbar. Optional. :param valfmt: format for the annotations :return: ax, colorbar .. exref:: :title: Plot benchmark improvments :lid: plot-2d-benchmark-metric .. plot:: import matplotlib.pyplot as plt from mlprodict.tools.plotting import plot_benchmark_metrics data = {(1, 1): 0.1, (10, 1): 1, (1, 10): 2, (10, 10): 100, (100, 1): 100, (100, 10): 1000} fig, ax = plt.subplots(1, 2, figsize=(10, 4)) plot_benchmark_metrics(data, ax=ax[0], cbar_kw={'shrink': 0.6}) plot_benchmark_metrics(data, ax=ax[1], transpose=True, xlabel='X', ylabel='Y', cbarlabel="ratio") plt.show() :githublink:`%|py|156` """ if transpose: metric = {(k[1], k[0]): v for k, v in metric.items()} return plot_benchmark_metrics(metric, ax=ax, xlabel=ylabel, ylabel=xlabel, middle=middle, transpose=False, cbar_kw=cbar_kw, cbarlabel=cbarlabel) x = numpy.array(list(sorted(set(k[0] for k in metric)))) y = numpy.array(list(sorted(set(k[1] for k in metric)))) rx = {v: i for i, v in enumerate(x)} ry = {v: i for i, v in enumerate(y)} X, _ = numpy.meshgrid(x, y) zm = numpy.zeros(X.shape, dtype=numpy.float64) for k, v in metric.items(): zm[ry[k[1]], rx[k[0]]] = v xs = [str(_) for _ in x] ys = [str(_) for _ in y] vmin = min(metric.values()) vmax = max(metric.values()) if middle is not None: v1 = middle / vmin v2 = middle / vmax vmin = min(vmin, v2) vmax = max(vmax, v1) ax, im, cbar = heatmap(zm, ys, xs, ax=ax, cmap="bwr", norm=LogNorm(vmin=vmin, vmax=vmax), cbarlabel=cbarlabel, cbar_kw=cbar_kw) annotate_heatmap(im, valfmt=valfmt) if xlabel is not None: ax.set_xlabel(xlabel) if ylabel is not None: ax.set_ylabel(ylabel) return ax, cbar