Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Useful plots. 

4""" 

5 

6import numpy 

7import matplotlib 

8import matplotlib.pyplot as plt 

9from matplotlib.colors import LogNorm 

10 

11 

12def heatmap(data, row_labels, col_labels, ax=None, 

13 cbar_kw=None, cbarlabel=None, **kwargs): 

14 """ 

15 Creates a heatmap from a numpy array and two lists of labels. 

16 See @see fn plot_benchmark_metrics for an example. 

17 

18 @param data a 2D numpy array of shape (N, M). 

19 @param row_labels a list or array of length N with the labels for the rows. 

20 @param col_labels a list or array of length M with the labels for the columns. 

21 @param ax a `matplotlib.axes.Axes` instance to which the heatmap is plotted, 

22 if not provided, use current axes or create a new one. Optional. 

23 @param cbar_kw a dictionary with arguments to `matplotlib.Figure.colorbar 

24 <https://matplotlib.org/api/_as_gen/matplotlib.pyplot.colorbar.html>`_. 

25 Optional. 

26 @param cbarlabel the label for the colorbar. Optional. 

27 @param kwargs all other arguments are forwarded to `imshow 

28 <https://matplotlib.org/api/_as_gen/matplotlib.pyplot.imshow.html>`_ 

29 @return ax, image, color bar 

30 """ 

31 

32 if not ax: 

33 ax = plt.gca() # pragma: no cover 

34 

35 # Plot the heatmap 

36 im = ax.imshow(data, **kwargs) 

37 

38 # Create colorbar 

39 if cbar_kw is None: 

40 cbar_kw = {} 

41 cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw) 

42 if cbarlabel is not None: 

43 cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom") 

44 

45 # We want to show all ticks... 

46 ax.set_xticks(numpy.arange(data.shape[1])) 

47 ax.set_yticks(numpy.arange(data.shape[0])) 

48 # ... and label them with the respective list entries. 

49 ax.set_xticklabels(col_labels) 

50 ax.set_yticklabels(row_labels) 

51 

52 # Let the horizontal axes labeling appear on top. 

53 ax.tick_params(top=True, bottom=False, 

54 labeltop=True, labelbottom=False) 

55 

56 # Rotate the tick labels and set their alignment. 

57 plt.setp(ax.get_xticklabels(), rotation=-30, ha="right", 

58 rotation_mode="anchor") 

59 

60 # Turn spines off and create white grid. 

61 for _, spine in ax.spines.items(): 

62 spine.set_visible(False) 

63 

64 ax.set_xticks(numpy.arange(data.shape[1] + 1) - .5, minor=True) 

65 ax.set_yticks(numpy.arange(data.shape[0] + 1) - .5, minor=True) 

66 ax.grid(which="minor", color="w", linestyle='-', linewidth=3) 

67 ax.tick_params(which="minor", bottom=False, left=False) 

68 return ax, im, cbar 

69 

70 

71def annotate_heatmap(im, data=None, valfmt="{x:.2f}", 

72 textcolors=("black", "black"), 

73 threshold=None, **textkw): 

74 """ 

75 Annotates a heatmap. 

76 See @see fn plot_benchmark_metrics for an example. 

77 

78 @param im the *AxesImage* to be labeled. 

79 @param data data used to annotate. If None, the image's data is used. Optional. 

80 @param valfmt the format of the annotations inside the heatmap. This should either 

81 use the string format method, e.g. `"$ {x:.2f}"`, or be a 

82 `matplotlib.ticker.Formatter 

83 <https://matplotlib.org/api/ticker_api.html>`_. Optional. 

84 @param textcolors a list or array of two color specifications. The first is used for 

85 values below a threshold, the second for those above. Optional. 

86 @param threshold value in data units according to which the colors from textcolors are 

87 applied. If None (the default) uses the middle of the colormap as 

88 separation. Optional. 

89 @param textkw all other arguments are forwarded to each call to `text` used to create 

90 the text labels. 

91 @return annotated objects 

