Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Optimisation of :epkg:`ONNX` graphs. 

4""" 

5from onnx.helper import make_graph 

6from ._onnx_optimisation_common import ( # pylint: disable=E0611 

7 _rename_node_input, 

8 _rename_node_output, 

9 _apply_optimisation_on_graph, 

10 _apply_remove_node_fct_node 

11) 

12 

13 

14def onnx_remove_node_identity(onnx_model, recursive=True, debug_info=None, **options): 

15 """ 

16 Removes as many *Identity* nodes as possible. 

17 The function looks into every node and subgraphs if 

18 *recursive* is True for identity node. Unless such a 

19 node directy connects one input to one output, it will 

20 be removed and every other node gets its inputs or 

21 outputs accordingly renamed. 

22 

23 @param onnx_model onnx model 

24 @param recursive looks into subgraphs 

25 @param debug_info debug information (private) 

26 @param options additional options (unused) 

27 @return new onnx _model 

28 """ 

29 if debug_info is None: 

30 debug_info = [str(type(onnx_model)).split('.')[-1].strip("'>")] 

31 else: 

32 debug_info = debug_info + \ 

33 [str(type(onnx_model)).split('.')[-1].strip("'>")] 

34 

35 if hasattr(onnx_model, 'graph'): 

36 return _apply_optimisation_on_graph( 

37 onnx_remove_node_identity, onnx_model, 

38 recursive=recursive, debug_info=debug_info, **options) 

39 

40 graph = onnx_model 

41 

42 inputs = set(i.name for i in graph.input) 

43 outputs = set(o.name for o in graph.output) 

44 

45 def retrieve_idnodes(graph, existing_nodes): 

46 idnodes = [] 

47 for i, exnode in enumerate(existing_nodes): 

48 if exnode is None: 

49 continue 

50 if exnode.op_type == 'Identity': 

51 input = exnode.input[0] 

52 output = exnode.output[0] 

53 idnodes.append((i, exnode, input, output)) 

54 return idnodes 

55 

56 nodes = list(graph.node) 

57 rem = 1 

58 while rem > 0: 

59 rem = 0 

60 idnodes = retrieve_idnodes(graph, nodes) 

61 restart = False 

62 for i, _, inp, out in idnodes: 

63 if restart: 

64 break # pragma: no cover 

65 if nodes[i] is None: 

66 # Already removed. 

67 continue # pragma: no cover 

68 if inp in inputs and out in outputs: 

69 # Cannot be removed. 

70 continue 

71 if not restart and out not in outputs: 

72 # We cannot change an output name. 

73 for j in range(len(nodes)): # pylint: disable=C0200 

74 if nodes[j] is None: 

75 continue 

76 if out in nodes[j].input: 

77 nodes[j] = _rename_node_input(nodes[j], out, inp) 

78 rem += 1 

79 if nodes[j].op_type == 'Identity': 

80 restart = True # pragma: no cover 

81 nodes[i] = None 

82 rem += 1 

83 continue 

84 if not restart and inp not in inputs and inp not in outputs: 

85 # We cannot change an input name or an output name. 

86 for j in range(len(nodes)): # pylint: disable=C0200 

87 if nodes[j] is None: 

88 continue 

89 if inp in nodes[j].output: 

90 nodes[j] = _rename_node_output(nodes[j], inp, out) 

91 rem += 1 

92 if nodes[j].op_type == 'Identity': 

93 restart = True # pragma: no cover 

94 if inp in nodes[j].input: 

95 nodes[j] = _rename_node_input(nodes[j], inp, out) 

96 rem += 1 

97 if nodes[j].op_type == 'Identity': 

98 restart = True 

99 nodes[i] = None 

100 rem += 1 

101 

102 if recursive: 

103 # Handles subgraphs. 

104 for i in range(len(nodes)): # pylint: disable=C0200 

105 node = nodes[i] 

106 if node is None or not (node.attribute): # pylint: disable=C0325 

107 continue 

108 nodes[i] = _apply_remove_node_fct_node( 

109 onnx_remove_node_identity, 

110 node, recursive=True, debug_info=debug_info + [node.name]) 

111 

112 # Finally create the new graph. 

113 nodes = list(filter(lambda n: n is not None, nodes)) 

114 graph = make_graph(nodes, onnx_model.name, 

115 onnx_model.input, onnx_model.output, 

116 onnx_model.initializer) 

117 

118 graph.value_info.extend(onnx_model.value_info) # pylint: disable=E1101 

119 return graph