Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Useful plots.
4"""
6import numpy
7import matplotlib
8import matplotlib.pyplot as plt
9from matplotlib.colors import LogNorm
12def heatmap(data, row_labels, col_labels, ax=None,
13 cbar_kw=None, cbarlabel=None, **kwargs):
14 """
15 Creates a heatmap from a numpy array and two lists of labels.
16 See @see fn plot_benchmark_metrics for an example.
18 @param data a 2D numpy array of shape (N, M).
19 @param row_labels a list or array of length N with the labels for the rows.
20 @param col_labels a list or array of length M with the labels for the columns.
21 @param ax a `matplotlib.axes.Axes` instance to which the heatmap is plotted,
22 if not provided, use current axes or create a new one. Optional.
23 @param cbar_kw a dictionary with arguments to `matplotlib.Figure.colorbar
24 <https://matplotlib.org/api/_as_gen/matplotlib.pyplot.colorbar.html>`_.
25 Optional.
26 @param cbarlabel the label for the colorbar. Optional.
27 @param kwargs all other arguments are forwarded to `imshow
28 <https://matplotlib.org/api/_as_gen/matplotlib.pyplot.imshow.html>`_
29 @return ax, image, color bar
30 """
32 if not ax:
33 ax = plt.gca() # pragma: no cover
35 # Plot the heatmap
36 im = ax.imshow(data, **kwargs)
38 # Create colorbar
39 if cbar_kw is None:
40 cbar_kw = {}
41 cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw)
42 if cbarlabel is not None:
43 cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom")
45 # We want to show all ticks...
46 ax.set_xticks(numpy.arange(data.shape[1]))
47 ax.set_yticks(numpy.arange(data.shape[0]))
48 # ... and label them with the respective list entries.
49 ax.set_xticklabels(col_labels)
50 ax.set_yticklabels(row_labels)
52 # Let the horizontal axes labeling appear on top.
53 ax.tick_params(top=True, bottom=False,
54 labeltop=True, labelbottom=False)
56 # Rotate the tick labels and set their alignment.
57 plt.setp(ax.get_xticklabels(), rotation=-30, ha="right",
58 rotation_mode="anchor")
60 # Turn spines off and create white grid.
61 for _, spine in ax.spines.items():
62 spine.set_visible(False)
64 ax.set_xticks(numpy.arange(data.shape[1] + 1) - .5, minor=True)
65 ax.set_yticks(numpy.arange(data.shape[0] + 1) - .5, minor=True)
66 ax.grid(which="minor", color="w", linestyle='-', linewidth=3)
67 ax.tick_params(which="minor", bottom=False, left=False)
68 return ax, im, cbar
71def annotate_heatmap(im, data=None, valfmt="{x:.2f}",
72 textcolors=("black", "black"),
73 threshold=None, **textkw):
74 """
75 Annotates a heatmap.
76 See @see fn plot_benchmark_metrics for an example.
78 @param im the *AxesImage* to be labeled.
79 @param data data used to annotate. If None, the image's data is used. Optional.
80 @param valfmt the format of the annotations inside the heatmap. This should either
81 use the string format method, e.g. `"$ {x:.2f}"`, or be a
82 `matplotlib.ticker.Formatter
83 <https://matplotlib.org/api/ticker_api.html>`_. Optional.
84 @param textcolors a list or array of two color specifications. The first is used for
85 values below a threshold, the second for those above. Optional.
86 @param threshold value in data units according to which the colors from textcolors are
87 applied. If None (the default) uses the middle of the colormap as
88 separation. Optional.
89 @param textkw all other arguments are forwarded to each call to `text` used to create
90 the text labels.
91 @return annotated objects
92 """
93 if not isinstance(data, (list, numpy.ndarray)):
94 data = im.get_array()
95 if threshold is not None:
96 threshold = im.norm(threshold) # pragma: no cover
97 else:
98 threshold = im.norm(data.max()) / 2.
100 kw = dict(horizontalalignment="center", verticalalignment="center")
101 kw.update(textkw)
103 # Get the formatter in case a string is supplied
104 if isinstance(valfmt, str):
105 valfmt = matplotlib.ticker.StrMethodFormatter(valfmt)
107 texts = []
108 for i in range(data.shape[0]):
109 for j in range(data.shape[1]):
110 kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)])
111 text = im.axes.text(j, i, valfmt(data[i, j], None), **kw)
112 texts.append(text)
114 return texts
117def plot_benchmark_metrics(metric, xlabel=None, ylabel=None,
118 middle=1., transpose=False, ax=None,
119 cbar_kw=None, cbarlabel=None,
120 valfmt="{x:.2f}x"):
121 """
122 Plots a heatmap which represents a benchmark.
123 See example below.
125 @param metric dictionary ``{ (x,y): value }``
126 @param xlabel x label
127 @param ylabel y label
128 @param middle force the white color to be this value
129 @param transpose switches *x* and *y*
130 @param ax axis to borrow
131 @param cbar_kw a dictionary with arguments to `matplotlib.Figure.colorbar
132 <https://matplotlib.org/api/_as_gen/matplotlib.pyplot.colorbar.html>`_.
133 Optional.
134 @param cbarlabel the label for the colorbar. Optional.
135 @param valfmt format for the annotations
136 @return ax, colorbar
138 .. exref::
139 :title: Plot benchmark improvments
140 :lid: plot-2d-benchmark-metric
142 .. plot::
144 import matplotlib.pyplot as plt
145 from mlprodict.plotting.plotting_benchmark import plot_benchmark_metrics
147 data = {(1, 1): 0.1, (10, 1): 1, (1, 10): 2,
148 (10, 10): 100, (100, 1): 100, (100, 10): 1000}
150 fig, ax = plt.subplots(1, 2, figsize=(10, 4))
151 plot_benchmark_metrics(data, ax=ax[0], cbar_kw={'shrink': 0.6})
152 plot_benchmark_metrics(data, ax=ax[1], transpose=True,
153 xlabel='X', ylabel='Y',
154 cbarlabel="ratio")
155 plt.show()
156 """
157 if transpose:
158 metric = {(k[1], k[0]): v for k, v in metric.items()}
159 return plot_benchmark_metrics(metric, ax=ax, xlabel=ylabel, ylabel=xlabel,
160 middle=middle, transpose=False,
161 cbar_kw=cbar_kw, cbarlabel=cbarlabel)
163 x = numpy.array(list(sorted(set(k[0] for k in metric))))
164 y = numpy.array(list(sorted(set(k[1] for k in metric))))
165 rx = {v: i for i, v in enumerate(x)}
166 ry = {v: i for i, v in enumerate(y)}
168 X, _ = numpy.meshgrid(x, y)
169 zm = numpy.zeros(X.shape, dtype=numpy.float64)
170 for k, v in metric.items():
171 zm[ry[k[1]], rx[k[0]]] = v
173 xs = [str(_) for _ in x]
174 ys = [str(_) for _ in y]
175 vmin = min(metric.values())
176 vmax = max(metric.values())
177 if middle is not None:
178 v1 = middle / vmin
179 v2 = middle / vmax
180 vmin = min(vmin, v2)
181 vmax = max(vmax, v1)
182 ax, im, cbar = heatmap(zm, ys, xs, ax=ax, cmap="bwr",
183 norm=LogNorm(vmin=vmin, vmax=vmax),
184 cbarlabel=cbarlabel, cbar_kw=cbar_kw)
185 annotate_heatmap(im, valfmt=valfmt)
186 if xlabel is not None:
187 ax.set_xlabel(xlabel)
188 if ylabel is not None:
189 ax.set_ylabel(ylabel)
190 return ax, cbar