Coverage for src/pymlbenchmark/plotting/plot_bench_xtime.py: 93%

109 statements  

« prev     ^ index     » next       coverage.py v7.2.1, created at 2023-03-08 00:27 +0100

1""" 

2@file 

3@brief Plotting for benchmarks. 

4""" 

5from .plot_helper import list_col_options, filter_df_options, options2label 

6from .plot_helper import ax_position, plt_colors, move_color, remove_common_prefix 

7from ..benchmark.bench_helper import remove_almost_nan_columns 

8 

9 

10def plot_bench_xtime(df, row_cols=None, col_cols=None, hue_cols=None, 

11 cmp_col_values=('lib', 'skl'), 

12 x_value='mean', y_value='xtime', 

13 parallel=(1., 0.5), title=None, 

14 box_side=6, labelsize=10, 

15 fontsize="small", label_fct=None, 

16 color_fct=None, ax=None): 

17 """ 

18 Plots benchmark acceleration. 

19 

20 @param df benchmark results 

21 @param row_cols dataframe columns for graph rows 

22 @param col_cols dataframe columns for graph columns 

23 @param hue_cols dataframe columns for other options 

24 @param cmp_col_values it can be one column or one tuple 

25 ``(column, baseline name)`` 

26 @param x_value value for x-axis 

27 @param y_value value to plot on y-axis (such as *mean*, *min*, ...) 

28 @param parallel lower and upper bounds 

29 @param title graph title 

30 @param box_side graph side, the function adjusts the 

31 size of the graph 

32 @param labelsize size of the labels 

33 @param fontsize font size see `Text properties 

34 <https://matplotlib.org/api/ 

35 text_api.html#matplotlib.text.Text>`_ 

36 @param ax existing axis 

37 @param label_fct if not None, it is a function which 

38 modifies the label before printing it on the graph 

39 @param color_fct if not None, it is a function which modifies 

40 a color based on the label and the previous color 

41 @return fig, ax 

42 

43 .. exref:: 

44 :title: Plot benchmark improvments 

45 

46 .. plot:: 

47 

48 from pymlbenchmark.datasets import experiment_results 

49 from pymlbenchmark.plotting import plot_bench_xtime 

50 import matplotlib.pyplot as plt 

51 

52 df = experiment_results('onnxruntime_LogisticRegression') 

53 

54 plot_bench_xtime(df, row_cols='N', col_cols='method', 

55 hue_cols='fit_intercept', 

56 title="LogisticRegression\\nAcceleration scikit-learn / onnxruntime") 

57 plt.show() 

58 """ 

59 if label_fct is None: 

60 def label_fct_(x): 

61 return x 

62 label_fct = label_fct_ 

63 

64 if color_fct is None: 

65 def color_fct_(la, col): 

66 return col 

67 color_fct = color_fct_ 

68 

69 import matplotlib.pyplot as plt # pylint: disable=C0415 

70 if not isinstance(row_cols, (tuple, list)): 

71 row_cols = [row_cols] 

72 if not isinstance(col_cols, (tuple, list)): 

73 col_cols = [col_cols] 

74 if not isinstance(hue_cols, (tuple, list)): 

75 hue_cols = [hue_cols] 

76 

77 all_cols = set(_ for _ in ( 

78 list(row_cols) + list(col_cols) + list(hue_cols)) if _ is not None) 

79 df = remove_almost_nan_columns(df, keep=all_cols) 

80 lrows_options = list_col_options(df, row_cols) 

81 lcols_options = list_col_options(df, col_cols) 

82 lhues_options = list_col_options(df, hue_cols) 

83 

84 shape = (len(lrows_options), len(lcols_options)) 

85 shape2 = shape if shape[0] > 1 else shape[1:] 

86 if ax is None: 

87 figsize = (shape[1] * box_side, shape[0] * box_side) 

88 fig, ax = plt.subplots(shape[0], shape[1], figsize=figsize) 

89 elif not hasattr(ax, 'shape') or ax.shape not in (shape, shape2): 

90 raise RuntimeError( # pragma: no cover 

91 "Shape mismatch ax.shape={} when expected values is {} or {}".format( 

92 getattr(ax, 'shape', None), shape, shape2)) 

93 else: 

94 fig = plt.gcf() 

95 colors = plt_colors() 

96 

97 if isinstance(cmp_col_values, str): 

98 values = tuple(sorted(set(df[cmp_col_values].dropna()))) 

99 baseline = [v for v in values if v in { 

100 'no', 'base', 'baseline', 'skl'}] 

101 bl = baseline[0] if len(baseline) > 0 else values[0] 

102 cmp_col_values = (cmp_col_values, bl) 

103 

104 dropc = "lower,max,max3,mean,median,min,min3,repeat,upper".split(',') 

