Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Functions to help visualizing performances.
4"""
5import numpy
6import pandas
9def _model_name(name):
10 """
11 Extracts the main component of a model, removes
12 suffixes such ``Classifier``, ``Regressor``, ``CV``.
14 @param name string
15 @return shorter string
16 """
17 if name.startswith("Select"):
18 return "Select"
19 if name.startswith("Nu"):
20 return "Nu"
21 modif = 1
22 while modif > 0:
23 modif = 0
24 for suf in ['Classifier', 'Regressor', 'CV', 'IC',
25 'Transformer']:
26 if name.endswith(suf):
27 name = name[:-len(suf)]
28 modif += 1
29 return name
32def plot_validate_benchmark(df):
33 """
34 Plots a graph which summarizes the performances of a benchmark
35 validating a runtime for :epkg:`ONNX`.
37 @param df output of function @see fn summary_report
38 @return fig, ax
40 .. plot::
42 from logging import getLogger
43 from pandas import DataFrame
44 import matplotlib.pyplot as plt
45 from mlprodict.onnxrt.validate import enumerate_validated_operator_opsets, summary_report
46 from mlprodict.tools.plotting import plot_validate_benchmark
48 logger = getLogger('skl2onnx')
49 logger.disabled = True
51 rows = list(enumerate_validated_operator_opsets(
52 verbose=0, models={"LinearRegression"}, opset_min=11,
53 runtime=['python', 'onnxruntime1'], debug=False,
54 benchmark=True, n_features=[None, 10]))
56 df = DataFrame(rows)
57 piv = summary_report(df)
58 fig, ax = plot_validate_benchmark(piv)
59 plt.show()
60 """
61 import matplotlib.pyplot as plt
63 if 'n_features' not in df.columns:
64 df["n_features"] = numpy.nan # pragma: no cover
65 if 'runtime' not in df.columns:
66 df['runtime'] = '?' # pragma: no cover
68 fmt = "{} [{}-{}|{}] D{}"
69 df["label"] = df.apply(
70 lambda row: fmt.format(
71 row["name"], row["problem"], row["scenario"],
72 row['optim'], row["n_features"]).replace("-default|", "-**]"), axis=1)
73 df = df.sort_values(["name", "problem", "scenario", "optim",
74 "n_features", "runtime"],
75 ascending=False).reset_index(drop=True).copy()
76 indices = ['label', 'runtime']
77 values = [c for c in df.columns
78 if 'N=' in c and '-min' not in c and '-max' not in c]
79 try:
80 df = df[indices + values]
81 except KeyError as e: # pragma: no cover
82 raise RuntimeError(
83 "Unable to find the following columns {}\nin {}".format(
84 indices + values, df.columns)) from e
86 if 'RT/SKL-N=1' not in df.columns:
87 raise RuntimeError( # pragma: no cover
88 "Column 'RT/SKL-N=1' is missing, benchmark was probably not run.")
89 na = df["RT/SKL-N=1"].isnull()
90 dfp = df[~na]
91 runtimes = list(sorted(set(dfp['runtime'])))
92 final = None
93 for rt in runtimes:
94 sub = dfp[dfp.runtime == rt].drop('runtime', axis=1).copy()
95 col = list(sub.columns)
96 for i in range(1, len(col)):
97 col[i] += "__" + rt
98 sub.columns = col
100 if final is None:
101 final = sub
102 else:
103 final = final.merge(sub, on='label', how='outer')
105 # let's add average and median
106 ncol = (final.shape[1] - 1) // len(runtimes)
107 if len(runtimes) + 1 > final.shape[0]:
108 dfp_legend = final.iloc[:len(runtimes) + 1, :].copy()
109 while dfp_legend.shape[0] < len(runtimes) + 1:
110 dfp_legend = pandas.concat([dfp_legend, dfp_legend[:1]])
111 else:
112 dfp_legend = final.iloc[:len(runtimes) + 1, :].copy()
113 rleg = dfp_legend.copy()
114 dfp_legend.iloc[:, 1:] = numpy.nan
115 rleg.iloc[:, 1:] = numpy.nan
117 for r, runt in enumerate(runtimes):
118 sli = slice(1 + ncol * r, 1 + ncol * r + ncol)
119 cm = final.iloc[:, sli].mean().values
120 dfp_legend.iloc[r + 1, sli] = cm
121 rleg.iloc[r, sli] = final.iloc[:, sli].median()
122 dfp_legend.iloc[r + 1, 0] = "avg_" + runt
123 rleg.iloc[r, 0] = "med_" + runt
124 dfp_legend.iloc[0, 0] = "------"
125 rleg.iloc[-1, 0] = "------"
127 # sort
128 final = final.sort_values('label', ascending=False).copy()
130 # add global statistics
131 final = pandas.concat([rleg, final, dfp_legend]).reset_index(drop=True)
133 # graph beginning
134 total = final.shape[0] * 0.45
135 fig, ax = plt.subplots(1, len(values), figsize=(14, total),
136 sharex=False, sharey=True)
137 x = numpy.arange(final.shape[0])
138 subh = 1.0 / len(runtimes)
139 height = total / final.shape[0] * (subh + 0.1)
140 decrt = {rt: height * i for i, rt in enumerate(runtimes)}
141 colors = {rt: c for rt, c in zip(
142 runtimes, ['blue', 'orange', 'cyan', 'yellow'])}
144 # draw lines between models
145 vals = final.iloc[:, 1:].values.ravel()
146 xlim = [min(0.5, min(vals)), max(2, max(vals))]
147 while i < final.shape[0] - 1:
148 i += 1
149 label = final.iloc[i, 0]
150 if '[' not in label:
151 continue
152 prev = final.iloc[i - 1, 0]
153 if '[' not in label:
154 continue # pragma: no cover
155 label = label.split()[0]
156 prev = prev.split()[0]
157 if _model_name(label) == _model_name(prev):
158 continue
160 blank = final.iloc[:1, :].copy()
161 blank.iloc[0, 0] = '------'
162 blank.iloc[0, 1:] = xlim[0]
163 final = pandas.concat([final[:i], blank, final[i:]])
164 i += 1
166 final = final.reset_index(drop=True).copy()
167 x = numpy.arange(final.shape[0])
169 done = set()
170 for c in final.columns[1:]:
171 place, runtime = c.split('__')
172 if hasattr(ax, 'shape'):
173 index = values.index(place)
174 if (index, runtime) in done:
175 raise RuntimeError( # pragma: no cover
176 "Issue with column '{}'\nlabels={}\nruntimes={}\ncolumns="
177 "{}\nvalues={}\n{}".format(
178 c, list(final.label), runtimes, final.columns, values, final))
179 axi = ax[index]
180 done.add((index, runtime))
181 else:
182 if (0, runtime) in done: # pragma: no cover
183 raise RuntimeError(
184 "Issue with column '{}'\nlabels={}\nruntimes={}\ncolumns="
185 "{}\nvalues={}\n{}".format(
186 c, final.label, runtimes, final.columns, values, final))
187 done.add((0, runtime)) # pragma: no cover
188 axi = ax # pragma: no cover
189 if c in final.columns:
190 yl = final.loc[:, c]
191 xl = x + decrt[runtime] / 2
192 axi.barh(xl, yl, label=runtime, height=height,
193 color=colors[runtime])
194 axi.set_title(place)
196 def _plot_axis(axi, x, xlim):
197 axi.plot([1, 1], [0, max(x)], 'g-')
198 axi.plot([2, 2], [0, max(x)], 'r--')
199 axi.set_xlim(xlim)
200 axi.set_xscale('log')
201 axi.set_ylim([min(x) - 2, max(x) + 1])
203 def _plot_final(axi, x, final):
204 axi.set_yticks(x)
205 axi.set_yticklabels(final['label'])
207 if hasattr(ax, 'shape'):
208 for i in range(len(ax)): # pylint: disable=C0200
209 _plot_axis(ax[i], x, xlim)
211 ax[min(ax.shape[0] - 1, 2)].legend()
212 _plot_final(ax[0], x, final)
213 else: # pragma: no cover
214 _plot_axis(ax, x, xlim)
215 _plot_final(ax, x, final)
216 ax.legend()
218 fig.subplots_adjust(left=0.25)
219 return fig, ax