Coverage for onnxcustom/utils/onnx_split.py: 92%
298 statements
« prev ^ index » next coverage.py v7.0.5, created at 2023-01-17 01:42 +0100
« prev ^ index » next coverage.py v7.0.5, created at 2023-01-17 01:42 +0100
1"""
2@file
3@brief Helpers to split an ONNX models.
4"""
5import textwrap
6import numpy
7from onnx import ( # pylint: disable=E0611
8 ModelProto, shape_inference, TensorProto, ValueInfoProto)
9from onnx.helper import make_graph, make_model
12class OnnxSegment:
13 """
14 A segments of an onnx graph assuming
15 it is the concatenation of all segments.
17 :param parent: an instance of OnnxSplitting
18 :param begin: names of the first extremity,
19 None for the inputs of the main graph
20 :param end: names of the second extremity,
21 None for the outputs of the main graph
22 :param size: total size of the segment
23 :param involved: list of result names involved in this segment
24 :param nodes: involved nodes, list of tuple `(int, NodeProt)`
25 """
27 def __init__(self, parent, begin, end, size=0, involved=None, nodes=None):
28 if begin is not None and not isinstance(begin, str):
29 raise ValueError(f"begin={begin!r} must be a string or None.")
30 if end is not None and not isinstance(end, str):
31 raise ValueError(f"end={end!r} must be a string or None.")
32 if begin is None and end is None:
33 raise ValueError(
34 "A segment cannot contain this whole model, "
35 "begin and end are both None.")
36 if nodes is not None and len(nodes) == 0:
37 raise ValueError(
38 f"A segment has no node, begin={begin!r}, "
39 f"end={end!r}, involved={involved!r}.")
40 self.parent = parent
41 self.begin = begin
42 self.end = end
43 self.involved = involved
44 self.size = size
45 self.nodes = nodes
47 def __repr__(self):
48 return f"{self.__class__.__name__}(...,\n " + "\n".join(
49 textwrap.wrap(
50 f"{self.begin!r}, {self.end!r}, size={self.size!r}, "
51 f"{self.involved!r})", subsequent_indent=" "))
54class OnnxSplitting:
55 """
56 The final goal is to split an onnx model into
57 equivalent pieces.
59 :param onnx_model: onnx_model
60 :param verbose: displays information during the split
61 :param fLOG: logging function
62 """
64 def __init__(self, onnx_model, verbose=0, fLOG=None):
65 self.onnx_model = onnx_model
66 self.verbose = verbose
67 self.fLOG = fLOG or print
68 self._init()
70 @staticmethod
71 def _key(idn, node):
72 return f"{node.op_type}-{node.name}-{idn}"
74 def _init(self):
75 onnx_model = self.onnx_model
76 if not isinstance(onnx_model, ModelProto):
77 raise TypeError(
78 f"onnx_model must a ModelProto not a {type(onnx_model)}.")
79 node_list = list(enumerate(onnx_model.graph.node))
81 # sizes
82 sizes = {}
83 for init in onnx_model.graph.initializer:
84 sizes[init.name] = len(init.SerializeToString())
85 for init in onnx_model.graph.sparse_initializer:
86 sizes[init.name] = len(init.SerializeToString())
88 for idn, node in node_list:
89 sizes[self._key(idn, node)] = len(node.SerializeToString())
90 self.sizes = sizes
92 # only working for standard domain (supporting shape inference)
93 for node in onnx_model.graph.node:
94 if node.domain not in {'', 'ai.onnx', 'ai.onnx.ml'}:
95 raise NotImplementedError(
96 f"Node {node.op_type!r} from domain {node.domain!r} "
97 f"is not supported yet.")
99 # cut points: results breaking the connexity of the graph
100 if self.verbose > 0:
101 self.fLOG(
102 f"[OnnxSplitting] look for cutting points in {len(node_list)} nodes.")
104 self.cutting_points = self._get_cutting_points(node_list)
106 if self.verbose:
107 self.fLOG(
108 f"[OnnxSplitting] # cuttings points: {len(self.cutting_points)}")
110 # segments
111 if self.verbose > 1:
112 import tqdm # pylint: disable=C0415
113 loop = tqdm.tqdm(range(len(self.cutting_points)))
114 else:
115 loop = range(len(self.cutting_points))
116 segments = []
117 for i in loop: # pylint: disable=C0200
118 segments.append(
119 self._make_segment(
120 None if i == 0 else self.cutting_points[i - 1],
121 self.cutting_points[i]))
122 segments.append(self._make_segment(self.cutting_points[-1], None))
123 self.segments = segments
124 if self.verbose > 0:
125 self.fLOG(f"[OnnxSplitting] # segments = {len(sizes)}")
126 self.fLOG("[OnnxSplitting] run shape_inference")
127 self.shapes = shape_inference.infer_shapes(onnx_model)
129 if self.verbose > 0:
130 sizes = [seg.size for seg in self.segments]
131 self.fLOG(f"[OnnxSplitting] # segments = {len(sizes)}, "
132 f"min,avg,max size=[{min(sizes)}, "
133 f"{sum(sizes) / len(sizes)}, {max(sizes)}]")
135 @staticmethod
136 def _connex_components(vertices, adja):
137 vert = {v: i for i, v in enumerate(vertices)}
138 more = True
139 while more:
140 more = False
141 for k, v in adja.items():
142 if v == 0:
143 continue
144 a, b = k
145 if vert[a] == vert[b]:
146 continue
147 more = True
148 if vert[a] < vert[b]:
149 vert[b] = vert[a]
150 else:
151 vert[a] = vert[b]
152 return vert
154 @staticmethod
155 def is_small(tensor):
156 """
157 Tells if a tensor is small. In that case, all edges to this
158 constant are ignored when looking for cutting points.
159 The algorithm assumes it can be duplicated in multiple parts.
160 It is usually single float constant or shapes.
161 """
162 if isinstance(tensor, TensorProto):
163 if tensor.HasField("segment"):
164 raise ValueError( # pragma: no cover
165 "Currently not supporting loading segments.")
166 if tensor.data_type == TensorProto.UNDEFINED: # pylint: disable=E1101
167 raise TypeError( # pragma: no cover
168 "The element type in the input tensor is not defined.")
169 dims = tensor.dims
170 elif isinstance(tensor, ValueInfoProto):
171 dim = tensor.type.tensor_type.shape.dim
172 dims = [d.dim_value for d in dim]
173 if any(map(lambda x: not isinstance(x, int), dims)):
174 return False
175 else:
176 raise TypeError( # pragma: no cover
177 f"Unexpected type {type(tensor)}.")
178 total = numpy.prod(dims)
179 if total < 32:
180 # Covers small constants, reshaping...
181 return True
182 return False
184 def _get_cutting_points(self, node_list):
185 # let's avoid adding small constant
186 small_tensors = {
187 i.name: self.is_small(i)
188 for i in self.onnx_model.graph.initializer}
189 small_tensors.update({
190 i.name: self.is_small(i)
191 for i in self.onnx_model.graph.sparse_initializer})
192 small_tensors.update({
193 i.name: self.is_small(i)
194 for i in self.onnx_model.graph.input})
195 set_small = set(k for k, v in small_tensors.items() if v)
196 for idn, node in node_list:
197 if len(node.input) == 0 and len(node.SerializeToString()) < 128:
198 key = self._key(idn, node)
199 set_small.add(key)
200 set_small |= set(node.output)
202 # adjacency matrix
203 no_cutting = (
204 set(small_tensors) |
205 set(o.name for o in self.onnx_model.graph.output))
206 constant_type = {'Constant', 'ConstantOfShape'}
207 adja = {}
208 vertices = set()
209 ordered_names = []
210 for idn, node in node_list:
211 key = self._key(idn, node)
212 if key in set_small:
213 continue
214 if (node.op_type not in constant_type and
215 len(node.output) == 1 and
216 len(node.input) > 0):
217 # only single output can be cutting points
218 ordered_names.extend(
219 o for o in node.output if o not in no_cutting)
220 vertices.add(key)
221 vertices |= set(i for i in node.input if i not in set_small)
222 vertices |= set(o for o in node.output if o not in set_small)
223 for i in node.input:
224 if i in set_small:
225 continue
226 adja[i, key] = 1
227 for o in node.output:
228 if o in set_small:
229 continue
230 adja[key, o] = 1
232 # checking the connexity
233 if self.verbose > 1:
234 import tqdm # pylint: disable=C0415
235 loop = tqdm.tqdm(ordered_names)
236 else:
237 loop = ordered_names
238 cutting_points = []
239 for name in loop:
240 keys = []
241 for a, b in adja:
242 if b == name:
243 keys.append((a, b))
245 # remove the links
246 for a, b in keys:
247 adja[a, b] = 0
249 connex = self._connex_components(vertices, adja)
250 connex_id = set(connex.values())
251 if len(connex_id) == 2:
252 cutting_points.append(name)
254 # put back the links
255 for a, b in keys:
256 adja[a, b] = 1
258 return cutting_points
260 def _make_segment(self, name1, name2):
261 nodes = []
262 for idn, node in enumerate(self.onnx_model.graph.node):
263 nodes.append((idn, node))
264 if name2 is not None and name2 in node.output:
265 break
267 if name2 is None:
268 names = set(i.name for i in self.onnx_model.graph.output)
269 else:
270 names = {name2}
272 size = 0
273 subset = []
274 for idn, node in reversed(nodes):
275 if set(node.output) & names:
276 size += self.sizes[self._key(idn, node)]
277 if len(node.output) == 1 and node.output[0] == name1:
278 continue
279 subset.append((idn, node))
280 if len(node.input) == 1 and node.input[0] == name1:
281 continue
282 for i in node.input:
283 if i in self.sizes:
284 size += self.sizes[i]
285 names |= set(node.input)
286 subset.sort() # original order must be kept
287 involved = names if name2 is None else names - {name2}
288 return OnnxSegment(self, begin=name1, end=name2, involved=involved,
289 size=size, nodes=subset)
291 def _split_2(self, a, b):
292 """
293 Splits the segments into two groups of the same size.
295 :param a: first segment (included)
296 :param b: second segment (excluded)
297 :return: split index
298 """
299 if a >= b - 1:
300 raise RuntimeError(f"a={a}, b={b}, unable to split.")
301 if a == b - 2:
302 return a + 1
303 sizes = numpy.array([s.size for s in self.segments[a:b]])
304 sizes_for = numpy.cumsum(sizes)
305 sizes_bck = numpy.cumsum(sizes[::-1])[::-1].copy()
306 diff = numpy.abs(sizes_bck - sizes_for)
307 pos = numpy.argmin(diff)
308 # pos is the beginning of the interval
309 pos += 1
310 pos += a
311 if pos == a:
312 pos = a + 1
313 elif pos == b:
314 pos = b - 1
315 return pos
317 def split_segment(self, n_parts=None, cut_points=None):
318 """
319 Splits the segments into `n_parts` segments
321 :param n_parts: number of parts to get
322 :param cut_points: uses this particular cut points
323 :return: list of segments indices
324 """
325 if n_parts is None and cut_points is None:
326 raise ValueError("n_parts or cut_points must be specified.")
327 if n_parts is not None and cut_points is not None:
328 raise ValueError("n_parts and cut_points cannot "
329 "be specified at the same time.")
330 if cut_points is not None:
331 possible = set(self.cutting_points)
332 for name in cut_points:
333 if name not in possible:
334 text = "\n".join(textwrap.wrap(str(self.cutting_points)))
335 raise ValueError(
336 f"Cut point {name!r} is not considered as a cutting "
337 f"points. Possible canditates:\n{text}")
338 memo = {s.begin: i for i, s in enumerate(
339 self.segments) if s.begin is not None}
340 extremities = [0]
341 for name in cut_points:
342 extremities.append(memo[name])
343 extremities.append(len(self.segments))
344 return extremities
346 if self.verbose > 10:
347 self.fLOG("[OnnxSplitting] cutting points")
348 self.fLOG("\n".join(textwrap.wrap(
349 ", ".join(map(str, self.cutting_points)))))
350 extremities = [0, len(self.segments)]
351 n = n_parts
352 while n > 1:
353 if n % 2 != 0:
354 raise NotImplementedError(
355 f"n_parts={n_parts} is not a power of 2.")
356 new_ext = [extremities[0]]
357 for i in range(1, len(extremities)):
358 a, b = extremities[i - 1:i + 1]
359 if self.verbose > 1:
360 size = sum(s.size for s in self.segments[a:b])
361 names = self.segments[a].begin, self.segments[b - 1].end
362 self.fLOG(f"[OnnxSplitting] split into n={n}, from a={a} to b={b}, "
363 f"size={size}, {names[0]!r} -> {names[1]!r}")
364 pos = self._split_2(a, b)
365 if self.verbose > 1:
366 size_a = sum(s.size for s in self.segments[a:pos])
367 size_b = sum(s.size for s in self.segments[pos:b])
368 self.fLOG(f"[OnnxSplitting] found pos={pos}, size_1={size_a}, "
369 f"size_2={size_b}={size_b/size:1.2f}, "
370 f"split={self.segments[pos].begin!r}")
371 new_ext.extend([pos, b])
372 extremities = new_ext
373 n = n // 2
374 return extremities
376 def make_onnx(self, extremities):
377 """
378 Builds onnx subparts based on the segmentation
379 defined by extremities.
381 :param extremities: example, `[0, 3, 5]`,
382 first onnx part contains segments `0:3=[0, 1, 2]`,
383 second onnx part contains segments `3:5=[3, 4]`
384 :return: list of onnx subgraphs (:epkg:`ModelProto`)
385 """
386 res = []
387 for i in range(1, len(extremities)):
388 a, b = extremities[i - 1:i + 1]
389 onx = self._make_onnx(a, b, i - 1)
390 res.append(onx)
391 if self.verbose > 0:
392 n_nodes = len(self.onnx_model.graph.node)
393 total = sum(s.size for s in self.segments)
394 size = sum(self.segments[i].size for i in range(a, b))
395 self.fLOG(f"[OnnxSplitting] part {i}: "
396 f"#nodes={len(onx.graph.node)}" # pylint: disable=E1101
397 f"/{n_nodes}, size={size}/{total}={size/total:1.2f}")
398 return res
400 def _make_onnx(self, a, b, index=None):
401 """
402 Builds one onnx subpart including segments from a to b (excluded).
403 """
404 if index is None:
405 index = a
407 # common parts
408 value_info = {o.name: o for o in self.onnx_model.graph.output}
409 value_info.update({
410 info.name: info
411 for info in self.shapes.graph.value_info}) # pylint: disable=E1101
413 segs = self.segments[a:b]
414 involved = set()
415 for seg in segs:
416 involved |= seg.involved
418 # initiliazers
419 new_inits = [init for init in self.onnx_model.graph.initializer
420 if init.name in involved]
421 new_sp_inits = [init for init in self.onnx_model.graph.sparse_initializer
422 if init.name in involved]
424 # nodes
425 nodes = []
426 for seg in segs:
427 for _, node in seg.nodes:
428 nodes.append(node)
430 # inputs, outputs
431 existing_inputs = [i for i in self.onnx_model.graph.input
432 if i.name in involved]
433 if a == 0:
434 new_inputs = existing_inputs
435 else:
436 new_inputs = [value_info[segs[0].begin]] + existing_inputs
438 if b == len(self.segments):
439 new_outputs = [i for i in self.onnx_model.graph.output
440 if i.name in involved]
441 else:
442 new_outputs = [value_info[segs[-1].end]]
444 model = self.onnx_model
445 graph = make_graph(
446 nodes, f"{model.graph.name}-{index}",
447 new_inputs, new_outputs,
448 new_inits, doc_string=model.graph.doc_string,
449 sparse_initializer=new_sp_inits,
450 value_info=model.graph.value_info)
451 new_model = make_model(graph, opset_imports=model.opset_import)
452 new_model.ir_version = model.ir_version
453 new_model.producer_name = model.producer_name
454 new_model.producer_version = model.producer_version
455 new_model.domain = model.domain
456 new_model.model_version = model.model_version
457 new_model.doc_string = model.doc_string
458 return new_model
461def split_onnx(onnx_model, n_parts=None, cut_points=None,
462 verbose=0, stats=False, fLOG=None):
463 """
464 Splits an ONNX model into *n_parts* consecutive subgraphs.
465 Chained altogether, they are equivalent to the given model.
467 :param onnx_model: onnx model
468 :param n_parts: number of subgraphs
469 :param cut_points: use different cutting points that the ones
470 the algorithm can found, it must be in the available set
471 of cutting points
472 :param verbose: display information related to the split
473 :param stats: returns statistics as well, return of the
474 function is a tuple
475 :param fLOG: logging function
476 :return: list of onnx model
477 """
478 if len(onnx_model.functions) > 0:
479 raise NotImplementedError(
480 f"The function does not work if the model contains function: "
481 f"{f.name for f in onnx_model.functions}.")
482 if n_parts is not None and not isinstance(n_parts, int):
483 raise TypeError(
484 f"n_parts must be None or an interger not {type(n_parts)}.")
485 if cut_points is not None and not isinstance(cut_points, (list, tuple)):
486 raise TypeError(
487 f"cut_points must be None or a list not {type(n_parts)}.")
488 if verbose > 0:
489 (fLOG or print)(
490 f"[split_onnx] prepare splitting "
491 f"{len(onnx_model.graph.node)} nodes in {n_parts} parts.")
492 spl_onnx = OnnxSplitting(onnx_model, verbose=verbose, fLOG=fLOG or print)
493 if n_parts is not None and len(spl_onnx.cutting_points) < n_parts:
494 raise RuntimeError( # pragma: no cover
495 f"Unable to split the onnn model, there are less cutting points "
496 f"{len(spl_onnx.cutting_points)} than the number of requested "
497 f"splits ({n_parts}).")
498 if verbose > 0:
499 (fLOG or print)(
500 f"[split_onnx] starts splitting "
501 f"{len(onnx_model.graph.node)} nodes in {n_parts} parts.")
502 exts = spl_onnx.split_segment(n_parts, cut_points=cut_points)
503 if verbose > 0:
504 names = [spl_onnx.segments[i].begin for i in exts[1:-1]]
505 (fLOG or print)(f"[split_onnx] splits: {exts}, names={names}")
506 res = spl_onnx.make_onnx(exts)
507 if stats:
508 more = dict(
509 split=spl_onnx,
510 segments=[dict(size=s.size, nodes=len(s.nodes),
511 involved=s.involved)
512 for s in spl_onnx.segments],
513 cutting_points=spl_onnx.cutting_points,
514 extremities=exts,
515 split_points=[spl_onnx.segments[e].begin for e in exts[1:-1]])
516 return res, more
517 return res