105 dropc = [c for c in dropc if c not in [ 

106 x_value, y_value] and c in df.columns] 

107 df = df.drop(dropc, axis=1) 

108 index = [c for c in df.columns if c not in [ 

109 x_value, y_value, cmp_col_values[0]]] 

110 piv = df.pivot_table(index=index, values=x_value, 

111 columns=cmp_col_values[0]) 

112 piv = piv.reset_index(drop=False) 

113 if piv.shape[0] == 0: # pragma no cover 

114 raise RuntimeError( 

115 "pivot table is empty,\nindex={},\nx_value={},\ncolumns={}," 

116 "\ndf.columns={}".format(index, x_value, cmp_col_values[0], df.columns)) 

117 vals = list(sorted(set(df[cmp_col_values[0]]))) 

118 

119 nb_empty = 0 

120 nb_total = 0 

121 for row, row_opt in enumerate(lrows_options): 

122 

123 sub = filter_df_options(piv, row_opt) 

124 nb_total += 1 

125 if sub.shape[0] == 0: 

126 nb_empty += 1 

127 continue 

128 legy = options2label(row_opt, sep="\n") 

129 

130 for col, col_opt in enumerate(lcols_options): 

131 sub2 = filter_df_options(sub, col_opt) 

132 if sub2.shape[0] == 0: 

133 continue 

134 legx = options2label(col_opt, sep="\n") 

135 

136 pos = ax_position(shape, (row, col)) 

137 a = ax[pos] if pos else ax 

138 drop_rename = [] 

139 

140 if parallel is not None: 

141 mi, ma = sub2[cmp_col_values[1]].min( 

142 ), sub2[cmp_col_values[1]].max() 

143 for p in parallel: 

144 style = '-' if p == 1 else "--" 

145 la = "%1.1fx" % (1. / p) 

146 drop_rename.append(la) 

147 a.plot([mi, ma], [p, p], style, color='black', 

148 label=label_fct(la)) 

149 

150 ic = 0 

151 for color, hue_opt in zip(colors, lhues_options): 

152 ds = filter_df_options(sub2, hue_opt).copy() 

153 if ds.shape[0] == 0: 

154 continue 

155 legh = options2label(hue_opt, sep="\n") 

156 

157 im = 0 

158 for ly in vals: 

159 if ly == cmp_col_values[1]: 

160 continue 

161 

162 ds["xtime"] = ds[ly] / ds[cmp_col_values[1]] 

163 if hue_opt is None: 

164 color = colors[ic % len(colors)] 

165 ic += 1 

166 nc = color 

167 la = "{}-{}".format(ly, legh) if legh != '-' else ly 

168 color_ = color_fct(la, color) 

169 if ly == cmp_col_values[1]: 

170 marker = 'o' 

171 nc = move_color(color_, -80) 

172 else: 

173 marker = '.x+'[im] 

174 im += 1 

175 nc = move_color(color_, 80 * (im - 1)) 

176 ds.plot(x=remove_common_prefix(cmp_col_values[1]), 

177 y=y_value, ax=a, marker=marker, 

178 logx=True, logy=True, c=nc, lw=2, 

179 label=label_fct(la), kind="scatter") 

180 

181 a.set_xlabel(label_fct("{}\n{}".format(x_value, legx)) 

182 if row == shape[0] - 1 else "", 

183 fontsize=fontsize) 

184 a.set_ylabel(label_fct("{}\n{}".format(legy, y_value)) 

185 if col == 0 else "", fontsize=fontsize) 

186 

187 leg = a.legend(loc=0, fontsize=fontsize) 

188 

189 # shortens the legend labels 

190 texts = leg.get_texts() 

191 leg_labels = remove_common_prefix([t.get_text() for t in texts], 

192 drop_rename) 

193 for t, la in zip(texts, leg_labels): 

194 t.set_text(la) 

195 

196 # changes label size 

197 a.tick_params(labelsize=labelsize) 

198 for tick in a.yaxis.get_majorticklabels(): 

199 tick.set_fontsize(labelsize) 

200 for tick in a.xaxis.get_majorticklabels(): 

201 tick.set_fontsize(labelsize) 

202 plt.setp(a.get_xminorticklabels(), visible=False) 

203 plt.setp(a.get_yminorticklabels(), visible=False) 

204 

205 if nb_empty == nb_total: # pragma no cover 

206 raise RuntimeError( 

207 "All graphs are empty for dataframe,\nrow_cols={},\ncol_cols={}," 

208 "\nhue_cols={},\ncolumns={}".format( 

209 row_cols, col_cols, hue_cols, df.columns)) 

210 if title is not None: 

211 fig.suptitle(title, fontsize=labelsize) 

212 return ax