Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Implements a class able to compute the predictions 

4from on an :epkg:`ONNX` model. 

5""" 

6from onnx import helper 

7 

8 

9def enumerate_model_node_outputs(model, add_node=False): 

10 """ 

11 Enumerates all the nodes of a model. 

12 

13 @param model :epkg:`ONNX` graph 

14 @param add_node if False, the function enumerates 

15 all output names from every node, otherwise, it 

16 enumerates tuple (output name, node) 

17 @return enumerator 

18 """ 

19 if not hasattr(model, "graph"): 

20 raise TypeError( # pragma: no cover 

21 "Parameter model is not an ONNX model but " 

22 "{}".format(type(model))) 

23 for node in model.graph.node: 

24 for out in node.output: 

25 yield (out, node) if add_node else out 

26 

27 

28def select_model_inputs_outputs(model, outputs=None, inputs=None): 

29 """ 

30 Takes a model and changes its outputs. 

31 

32 @param model :epkg:`ONNX` model 

33 @param inputs new inputs, same ones if None 

34 @param outputs new outputs, same ones if None 

35 @return modified model 

36 

37 The function removes unneeded files. 

38 """ 

39 if inputs is not None: 

40 raise NotImplementedError( # pragma: no cover 

41 "Parameter inputs cannot be empty.") 

42 if outputs is None: 

43 raise RuntimeError( # pragma: no cover 

44 "Parameter outputs cannot be None.") 

45 if not isinstance(outputs, list): 

46 outputs = [outputs] 

47 

48 mark_var = {} 

49 for out in enumerate_model_node_outputs(model): 

50 mark_var[out] = 0 

51 for inp in model.graph.input: 

52 mark_var[inp.name] = 0 

53 for out in outputs: 

54 if out not in mark_var: 

55 raise ValueError( # pragma: no cover 

56 "Output '{}' not found in model.".format(out)) 

57 mark_var[out] = 1 

58 

59 nodes = model.graph.node[::-1] 

60 mark_op = {} 

61 for node in nodes: 

62 mark_op[node.name] = 0 

63 

64 # We mark all the nodes we need to keep. 

65 nb = 1 

66 while nb > 0: 

67 nb = 0 

68 for node in nodes: 

69 if mark_op[node.name] == 1: 

70 continue 

71 mod = False 

72 for out in node.output: 

73 if mark_var[out] == 1: 

74 mark_op[node.name] = 1 

75 mod = True 

76 break 

77 if not mod: 

78 continue 

79 

80 nb += 1 

81 for inp in node.input: 

82 if mark_var.get(inp, 0) == 1: 

83 continue 

84 mark_var[inp] = 1 

85 nb += 1 

86 

87 # All nodes verifies mark_op[node.name] == 1 

88 keep_nodes = [node for node in nodes if mark_op[node.name] == 1] 

89 

90 var_out = [] 

91 for out in outputs: 

92 value_info = helper.ValueInfoProto() 

93 value_info.name = out 

94 var_out.append(value_info) 

95 graph = helper.make_graph(keep_nodes, model.graph.name, model.graph.input, 

96 var_out, model.graph.initializer) 

97 onnx_model = helper.make_model(graph) 

98 onnx_model.ir_version = model.ir_version 

99 onnx_model.producer_name = model.producer_name 

100 onnx_model.producer_version = model.producer_version 

101 onnx_model.domain = model.domain 

102 onnx_model.model_version = model.model_version 

103 onnx_model.doc_string = model.doc_string 

104 if len(model.metadata_props) > 0: # pragma: no cover 

105 values = {p.key: p.value for p in model.metadata_props} 

106 helper.set_model_props(onnx_model, values) 

107 

108 del onnx_model.opset_import[:] # pylint: disable=E1101 

109 for oimp in model.opset_import: 

110 op_set = onnx_model.opset_import.add() # pylint: disable=E1101 

111 op_set.domain = oimp.domain 

112 op_set.version = oimp.version 

113 

114 if len(onnx_model.graph.input) != len(model.graph.input): # pylint: disable=E1101 

115 raise RuntimeError( # pragma: no cover 

116 "Input mismatch {} != {}".format( 

117 len(onnx_model.input), len(model.input))) # pylint: disable=E1101 

118 return onnx_model