Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Plotting for benchmarks. 

4""" 

5from .plot_helper import list_col_options, filter_df_options, options2label 

6from .plot_helper import ax_position, plt_colors, plt_styles, move_color, move_color_add 

7from ..benchmark.bench_helper import remove_almost_nan_columns 

8 

9 

10def plot_bench_results(df, row_cols=None, col_cols=None, hue_cols=None, # pylint: disable=R0914 

11 cmp_col_values=('lib', 'skl'), 

12 x_value='N', y_value='mean', 

13 err_value=('lower', 'upper'), 

14 title=None, box_side=6, labelsize=8, 

15 fontsize="small", label_fct=None, 

16 color_fct=None, ax=None): 

17 """ 

18 Plots benchmark results. 

19 

20 @param df benchmark results 

21 @param row_cols dataframe columns for graph rows 

22 @param col_cols dataframe columns for graph columns 

23 @param hue_cols dataframe columns for other options 

24 @param cmp_col_values if can be one column or one tuple ``(column, baseline name)`` 

25 @param x_value value for x-axis 

26 @param y_value value to plot on y-axis (such as *mean*, *min*, ...) 

27 @param err_value lower and upper bounds 

28 @param title graph title 

29 @param box_side graph side, the function adjusts the size of the graph 

30 @param labelsize size of the labels 

31 @param fontsize font size see `Text properties 

32 <https://matplotlib.org/api/text_api.html#matplotlib.text.Text>`_ 

33 @param label_fct if not None, it is a function which 

34 modifies the label before printing it on the graph 

35 @param color_fct if not None, it is a function which modifies 

36 a color based on the label and the previous color 

37 @param ax existing axis 

38 @return fig, ax 

39 

40 .. exref:: 

41 :title: Plot benchmark results 

42 

43 .. plot:: 

44 

45 from pymlbenchmark.datasets import experiment_results 

46 from pymlbenchmark.plotting import plot_bench_results 

47 import matplotlib.pyplot as plt 

48 

49 df = experiment_results('onnxruntime_LogisticRegression') 

50 

51 plot_bench_results(df, row_cols='N', col_cols='method', 

52 x_value='dim', hue_cols='fit_intercept', 

53 title="LogisticRegression\\nBenchmark scikit-learn / onnxruntime") 

54 plt.show() 

