Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Statistics on :epkg:`ONNX` models.
4"""
5from collections import Counter
6from onnx.helper import make_graph
7from onnx import ValueInfoProto
8from skl2onnx.common._topology import Variable
9from ._onnx_optimisation_common import _apply_optimisation_on_graph
10from .onnx_optimisation import onnx_remove_node
13def onnx_statistics(onnx_model, recursive=True, optim=True):
14 """
15 Computes statistics on :epkg:`ONNX` models,
16 extracts informations about the model such as
17 the number of nodes.
19 @param onnx_model onnx model
20 @param recursive looks into subgraphs
21 @param optim adds statistics because of optimisation
22 @return dictionary
24 .. runpython::
25 :showcode:
26 :warningout: DeprecationWarning
28 import pprint
29 from sklearn.linear_model import LogisticRegression
30 from sklearn.ensemble import RandomForestClassifier
31 from sklearn.datasets import load_iris
32 from mlprodict.onnxrt.optim.onnx_helper import onnx_statistics
33 from mlprodict.onnx_conv import to_onnx
35 iris = load_iris()
36 X = iris.data
37 y = iris.target
38 lr = LogisticRegression()
39 lr.fit(X, y)
40 onx = to_onnx(lr, X[:1])
41 pprint.pprint((lr, onnx_statistics(onx)))
43 iris = load_iris()
44 X = iris.data
45 y = iris.target
46 rf = RandomForestClassifier()
47 rf.fit(X, y)
48 onx = to_onnx(rf, X[:1], target_opset=12)
49 pprint.pprint((rf, onnx_statistics(onx)))
50 """
51 atts = ['doc_string', 'ir_version', 'metadata_props', 'domain',
52 'model_version', 'producer_name', 'producer_version']
54 def update(sts, st):
55 for k, v in st.items():
56 if k in ['size'] or k in atts:
57 continue # pragma: no cover
58 if k in sts:
59 sts[k] += v
60 else:
61 sts[k] = v
63 if hasattr(onnx_model, 'graph'):
64 content = onnx_model.SerializeToString()
65 nnodes = len(onnx_model.graph.node)
66 ninits = len(onnx_model.graph.initializer)
67 stats = {'size': len(content), 'nnodes': nnodes, 'ninits': ninits}
68 for a in atts:
69 v = getattr(onnx_model, a)
70 if isinstance(v, str):
71 li = None
72 else:
73 try:
74 li = list(v)
75 except TypeError:
76 li = None
77 if li is not None and len(li) == 0:
78 continue
79 stats[a] = v
81 for opi in onnx_model.opset_import:
82 stats[opi.domain] = opi.version
84 graph = onnx_model.graph
85 elif not hasattr(onnx_model, 'node'): # pragma: no cover
86 # We're in a node.
87 stats = {'nnodes': 1}
88 if hasattr(onnx_model, 'attribute') and onnx_model.attribute:
89 for att in onnx_model.attribute:
90 if att.name == 'body':
91 st = onnx_statistics(att.g)
92 update(stats, st)
93 return stats
94 else:
95 graph = onnx_model
96 nnodes = len(graph.node)
97 stats = {'nnodes': nnodes}
99 # Number of identities
100 counts = Counter(map(lambda obj: obj.op_type, graph.node))
101 for op in ['Cast', 'Identity', 'ZipMap', 'Reshape']:
102 if op in counts:
103 stats['op_' + op] = counts[op]
105 # Recursive
106 if recursive:
107 for node in graph.node:
108 if not hasattr(node, 'attribute'):
109 continue # pragma: no cover
110 for att in node.attribute:
111 if att.name != 'body':
112 continue
113 substats = onnx_statistics(att.g, recursive=True, optim=False)
114 update(stats, {'subgraphs': 1})
115 update(stats, substats)
117 # optimisation: remove_identity nodes
118 if optim:
119 new_model = onnx_remove_node(
120 onnx_model, recursive=recursive)
121 st = onnx_statistics(new_model, recursive=recursive, optim=False)
122 for key in ["op_Identity", "subgraphs", "size",
123 "nnodes", "ninits"]:
124 if key in st:
125 stats[key + "_optim"] = st[key]
126 return stats
129def change_input_first_dimension(onnx_model, N=None, debug_info=None):
130 """
131 Some models are converted under the assumption
132 batch prediction is not necessary. This function
133 changes the first dimension of an ONNX graph.
135 @param onnx_model model :epkg:`onnx`
136 @param N new first dimension,
137 None to avoid changing it,
138 0 to fix an undefined
139 first dimension
140 @param debug_info unused
141 @return modified model onnx
142 """
143 def _make_value_info(variable):
144 value_info = ValueInfoProto()
145 value_info.name = variable.full_name
146 value_info.type.CopyFrom( # pylint: disable=E1101
147 variable.type.to_onnx_type()) # pylint: disable=E1101
148 if variable.type.doc_string: # pylint: disable=E0611
149 value_info.doc_string = variable.type.doc_string # pragma: no cover
150 return value_info
152 if hasattr(onnx_model, 'graph'):
153 return _apply_optimisation_on_graph(
154 change_input_first_dimension, onnx_model, N=N)
156 graph = onnx_model
158 nodes = graph.node
159 inputs = [Variable.from_pb(input) for input in onnx_model.input]
160 outputs = onnx_model.output
162 if N <= 0:
163 N = None
164 for input in inputs:
165 input.type.shape[0] = N
166 inputs = [_make_value_info(v) for v in inputs]
168 graph = make_graph(nodes, onnx_model.name,
169 inputs, outputs, onnx_model.initializer)
171 graph.value_info.extend(onnx_model.value_info) # pylint: disable=E1101
172 return graph