92 """ 

93 if not isinstance(data, (list, numpy.ndarray)): 

94 data = im.get_array() 

95 if threshold is not None: 

96 threshold = im.norm(threshold) # pragma: no cover 

97 else: 

98 threshold = im.norm(data.max()) / 2. 

99 

100 kw = dict(horizontalalignment="center", verticalalignment="center") 

101 kw.update(textkw) 

102 

103 # Get the formatter in case a string is supplied 

104 if isinstance(valfmt, str): 

105 valfmt = matplotlib.ticker.StrMethodFormatter(valfmt) 

106 

107 texts = [] 

108 for i in range(data.shape[0]): 

109 for j in range(data.shape[1]): 

110 kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)]) 

111 text = im.axes.text(j, i, valfmt(data[i, j], None), **kw) 

112 texts.append(text) 

113 

114 return texts 

115 

116 

117def plot_benchmark_metrics(metric, xlabel=None, ylabel=None, 

118 middle=1., transpose=False, ax=None, 

119 cbar_kw=None, cbarlabel=None, 

120 valfmt="{x:.2f}x"): 

121 """ 

122 Plots a heatmap which represents a benchmark. 

123 See example below. 

124 

125 @param metric dictionary ``{ (x,y): value }`` 

126 @param xlabel x label 

127 @param ylabel y label 

128 @param middle force the white color to be this value 

129 @param transpose switches *x* and *y* 

130 @param ax axis to borrow 

131 @param cbar_kw a dictionary with arguments to `matplotlib.Figure.colorbar 

132 <https://matplotlib.org/api/_as_gen/matplotlib.pyplot.colorbar.html>`_. 

133 Optional. 

134 @param cbarlabel the label for the colorbar. Optional. 

135 @param valfmt format for the annotations 

136 @return ax, colorbar 

137 

138 .. exref:: 

139 :title: Plot benchmark improvments 

140 :lid: plot-2d-benchmark-metric 

141 

142 .. plot:: 

143 

144 import matplotlib.pyplot as plt 

145 from mlprodict.plotting.plotting_benchmark import plot_benchmark_metrics 

146 

147 data = {(1, 1): 0.1, (10, 1): 1, (1, 10): 2, 

148 (10, 10): 100, (100, 1): 100, (100, 10): 1000} 

149 

150 fig, ax = plt.subplots(1, 2, figsize=(10, 4)) 

151 plot_benchmark_metrics(data, ax=ax[0], cbar_kw={'shrink': 0.6}) 

152 plot_benchmark_metrics(data, ax=ax[1], transpose=True, 

153 xlabel='X', ylabel='Y', 

154 cbarlabel="ratio") 

155 plt.show() 

156 """ 

157 if transpose: 

158 metric = {(k[1], k[0]): v for k, v in metric.items()} 

159 return plot_benchmark_metrics(metric, ax=ax, xlabel=ylabel, ylabel=xlabel, 

160 middle=middle, transpose=False, 

161 cbar_kw=cbar_kw, cbarlabel=cbarlabel) 

162 

163 x = numpy.array(list(sorted(set(k[0] for k in metric)))) 

164 y = numpy.array(list(sorted(set(k[1] for k in metric)))) 

165 rx = {v: i for i, v in enumerate(x)} 

166 ry = {v: i for i, v in enumerate(y)} 

167 

168 X, _ = numpy.meshgrid(x, y) 

169 zm = numpy.zeros(X.shape, dtype=numpy.float64) 

170 for k, v in metric.items(): 

171 zm[ry[k[1]], rx[k[0]]] = v 

172 

173 xs = [str(_) for _ in x] 

174 ys = [str(_) for _ in y] 

175 vmin = min(metric.values()) 

176 vmax = max(metric.values()) 

177 if middle is not None: 

178 v1 = middle / vmin 

179 v2 = middle / vmax 

180 vmin = min(vmin, v2) 

181 vmax = max(vmax, v1) 

182 ax, im, cbar = heatmap(zm, ys, xs, ax=ax, cmap="bwr", 

183 norm=LogNorm(vmin=vmin, vmax=vmax), 

184 cbarlabel=cbarlabel, cbar_kw=cbar_kw) 

185 annotate_heatmap(im, valfmt=valfmt) 

186 if xlabel is not None: 

187 ax.set_xlabel(xlabel) 

188 if ylabel is not None: 

189 ax.set_ylabel(ylabel) 

190 return ax, cbar