55 """ 

56 if label_fct is None: 

57 def label_fct_(x): 

58 return x 

59 label_fct = label_fct_ 

60 

61 if color_fct is None: 

62 def color_fct_(la, col): 

63 return col 

64 color_fct = color_fct_ 

65 

66 import matplotlib.pyplot as plt # pylint: disable=C0415 

67 if not isinstance(row_cols, (tuple, list)): 

68 row_cols = [row_cols] 

69 if not isinstance(col_cols, (tuple, list)): 

70 col_cols = [col_cols] 

71 if not isinstance(hue_cols, (tuple, list)): 

72 hue_cols = [hue_cols] 

73 

74 all_cols = set(_ for _ in ( 

75 list(row_cols) + list(col_cols) + list(hue_cols)) if _ is not None) 

76 df = remove_almost_nan_columns(df, keep=all_cols) 

77 lrows_options = list_col_options(df, row_cols) 

78 lcols_options = list_col_options(df, col_cols) 

79 lhues_options = list_col_options(df, hue_cols) 

80 

81 shape = (len(lrows_options), len(lcols_options)) 

82 if ax is None: 

83 figsize = (shape[1] * box_side, shape[0] * box_side) 

84 fig, ax = plt.subplots(shape[0], shape[1], figsize=figsize) 

85 elif ax.shape != shape: 

86 raise RuntimeError( # pragma: no cover 

87 "Shape mismatch ax.shape={} when expected values is {}".format( 

88 ax.shape, shape)) 

89 else: 

90 fig = plt.gcf() 

91 colors = plt_colors() 

92 styles = plt_styles() 

93 

94 nb_empty = 0 

95 nb_total = 0 

96 for row, row_opt in enumerate(lrows_options): 

97 

98 sub = filter_df_options(df, row_opt) 

99 nb_total += 1 

100 if sub.shape[0] == 0: 

101 nb_empty += 1 

102 continue 

103 legy = options2label(row_opt, sep="\n") 

104 

105 for col, col_opt in enumerate(lcols_options): 

106 sub2 = filter_df_options(sub, col_opt) 

107 if sub2.shape[0] == 0: 

108 continue 

109 legx = options2label(col_opt, sep="\n") 

110 

111 pos = ax_position(shape, (row, col)) 

112 a = ax[pos] if pos else ax 

113 

114 for color, hue_opt in zip(colors, lhues_options): 

115 ds = filter_df_options(sub2, hue_opt) 

116 if ds.shape[0] == 0: 

117 continue 

118 legh = options2label(hue_opt, sep="\n") 

119 

120 if isinstance(cmp_col_values, tuple): 

121 y_cols = [x_value, cmp_col_values[0], y_value] 

122 if err_value is not None: 

123 lower_cols = [x_value, cmp_col_values[0], err_value[0]] 

124 upper_cols = [x_value, cmp_col_values[0], err_value[1]] 

125 elif cmp_col_values is not None: 

126 y_cols = [x_value, cmp_col_values, y_value] 

127 if err_value is not None: 

128 lower_cols = [x_value, cmp_col_values, err_value[0]] 

129 upper_cols = [x_value, cmp_col_values, err_value[1]] 

130 else: 

131 raise ValueError( # pragma: no cover 

132 "cmp_col_values cannot be None.") 

133 

134 try: 

135 piv = ds.pivot(*y_cols) 

136 except ValueError as e: # pragma no cover 

137 raise ValueError("Unable to compute a pivot on columns {}\nAvailable: {}\n{}".format( 

138 y_cols, list(df.columns), ds[y_cols].head(n=10))) from e 

139 except KeyError as e: # pragma no cover 

140 raise ValueError( 

141 "Unable to find columns {} in {}".format( 

142 y_cols, ds.columns)) from e 

143 if lower_cols is not None: 

144 try: 

145 lower_piv = ds.pivot(*lower_cols) 

146 except ValueError as e: # pragma no cover 

147 raise ValueError("Unable to compute a pivot on columns {}\n{}".format( 

148 lower_cols, ds[lower_cols].head())) from e 

149 except KeyError as e: # pragma no cover 

150 raise ValueError("Unable to find columns {} in {}".format( 

151 lower_cols, ds.columns)) from e 

152 else: 

153 lower_piv = None 

154 if upper_cols is not None: 

155 try: 

156 upper_piv = ds.pivot(*upper_cols) 

157 except ValueError as e: # pragma no cover 

158 raise ValueError("Unable to compute a pivot on columns {}\n{}".format( 

159 upper_cols, ds[upper_cols].head())) from e 

160 except KeyError as e: # pragma no cover 

161 raise ValueError("Unable to find columns {} in {}".format( 

162 upper_cols, ds.columns)) from e 

163 else: 

164 upper_piv = None 

165 ys = list(piv.columns) 

166 

167 piv = piv.reset_index(drop=False) 

168 if upper_piv is not None: 

169 upper_piv = upper_piv.reset_index(drop=False) 

170 if lower_piv is not None: 

171 lower_piv = lower_piv.reset_index(drop=False) 

172 

173 for i, ly in enumerate(ys): 

174 if hue_opt is None: 

175 color = colors[i] 

176 la = "{}-{}".format(ly, legh) if legh != '-' else ly 

177 color_ = color_fct(la, color) 

178 if upper_piv is not None and lower_piv is not None: 

179 a.fill_between(piv[x_value], lower_piv[ly], upper_piv[ly], 

180 color=color_, alpha=0.1) 

181 elif upper_piv is not None: 

182 a.fill_between(piv[x_value], piv[ly], upper_piv[ly], 

183 color=color_, alpha=0.1) 

184 elif lower_piv is not None: 

185 a.fill_between(piv[x_value], lower_piv[ly], piv[ly], 

186 color=color_, alpha=0.1) 

187 

188 for i, (ly, style) in enumerate(zip(ys, styles)): 

189 if hue_opt is None: 

190 color = colors[i] 

191 la = "{}-{}".format(ly, legh) if legh != '-' else ly 

192 color_ = color_fct(la, color) 

193 lw = 4. if ly == cmp_col_values[1] else 1.5 

194 ms = lw * 3 

195 nc_add = move_color_add(style[0]) 

196 nc = move_color(color_, nc_add) 

197 piv.plot(x=x_value, y=ly, ax=a, marker=style[0], 

198 style=style[1], logx=True, logy=True, 

199 c=nc, lw=lw, ms=ms, 

200 label=label_fct(la)) 

201 

202 a.legend(loc=0, fontsize=fontsize) 

203 a.set_xlabel(label_fct("{}\n{}".format(x_value, legx)) 

204 if row == shape[0] - 1 else "", 

205 fontsize=fontsize) 

206 a.set_ylabel(label_fct("{}\n{}".format(legy, y_value)) 

207 if col == 0 else "", fontsize=fontsize) 

208 if row == 0: 

209 a.set_title(legx, fontsize=fontsize) 

210 a.tick_params(labelsize=labelsize) 

211 for tick in a.yaxis.get_majorticklabels(): 

212 tick.set_fontsize(labelsize) 

213 for tick in a.xaxis.get_majorticklabels(): 

214 tick.set_fontsize(labelsize) 

215 plt.setp(a.get_xminorticklabels(), visible=False) 

216 plt.setp(a.get_yminorticklabels(), visible=False) 

217 

218 if nb_empty == nb_total: # pragma no cover 

219 raise RuntimeError("All graphs are empty for dataframe,\nrow_cols={}," 

220 "\ncol_cols={},\nhue_cols={},\ncolumns={}".format( 

221 row_cols, col_cols, hue_cols, df.columns)) 

222 if title is not None: 

223 fig.suptitle(title, fontsize=labelsize) 

224 return ax