Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Statistics on :epkg:`ONNX` models. 

4""" 

5from collections import Counter 

6from onnx.helper import make_graph 

7from onnx import ValueInfoProto 

8from skl2onnx.common._topology import Variable 

9from ._onnx_optimisation_common import _apply_optimisation_on_graph 

10from .onnx_optimisation import onnx_remove_node 

11 

12 

13def onnx_statistics(onnx_model, recursive=True, optim=True): 

14 """ 

15 Computes statistics on :epkg:`ONNX` models, 

16 extracts informations about the model such as 

17 the number of nodes. 

18 

19 @param onnx_model onnx model 

20 @param recursive looks into subgraphs 

21 @param optim adds statistics because of optimisation 

22 @return dictionary 

23 

24 .. runpython:: 

25 :showcode: 

26 :warningout: DeprecationWarning 

27 

28 import pprint 

29 from sklearn.linear_model import LogisticRegression 

30 from sklearn.ensemble import RandomForestClassifier 

31 from sklearn.datasets import load_iris 

32 from mlprodict.onnxrt.optim.onnx_helper import onnx_statistics 

33 from mlprodict.onnx_conv import to_onnx 

34 

35 iris = load_iris() 

36 X = iris.data 

37 y = iris.target 

38 lr = LogisticRegression() 

39 lr.fit(X, y) 

40 onx = to_onnx(lr, X[:1]) 

41 pprint.pprint((lr, onnx_statistics(onx))) 

42 

43 iris = load_iris() 

44 X = iris.data 

45 y = iris.target 

46 rf = RandomForestClassifier() 

47 rf.fit(X, y) 

48 onx = to_onnx(rf, X[:1], target_opset=12) 

49 pprint.pprint((rf, onnx_statistics(onx))) 

50 """ 

51 atts = ['doc_string', 'ir_version', 'metadata_props', 'domain', 

52 'model_version', 'producer_name', 'producer_version'] 

53 

54 def update(sts, st): 

55 for k, v in st.items(): 

56 if k in ['size'] or k in atts: 

57 continue # pragma: no cover 

58 if k in sts: 

59 sts[k] += v 

60 else: 

61 sts[k] = v 

62 

63 if hasattr(onnx_model, 'graph'): 

64 content = onnx_model.SerializeToString() 

65 nnodes = len(onnx_model.graph.node) 

66 ninits = len(onnx_model.graph.initializer) 

67 stats = {'size': len(content), 'nnodes': nnodes, 'ninits': ninits} 

68 for a in atts: 

69 v = getattr(onnx_model, a) 

70 if isinstance(v, str): 

71 li = None 

72 else: 

73 try: 

74 li = list(v) 

75 except TypeError: 

76 li = None 

77 if li is not None and len(li) == 0: 

78 continue 

79 stats[a] = v 

80 

81 for opi in onnx_model.opset_import: 

82 stats[opi.domain] = opi.version 

83 

84 graph = onnx_model.graph 

85 elif not hasattr(onnx_model, 'node'): # pragma: no cover 

86 # We're in a node. 

87 stats = {'nnodes': 1} 

88 if hasattr(onnx_model, 'attribute') and onnx_model.attribute: 

89 for att in onnx_model.attribute: 

90 if att.name == 'body': 

91 st = onnx_statistics(att.g) 

92 update(stats, st) 

93 return stats 

94 else: 

95 graph = onnx_model 

96 nnodes = len(graph.node) 

97 stats = {'nnodes': nnodes} 

98 

99 # Number of identities 

100 counts = Counter(map(lambda obj: obj.op_type, graph.node)) 

101 for op in ['Cast', 'Identity', 'ZipMap', 'Reshape']: 

102 if op in counts: 

103 stats['op_' + op] = counts[op] 

104 

105 # Recursive 

106 if recursive: 

107 for node in graph.node: 

108 if not hasattr(node, 'attribute'): 

109 continue # pragma: no cover 

110 for att in node.attribute: 

111 if att.name != 'body': 

112 continue 

113 substats = onnx_statistics(att.g, recursive=True, optim=False) 

114 update(stats, {'subgraphs': 1}) 

115 update(stats, substats) 

116 

117 # optimisation: remove_identity nodes 

118 if optim: 

119 new_model = onnx_remove_node( 

120 onnx_model, recursive=recursive) 

121 st = onnx_statistics(new_model, recursive=recursive, optim=False) 

122 for key in ["op_Identity", "subgraphs", "size", 

123 "nnodes", "ninits"]: 

124 if key in st: 

125 stats[key + "_optim"] = st[key] 

126 return stats 

127 

128 

129def change_input_first_dimension(onnx_model, N=None, debug_info=None): 

130 """ 

131 Some models are converted under the assumption 

132 batch prediction is not necessary. This function 

133 changes the first dimension of an ONNX graph. 

134 

135 @param onnx_model model :epkg:`onnx` 

136 @param N new first dimension, 

137 None to avoid changing it, 

138 0 to fix an undefined 

139 first dimension 

140 @param debug_info unused 

141 @return modified model onnx 

142 """ 

143 def _make_value_info(variable): 

144 value_info = ValueInfoProto() 

145 value_info.name = variable.full_name 

146 value_info.type.CopyFrom( # pylint: disable=E1101 

147 variable.type.to_onnx_type()) # pylint: disable=E1101 

148 if variable.type.doc_string: # pylint: disable=E0611 

149 value_info.doc_string = variable.type.doc_string # pragma: no cover 

150 return value_info 

151 

152 if hasattr(onnx_model, 'graph'): 

153 return _apply_optimisation_on_graph( 

154 change_input_first_dimension, onnx_model, N=N) 

155 

156 graph = onnx_model 

157 

158 nodes = graph.node 

159 inputs = [Variable.from_pb(input) for input in onnx_model.input] 

160 outputs = onnx_model.output 

161 

162 if N <= 0: 

163 N = None 

164 for input in inputs: 

165 input.type.shape[0] = N 

166 inputs = [_make_value_info(v) for v in inputs] 

167 

168 graph = make_graph(nodes, onnx_model.name, 

169 inputs, outputs, onnx_model.initializer) 

170 

171 graph.value_info.extend(onnx_model.value_info) # pylint: disable=E1101 

172 return graph