Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Functions to help visualizing performances. 

4""" 

5import numpy 

6import pandas 

7 

8 

9def _model_name(name): 

10 """ 

11 Extracts the main component of a model, removes 

12 suffixes such ``Classifier``, ``Regressor``, ``CV``. 

13 

14 @param name string 

15 @return shorter string 

16 """ 

17 if name.startswith("Select"): 

18 return "Select" 

19 if name.startswith("Nu"): 

20 return "Nu" 

21 modif = 1 

22 while modif > 0: 

23 modif = 0 

24 for suf in ['Classifier', 'Regressor', 'CV', 'IC', 

25 'Transformer']: 

26 if name.endswith(suf): 

27 name = name[:-len(suf)] 

28 modif += 1 

29 return name 

30 

31 

32def plot_validate_benchmark(df): 

33 """ 

34 Plots a graph which summarizes the performances of a benchmark 

35 validating a runtime for :epkg:`ONNX`. 

36 

37 @param df output of function @see fn summary_report 

38 @return fig, ax 

39 

40 .. plot:: 

41 

42 from logging import getLogger 

43 from pandas import DataFrame 

44 import matplotlib.pyplot as plt 

45 from mlprodict.onnxrt.validate import enumerate_validated_operator_opsets, summary_report 

46 from mlprodict.tools.plotting import plot_validate_benchmark 

47 

48 logger = getLogger('skl2onnx') 

49 logger.disabled = True 

50 

51 rows = list(enumerate_validated_operator_opsets( 

52 verbose=0, models={"LinearRegression"}, opset_min=11, 

53 runtime=['python', 'onnxruntime1'], debug=False, 

54 benchmark=True, n_features=[None, 10])) 

55 

56 df = DataFrame(rows) 

57 piv = summary_report(df) 

58 fig, ax = plot_validate_benchmark(piv) 

59 plt.show() 

