Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Implements a class able to compute the predictions
4from on an :epkg:`ONNX` model.
5"""
6from onnx import helper
9def enumerate_model_node_outputs(model, add_node=False):
10 """
11 Enumerates all the nodes of a model.
13 @param model :epkg:`ONNX` graph
14 @param add_node if False, the function enumerates
15 all output names from every node, otherwise, it
16 enumerates tuple (output name, node)
17 @return enumerator
18 """
19 if not hasattr(model, "graph"):
20 raise TypeError( # pragma: no cover
21 "Parameter model is not an ONNX model but "
22 "{}".format(type(model)))
23 for node in model.graph.node:
24 for out in node.output:
25 yield (out, node) if add_node else out
28def select_model_inputs_outputs(model, outputs=None, inputs=None):
29 """
30 Takes a model and changes its outputs.
32 @param model :epkg:`ONNX` model
33 @param inputs new inputs, same ones if None
34 @param outputs new outputs, same ones if None
35 @return modified model
37 The function removes unneeded files.
38 """
39 if inputs is not None:
40 raise NotImplementedError( # pragma: no cover
41 "Parameter inputs cannot be empty.")
42 if outputs is None:
43 raise RuntimeError( # pragma: no cover
44 "Parameter outputs cannot be None.")
45 if not isinstance(outputs, list):
46 outputs = [outputs]
48 mark_var = {}
49 for out in enumerate_model_node_outputs(model):
50 mark_var[out] = 0
51 for inp in model.graph.input:
52 mark_var[inp.name] = 0
53 for out in outputs:
54 if out not in mark_var:
55 raise ValueError( # pragma: no cover
56 "Output '{}' not found in model.".format(out))
57 mark_var[out] = 1
59 nodes = model.graph.node[::-1]
60 mark_op = {}
61 for node in nodes:
62 mark_op[node.name] = 0
64 # We mark all the nodes we need to keep.
65 nb = 1
66 while nb > 0:
67 nb = 0
68 for node in nodes:
69 if mark_op[node.name] == 1:
70 continue
71 mod = False
72 for out in node.output:
73 if mark_var[out] == 1:
74 mark_op[node.name] = 1
75 mod = True
76 break
77 if not mod:
78 continue
80 nb += 1
81 for inp in node.input:
82 if mark_var.get(inp, 0) == 1:
83 continue
84 mark_var[inp] = 1
85 nb += 1
87 # All nodes verifies mark_op[node.name] == 1
88 keep_nodes = [node for node in nodes if mark_op[node.name] == 1]
90 var_out = []
91 for out in outputs:
92 value_info = helper.ValueInfoProto()
93 value_info.name = out
94 var_out.append(value_info)
95 graph = helper.make_graph(keep_nodes, model.graph.name, model.graph.input,
96 var_out, model.graph.initializer)
97 onnx_model = helper.make_model(graph)
98 onnx_model.ir_version = model.ir_version
99 onnx_model.producer_name = model.producer_name
100 onnx_model.producer_version = model.producer_version
101 onnx_model.domain = model.domain
102 onnx_model.model_version = model.model_version
103 onnx_model.doc_string = model.doc_string
104 if len(model.metadata_props) > 0: # pragma: no cover
105 values = {p.key: p.value for p in model.metadata_props}
106 helper.set_model_props(onnx_model, values)
108 del onnx_model.opset_import[:] # pylint: disable=E1101
109 for oimp in model.opset_import:
110 op_set = onnx_model.opset_import.add() # pylint: disable=E1101
111 op_set.domain = oimp.domain
112 op_set.version = oimp.version
114 if len(onnx_model.graph.input) != len(model.graph.input): # pylint: disable=E1101
115 raise RuntimeError( # pragma: no cover
116 "Input mismatch {} != {}".format(
117 len(onnx_model.input), len(model.input))) # pylint: disable=E1101
118 return onnx_model