Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Shorten code in notebook :ref:`onnxsklearnconsortiumrst`.
4"""
5import os
6import sys
7from collections import OrderedDict
8import warnings
9from pyquickhelper.pycode.profiling import profile
10from pyquickhelper.helpgen.rst_converters import docstring2html
11from pyensae.graphhelper import draw_diagram
12from jyquickhelper import RenderJsDot
13import sklearn
14from skl2onnx.proto import TensorProto
15from onnx import helper
18def graph_persistence_pickle():
19 """
20 See :ref:`onnxsklearnconsortiumrst`.
21 """
22 return draw_diagram("""
23 blockdiag {
24 default_fontsize = 20; node_width = 200; node_height = 100;
25 model[label="trained model\\nscikit-learn"];
26 pkl[label="pickled model"];
27 rest[label="restored model\\nscikit-learn", textcolor="#00AAAA"];
28 pkl -> rest;
29 pred[label="predictions"];
30 train[label="training"];
31 group {
32 label = "machine 1";
33 color = "#FFAAAA";
34 model -> pkl; pkl;
35 }
36 group {
37 label = "machine 2";
38 color = "#AAFFAA";
39 rest -> pred; rest -> train;
40 }
41 }""")
44def graph_persistence_pickle_issues():
45 """
46 See :ref:`onnxsklearnconsortiumrst`.
47 """
48 return draw_diagram("""
49 blockdiag {
50 default_fontsize = 20; node_width = 200; node_height = 100;
51 pkl[label="pickled model"];
52 rest[label="restored model\\nscikit-learn\\nUNSTABLE", textcolor="#00AAAA"];
53 pkl -> rest;
54 pred[label="predictions\\nSLOW"];
55 train[label="training"];
56 group {
57 label = "machine 1";
58 color = "#FFAAAA"; pkl;
59 }
60 group {
61 label = "machine 2";
62 color = "#AAFFAA";
63 rest -> pred; rest -> train;
64 }
65 }""")
68def graph_persistence_onnx():
69 """
70 See :ref:`onnxsklearnconsortiumrst`.
71 """
72 return draw_diagram("""
73 blockdiag {
74 default_fontsize = 20; node_width = 200; node_height = 100;
75 model[label="trained model\\nscikit-learn"];
76 onnx[label="ONNX model"];
77 rest[label="ONNX runtime", textcolor="#00AAAA"];
78 onnx -> rest;
79 pred[label="predictions"];
80 notrain[label="cannot train", color="#FF0000"];
81 group {
82 label = "machine 1";
83 color = "#FFAAAA";
84 model -> onnx[label="conversion"];
85 onnx;
86 }
87 group {
88 label = "machine 2";
89 color = "#AAFFAA";
90 rest ;
91 pred;
92 rest -> pred;
93 rest -> notrain[folded];
94 }
95 }""")
98def graph_three_components():
99 """
100 See :ref:`onnxsklearnconsortiumrst`.
101 """
102 return draw_diagram("""
103 blockdiag {
104 default_fontsize = 20; node_width = 200; node_height = 100;
105 onnx[label="ONNX\\n\\nset of mathematical functions", color="#FFFF00"];
106 conv[label="converter\\n\\nsklearn-onnx", color="#FFFF00"];
107 run[label="runtime\\n\\nonnxruntime\\nonnx.js\\n...", color="#FFFF00"];
108 onnx -> conv -> run ;
109 }""")
112def profile_fct_graph(fct, title, highlights=None, nb=20, figsize=(10, 3)):
113 """
114 Returns a graph which profiles the execution of function *fct*.
115 See :ref:`onnxsklearnconsortiumrst`.
116 """
117 paths = [os.path.dirname(sklearn.__file__),
118 "site-packages",
119 os.path.join(sys.prefix, "lib")]
120 _, df = profile(fct, as_df=True, rootrem=paths) # pylint: disable=W0632
121 colname = 'namefct' if 'namefct' in df.columns else 'fct'
122 sdf = df[[colname, 'cum_tall']].head(n=nb).set_index(colname)
123 index_list = list(sdf.index)
124 ax = sdf.plot(kind='bar', figsize=figsize, rot=30)
125 ax.set_title(title)
126 for la in ax.get_xticklabels():
127 la.set_horizontalalignment('right')
128 if highlights:
129 for lab in highlights:
130 if lab not in index_list:
131 new_labs = [ns for ns in index_list if isinstance(
132 ns, str) and lab in ns]
133 if len(new_labs) == 0:
134 raise ValueError("Unable to find '{}' in '{}'?".format(
135 lab, ", ".join(sorted(map(str, index_list)))))
136 labs = new_labs
137 else:
138 labs = [lab]
139 for la in labs:
140 pos = sdf.index.get_loc(la)
141 h = 0.15
142 ax.plot([pos - 0.35, pos - 0.35], [0, h], 'r--')
143 ax.plot([pos + 0.3, pos + 0.3], [0, h], 'r--')
144 ax.plot([pos - 0.35, pos + 0.3], [h, h], 'r--')
145 return ax
148def onnx2str(model_onnx, nrows=15):
149 """
150 Displays the beginning of an ONNX graph.
151 See :ref:`onnxsklearnconsortiumrst`.
152 """
153 lines = str(model_onnx).split('\n')
154 if len(lines) > nrows:
155 lines = lines[:nrows] + ['...']
156 return "\n".join(lines)
159def onnx2dotnb(model_onnx, width="100%", orientation="LR"):
160 """
161 Converts an ONNX graph into dot then into :epkg:`RenderJsDot`.
162 See :ref:`onnxsklearnconsortiumrst`.
163 """
164 from onnx.tools.net_drawer import GetPydotGraph, GetOpNodeProducer
165 pydot_graph = GetPydotGraph(
166 model_onnx.graph, name=model_onnx.graph.name, rankdir=orientation,
167 node_producer=GetOpNodeProducer("docstring", color="yellow",
168 fillcolor="yellow", style="filled"))
169 dot = pydot_graph.to_string()
170 return RenderJsDot(dot, width=width)
173def onnx2graph(onnx_model):
174 """
175 Converts an :epkg:`ONNX` model into a readable graph.
177 @param onnx_model onnx_model
178 @return graph defined with :epkg:`OrderedDict`
179 so that it can be processed by epkg:`asciitree`
180 """
181 vars = {}
183 for node in onnx_model.graph.node:
184 key = "%s[%s]" % (node.name, node.op_type)
185 for inp in node.input:
186 if inp not in vars:
187 vars[inp] = []
188 if key not in vars[inp]:
189 vars[inp].append(key)
190 vars[key] = []
191 for out in node.output:
192 if out not in vars[key]:
193 vars[key].append(out)
195 return edges2asciitree(vars)
198def edges2asciitree(edges):
199 """
200 Converts a set of edges into a combination
201 of :epkg:`OrderedDict` which can be understood
202 by :epkg:`asciitree`. This does not work if one node
203 has multiple inputs.
205 @param edges set of edges
206 @return :epkg:`OrderedDict`
208 .. runpython::
209 :showcode:
211 data = {'X': ['LinearClassifier[LinearClassifier]'],
212 'LinearClassifier[LinearClassifier]':
213 ['label', 'probability_tensor'],
214 'probability_tensor': ['Normalizer[Normalizer]'],
215 'Normalizer[Normalizer]': ['probabilities'],
216 'label': ['Cast[Cast]'],
217 'Cast[Cast]': ['output_label'],
218 'probabilities': ['ZipMap[ZipMap]'],
219 'ZipMap[ZipMap]': ['output_probability']}
221 from jupytalk.talk_examples.sklearn2019 import edges2asciitree
222 res = edges2asciitree(data)
224 import pprint
225 pprint.pprint(res)
227 from asciitree import LeftAligned
228 tr = LeftAligned()
229 print(tr(res))
230 """
231 roots = []
232 values = []
233 for _, eds in edges.items():
234 values.extend(eds)
235 vs = set(values)
236 for key in edges:
237 if key not in vs:
238 roots.append(key)
240 if len(roots) > 1:
241 edges = edges.copy()
242 edges['root'] = roots
243 roots = ['root']
245 res = OrderedDict()
246 find = {}
247 for r in roots:
248 res[r] = OrderedDict()
249 find[r] = res[r]
251 modif = 1
252 while modif > 0:
253 modif = 0
254 for k, eds in edges.items():
255 if k in find:
256 ord = find[k]
257 for edge in eds:
258 if edge not in ord:
259 ord[edge] = OrderedDict()
260 find[edge] = ord[edge]
261 modif += 1
263 return res
266def onnxdocstring2html(doc, start="number of targets."):
267 """
268 Converts the ONNX documentation into rst.
269 """
270 if start is not None:
271 doc = doc.split(start)[-1]
272 with warnings.catch_warnings():
273 warnings.filterwarnings("ignore")
274 return docstring2html(doc.replace("Default value is ````", ""))
277def rename_input_output(model_onnx):
278 """
279 Renames all input and output of an ONNX file.
280 """
281 def clean_name(name):
282 return name.replace("_", "")
284 def copy_inout(inout):
285 shape = [s.dim_value for s in inout.type.tensor_type.shape.dim]
286 value_info = helper.make_tensor_value_info(
287 clean_name(inout.name),
288 inout.type.tensor_type.elem_type,
289 shape)
290 return value_info
292 graph = model_onnx.graph
293 inputs = [copy_inout(o) for o in graph.input]
294 outputs = [copy_inout(o) for o in graph.output]
295 nodes = []
296 for node in graph.node:
297 n = helper.make_node(node.op_type,
298 [clean_name(o) for o in node.input],
299 [clean_name(o) for o in node.output])
300 n.attribute.extend(node.attribute) # pylint: disable=E1101
301 nodes.append(n)
303 inits = []
304 for o in graph.initializer:
305 tensor = TensorProto()
306 tensor.data_type = o.data_type
307 tensor.name = clean_name(o.name)
308 tensor.raw_data = o.raw_data
309 tensor.dims.extend(o.dims) # pylint: disable=E1101
310 inits.append(tensor)
312 graph = helper.make_graph(nodes, graph.name, inputs, outputs, inits)
313 onnx_model = helper.make_model(graph)
314 onnx_model.ir_version = model_onnx.ir_version
315 onnx_model.producer_name = model_onnx.producer_name
316 onnx_model.producer_version = model_onnx.producer_version
317 onnx_model.domain = model_onnx.domain
318 onnx_model.model_version = model_onnx.model_version
319 onnx_model.doc_string = model_onnx.doc_string
320 return onnx_model