60 """ 

61 import matplotlib.pyplot as plt 

62 

63 if 'n_features' not in df.columns: 

64 df["n_features"] = numpy.nan # pragma: no cover 

65 if 'runtime' not in df.columns: 

66 df['runtime'] = '?' # pragma: no cover 

67 

68 fmt = "{} [{}-{}|{}] D{}" 

69 df["label"] = df.apply( 

70 lambda row: fmt.format( 

71 row["name"], row["problem"], row["scenario"], 

72 row['optim'], row["n_features"]).replace("-default|", "-**]"), axis=1) 

73 df = df.sort_values(["name", "problem", "scenario", "optim", 

74 "n_features", "runtime"], 

75 ascending=False).reset_index(drop=True).copy() 

76 indices = ['label', 'runtime'] 

77 values = [c for c in df.columns 

78 if 'N=' in c and '-min' not in c and '-max' not in c] 

79 try: 

80 df = df[indices + values] 

81 except KeyError as e: # pragma: no cover 

82 raise RuntimeError( 

83 "Unable to find the following columns {}\nin {}".format( 

84 indices + values, df.columns)) from e 

85 

86 if 'RT/SKL-N=1' not in df.columns: 

87 raise RuntimeError( # pragma: no cover 

88 "Column 'RT/SKL-N=1' is missing, benchmark was probably not run.") 

89 na = df["RT/SKL-N=1"].isnull() 

90 dfp = df[~na] 

91 runtimes = list(sorted(set(dfp['runtime']))) 

92 final = None 

93 for rt in runtimes: 

94 sub = dfp[dfp.runtime == rt].drop('runtime', axis=1).copy() 

95 col = list(sub.columns) 

96 for i in range(1, len(col)): 

97 col[i] += "__" + rt 

98 sub.columns = col 

99 

100 if final is None: 

101 final = sub 

102 else: 

103 final = final.merge(sub, on='label', how='outer') 

104 

105 # let's add average and median 

106 ncol = (final.shape[1] - 1) // len(runtimes) 

107 if len(runtimes) + 1 > final.shape[0]: 

108 dfp_legend = final.iloc[:len(runtimes) + 1, :].copy() 

109 while dfp_legend.shape[0] < len(runtimes) + 1: 

110 dfp_legend = pandas.concat([dfp_legend, dfp_legend[:1]]) 

111 else: 

112 dfp_legend = final.iloc[:len(runtimes) + 1, :].copy() 

113 rleg = dfp_legend.copy() 

114 dfp_legend.iloc[:, 1:] = numpy.nan 

115 rleg.iloc[:, 1:] = numpy.nan 

116 

117 for r, runt in enumerate(runtimes): 

118 sli = slice(1 + ncol * r, 1 + ncol * r + ncol) 

119 cm = final.iloc[:, sli].mean().values 

120 dfp_legend.iloc[r + 1, sli] = cm 

121 rleg.iloc[r, sli] = final.iloc[:, sli].median() 

122 dfp_legend.iloc[r + 1, 0] = "avg_" + runt 

123 rleg.iloc[r, 0] = "med_" + runt 

124 dfp_legend.iloc[0, 0] = "------" 

125 rleg.iloc[-1, 0] = "------" 

126 

127 # sort 

128 final = final.sort_values('label', ascending=False).copy() 

129 

130 # add global statistics 

131 final = pandas.concat([rleg, final, dfp_legend]).reset_index(drop=True) 

132 

133 # graph beginning 

134 total = final.shape[0] * 0.45 

135 fig, ax = plt.subplots(1, len(values), figsize=(14, total), 

136 sharex=False, sharey=True) 

137 x = numpy.arange(final.shape[0]) 

138 subh = 1.0 / len(runtimes) 

139 height = total / final.shape[0] * (subh + 0.1) 

140 decrt = {rt: height * i for i, rt in enumerate(runtimes)} 

141 colors = {rt: c for rt, c in zip( 

142 runtimes, ['blue', 'orange', 'cyan', 'yellow'])} 

143 

144 # draw lines between models 

145 vals = final.iloc[:, 1:].values.ravel() 

146 xlim = [min(0.5, min(vals)), max(2, max(vals))] 

147 while i < final.shape[0] - 1: 

148 i += 1 

149 label = final.iloc[i, 0] 

150 if '[' not in label: 

151 continue 

152 prev = final.iloc[i - 1, 0] 

153 if '[' not in label: 

154 continue # pragma: no cover 

155 label = label.split()[0] 

156 prev = prev.split()[0] 

157 if _model_name(label) == _model_name(prev): 

158 continue 

159 

160 blank = final.iloc[:1, :].copy() 

161 blank.iloc[0, 0] = '------' 

162 blank.iloc[0, 1:] = xlim[0] 

163 final = pandas.concat([final[:i], blank, final[i:]]) 

164 i += 1 

165 

166 final = final.reset_index(drop=True).copy() 

167 x = numpy.arange(final.shape[0]) 

168 

169 done = set() 

170 for c in final.columns[1:]: 

171 place, runtime = c.split('__') 

172 if hasattr(ax, 'shape'): 

173 index = values.index(place) 

174 if (index, runtime) in done: 

175 raise RuntimeError( # pragma: no cover 

176 "Issue with column '{}'\nlabels={}\nruntimes={}\ncolumns=" 

177 "{}\nvalues={}\n{}".format( 

178 c, list(final.label), runtimes, final.columns, values, final)) 

179 axi = ax[index] 

180 done.add((index, runtime)) 

181 else: 

182 if (0, runtime) in done: # pragma: no cover 

183 raise RuntimeError( 

184 "Issue with column '{}'\nlabels={}\nruntimes={}\ncolumns=" 

185 "{}\nvalues={}\n{}".format( 

186 c, final.label, runtimes, final.columns, values, final)) 

187 done.add((0, runtime)) # pragma: no cover 

188 axi = ax # pragma: no cover 

189 if c in final.columns: 

190 yl = final.loc[:, c] 

191 xl = x + decrt[runtime] / 2 

192 axi.barh(xl, yl, label=runtime, height=height, 

193 color=colors[runtime]) 

194 axi.set_title(place) 

195 

196 def _plot_axis(axi, x, xlim): 

197 axi.plot([1, 1], [0, max(x)], 'g-') 

198 axi.plot([2, 2], [0, max(x)], 'r--') 

199 axi.set_xlim(xlim) 

200 axi.set_xscale('log') 

201 axi.set_ylim([min(x) - 2, max(x) + 1]) 

202 

203 def _plot_final(axi, x, final): 

204 axi.set_yticks(x) 

205 axi.set_yticklabels(final['label']) 

206 

207 if hasattr(ax, 'shape'): 

208 for i in range(len(ax)): # pylint: disable=C0200 

209 _plot_axis(ax[i], x, xlim) 

210 

211 ax[min(ax.shape[0] - 1, 2)].legend() 

212 _plot_final(ax[0], x, final) 

213 else: # pragma: no cover 

214 _plot_axis(ax, x, xlim) 

215 _plot_final(ax, x, final) 

216 ax.legend() 

217 

218 fig.subplots_adjust(left=0.25) 

219 return fig, ax