Coverage for onnxcustom/utils/onnx_split.py: 92%

298 statements  

« prev     ^ index     » next       coverage.py v7.0.5, created at 2023-01-17 01:42 +0100

1""" 

2@file 

3@brief Helpers to split an ONNX models. 

4""" 

5import textwrap 

6import numpy 

7from onnx import ( # pylint: disable=E0611 

8 ModelProto, shape_inference, TensorProto, ValueInfoProto) 

9from onnx.helper import make_graph, make_model 

10 

11 

12class OnnxSegment: 

13 """ 

14 A segments of an onnx graph assuming 

15 it is the concatenation of all segments. 

16 

17 :param parent: an instance of OnnxSplitting 

18 :param begin: names of the first extremity, 

19 None for the inputs of the main graph 

20 :param end: names of the second extremity, 

21 None for the outputs of the main graph 

22 :param size: total size of the segment 

23 :param involved: list of result names involved in this segment 

24 :param nodes: involved nodes, list of tuple `(int, NodeProt)` 

25 """ 

26 

27 def __init__(self, parent, begin, end, size=0, involved=None, nodes=None): 

28 if begin is not None and not isinstance(begin, str): 

29 raise ValueError(f"begin={begin!r} must be a string or None.") 

30 if end is not None and not isinstance(end, str): 

31 raise ValueError(f"end={end!r} must be a string or None.") 

32 if begin is None and end is None: 

33 raise ValueError( 

34 "A segment cannot contain this whole model, " 

35 "begin and end are both None.") 

36 if nodes is not None and len(nodes) == 0: 

37 raise ValueError( 

38 f"A segment has no node, begin={begin!r}, " 

39 f"end={end!r}, involved={involved!r}.") 

40 self.parent = parent 

41 self.begin = begin 

42 self.end = end 

43 self.involved = involved 

44 self.size = size 

45 self.nodes = nodes 

46 

47 def __repr__(self): 

48 return f"{self.__class__.__name__}(...,\n " + "\n".join( 

49 textwrap.wrap( 

50 f"{self.begin!r}, {self.end!r}, size={self.size!r}, " 

51 f"{self.involved!r})", subsequent_indent=" ")) 

52 

53 

54class OnnxSplitting: 

55 """ 

56 The final goal is to split an onnx model into 

57 equivalent pieces. 

58 

59 :param onnx_model: onnx_model 

60 :param verbose: displays information during the split 

61 :param fLOG: logging function 

62 """ 

63 

64 def __init__(self, onnx_model, verbose=0, fLOG=None): 

65 self.onnx_model = onnx_model 

66 self.verbose = verbose 

67 self.fLOG = fLOG or print 

68 self._init() 

69 

70 @staticmethod 

71 def _key(idn, node): 

72 return f"{node.op_type}-{node.name}-{idn}" 

73 

74 def _init(self): 

75 onnx_model = self.onnx_model 

76 if not isinstance(onnx_model, ModelProto): 

77 raise TypeError( 

78 f"onnx_model must a ModelProto not a {type(onnx_model)}.") 

79 node_list = list(enumerate(onnx_model.graph.node)) 

80 

81 # sizes 

82 sizes = {} 

83 for init in onnx_model.graph.initializer: 

84 sizes[init.name] = len(init.SerializeToString()) 

85 for init in onnx_model.graph.sparse_initializer: 

86 sizes[init.name] = len(init.SerializeToString()) 

87 

88 for idn, node in node_list: 

89 sizes[self._key(idn, node)] = len(node.SerializeToString()) 

90 self.sizes = sizes 

91 

92 # only working for standard domain (supporting shape inference) 

93 for node in onnx_model.graph.node: 

94 if node.domain not in {'', 'ai.onnx', 'ai.onnx.ml'}: 

95 raise NotImplementedError( 

96 f"Node {node.op_type!r} from domain {node.domain!r} " 

97 f"is not supported yet.") 

98 

99 # cut points: results breaking the connexity of the graph 

100 if self.verbose > 0: 

101 self.fLOG( 

102 f"[OnnxSplitting] look for cutting points in {len(node_list)} nodes.") 

103 

104 self.cutting_points = self._get_cutting_points(node_list) 

105 

106 if self.verbose: 

107 self.fLOG( 

108 f"[OnnxSplitting] # cuttings points: {len(self.cutting_points)}") 

109 

110 # segments 

111 if self.verbose > 1: 

112 import tqdm # pylint: disable=C0415 

113 loop = tqdm.tqdm(range(len(self.cutting_points))) 

114 else: 

115 loop = range(len(self.cutting_points)) 

116 segments = [] 

117 for i in loop: # pylint: disable=C0200 

118 segments.append( 

119 self._make_segment( 

120 None if i == 0 else self.cutting_points[i - 1], 

121 self.cutting_points[i])) 

122 segments.append(self._make_segment(self.cutting_points[-1], None)) 

123 self.segments = segments 

124 if self.verbose > 0: 

125 self.fLOG(f"[OnnxSplitting] # segments = {len(sizes)}") 

126 self.fLOG("[OnnxSplitting] run shape_inference") 

127 self.shapes = shape_inference.infer_shapes(onnx_model) 

128 

129 if self.verbose > 0: 

130 sizes = [seg.size for seg in self.segments] 

131 self.fLOG(f"[OnnxSplitting] # segments = {len(sizes)}, " 

132 f"min,avg,max size=[{min(sizes)}, " 

133 f"{sum(sizes) / len(sizes)}, {max(sizes)}]") 

134 

135 @staticmethod 

136 def _connex_components(vertices, adja): 

137 vert = {v: i for i, v in enumerate(vertices)} 

138 more = True 

139 while more: 

140 more = False 

141 for k, v in adja.items(): 

142 if v == 0: 

143 continue 

144 a, b = k 

145 if vert[a] == vert[b]: 

146 continue 

147 more = True 

148 if vert[a] < vert[b]: 

149 vert[b] = vert[a] 

150 else: 

151 vert[a] = vert[b] 

152 return vert 

153 

154 @staticmethod 

155 def is_small(tensor): 

156 """ 

157 Tells if a tensor is small. In that case, all edges to this 

158 constant are ignored when looking for cutting points. 

159 The algorithm assumes it can be duplicated in multiple parts. 

160 It is usually single float constant or shapes. 

161 """ 

162 if isinstance(tensor, TensorProto): 

163 if tensor.HasField("segment"): 

164 raise ValueError( # pragma: no cover 

165 "Currently not supporting loading segments.") 

166 if tensor.data_type == TensorProto.UNDEFINED: # pylint: disable=E1101 

167 raise TypeError( # pragma: no cover 

168 "The element type in the input tensor is not defined.") 

169 dims = tensor.dims 

170 elif isinstance(tensor, ValueInfoProto): 

171 dim = tensor.type.tensor_type.shape.dim 

172 dims = [d.dim_value for d in dim] 

173 if any(map(lambda x: not isinstance(x, int), dims)): 

174 return False 

175 else: 

176 raise TypeError( # pragma: no cover 

177 f"Unexpected type {type(tensor)}.") 

178 total = numpy.prod(dims) 

179 if total < 32: 

180 # Covers small constants, reshaping... 

181 return True 

182 return False 

183 

184 def _get_cutting_points(self, node_list): 

185 # let's avoid adding small constant 

186 small_tensors = { 

187 i.name: self.is_small(i) 

188 for i in self.onnx_model.graph.initializer} 

189 small_tensors.update({ 

190 i.name: self.is_small(i) 

191 for i in self.onnx_model.graph.sparse_initializer}) 

192 small_tensors.update({ 

193 i.name: self.is_small(i) 

194 for i in self.onnx_model.graph.input}) 

195 set_small = set(k for k, v in small_tensors.items() if v) 

196 for idn, node in node_list: 

197 if len(node.input) == 0 and len(node.SerializeToString()) < 128: 

198 key = self._key(idn, node) 

199 set_small.add(key) 

200 set_small |= set(node.output) 

201 

202 # adjacency matrix 

203 no_cutting = ( 

204 set(small_tensors) | 

205 set(o.name for o in self.onnx_model.graph.output)) 

206 constant_type = {'Constant', 'ConstantOfShape'} 

207 adja = {} 

208 vertices = set() 

209 ordered_names = [] 

210 for idn, node in node_list: 

211 key = self._key(idn, node) 

212 if key in set_small: 

213 continue 

214 if (node.op_type not in constant_type and 

215 len(node.output) == 1 and 

216 len(node.input) > 0): 

217 # only single output can be cutting points 

218 ordered_names.extend( 

219 o for o in node.output if o not in no_cutting) 

220 vertices.add(key) 

221 vertices |= set(i for i in node.input if i not in set_small) 

222 vertices |= set(o for o in node.output if o not in set_small) 

223 for i in node.input: 

224 if i in set_small: 

225 continue 

226 adja[i, key] = 1 

227 for o in node.output: 

228 if o in set_small: 

229 continue 

230 adja[key, o] = 1 

231 

232 # checking the connexity 

233 if self.verbose > 1: 

234 import tqdm # pylint: disable=C0415 

235 loop = tqdm.tqdm(ordered_names) 

236 else: 

237 loop = ordered_names 

238 cutting_points = [] 

239 for name in loop: 

240 keys = [] 

241 for a, b in adja: 

242 if b == name: 

243 keys.append((a, b)) 

244 

245 # remove the links 

246 for a, b in keys: 

247 adja[a, b] = 0 

248 

249 connex = self._connex_components(vertices, adja) 

250 connex_id = set(connex.values()) 

251 if len(connex_id) == 2: 

252 cutting_points.append(name) 

253 

254 # put back the links 

255 for a, b in keys: 

256 adja[a, b] = 1 

257 

258 return cutting_points 

259 

260 def _make_segment(self, name1, name2): 

261 nodes = [] 

262 for idn, node in enumerate(self.onnx_model.graph.node): 

263 nodes.append((idn, node)) 

264 if name2 is not None and name2 in node.output: 

265 break 

266 

267 if name2 is None: 

268 names = set(i.name for i in self.onnx_model.graph.output) 

269 else: 

270 names = {name2} 

271 

272 size = 0 

273 subset = [] 

274 for idn, node in reversed(nodes): 

275 if set(node.output) & names: 

276 size += self.sizes[self._key(idn, node)] 

277 if len(node.output) == 1 and node.output[0] == name1: 

278 continue 

279 subset.append((idn, node)) 

280 if len(node.input) == 1 and node.input[0] == name1: 

281 continue 

282 for i in node.input: 

283 if i in self.sizes: 

284 size += self.sizes[i] 

285 names |= set(node.input) 

286 subset.sort() # original order must be kept 

287 involved = names if name2 is None else names - {name2} 

288 return OnnxSegment(self, begin=name1, end=name2, involved=involved, 

289 size=size, nodes=subset) 

290 

291 def _split_2(self, a, b): 

292 """ 

293 Splits the segments into two groups of the same size. 

294 

295 :param a: first segment (included) 

296 :param b: second segment (excluded) 

297 :return: split index 

298 """ 

299 if a >= b - 1: 

300 raise RuntimeError(f"a={a}, b={b}, unable to split.") 

301 if a == b - 2: 

302 return a + 1 

303 sizes = numpy.array([s.size for s in self.segments[a:b]]) 

304 sizes_for = numpy.cumsum(sizes) 

305 sizes_bck = numpy.cumsum(sizes[::-1])[::-1].copy() 

306 diff = numpy.abs(sizes_bck - sizes_for) 

307 pos = numpy.argmin(diff) 

308 # pos is the beginning of the interval 

309 pos += 1 

310 pos += a 

311 if pos == a: 

312 pos = a + 1 

313 elif pos == b: 

314 pos = b - 1 

315 return pos 

316 

317 def split_segment(self, n_parts=None, cut_points=None): 

318 """ 

319 Splits the segments into `n_parts` segments 

320 

321 :param n_parts: number of parts to get 

322 :param cut_points: uses this particular cut points 

323 :return: list of segments indices 

324 """ 

325 if n_parts is None and cut_points is None: 

326 raise ValueError("n_parts or cut_points must be specified.") 

327 if n_parts is not None and cut_points is not None: 

328 raise ValueError("n_parts and cut_points cannot " 

329 "be specified at the same time.") 

330 if cut_points is not None: 

331 possible = set(self.cutting_points) 

332 for name in cut_points: 

333 if name not in possible: 

334 text = "\n".join(textwrap.wrap(str(self.cutting_points))) 

335 raise ValueError( 

336 f"Cut point {name!r} is not considered as a cutting " 

337 f"points. Possible canditates:\n{text}") 

338 memo = {s.begin: i for i, s in enumerate( 

339 self.segments) if s.begin is not None} 

340 extremities = [0] 

341 for name in cut_points: 

342 extremities.append(memo[name]) 

343 extremities.append(len(self.segments)) 

344 return extremities 

345 

346 if self.verbose > 10: 

347 self.fLOG("[OnnxSplitting] cutting points") 

348 self.fLOG("\n".join(textwrap.wrap( 

349 ", ".join(map(str, self.cutting_points))))) 

350 extremities = [0, len(self.segments)] 

351 n = n_parts 

352 while n > 1: 

353 if n % 2 != 0: 

354 raise NotImplementedError( 

355 f"n_parts={n_parts} is not a power of 2.") 

356 new_ext = [extremities[0]] 

357 for i in range(1, len(extremities)): 

358 a, b = extremities[i - 1:i + 1] 

359 if self.verbose > 1: 

360 size = sum(s.size for s in self.segments[a:b]) 

361 names = self.segments[a].begin, self.segments[b - 1].end 

362 self.fLOG(f"[OnnxSplitting] split into n={n}, from a={a} to b={b}, " 

363 f"size={size}, {names[0]!r} -> {names[1]!r}") 

364 pos = self._split_2(a, b) 

365 if self.verbose > 1: 

366 size_a = sum(s.size for s in self.segments[a:pos]) 

367 size_b = sum(s.size for s in self.segments[pos:b]) 

368 self.fLOG(f"[OnnxSplitting] found pos={pos}, size_1={size_a}, " 

369 f"size_2={size_b}={size_b/size:1.2f}, " 

370 f"split={self.segments[pos].begin!r}") 

371 new_ext.extend([pos, b]) 

372 extremities = new_ext 

373 n = n // 2 

374 return extremities 

375 

376 def make_onnx(self, extremities): 

377 """ 

378 Builds onnx subparts based on the segmentation 

379 defined by extremities. 

380 

381 :param extremities: example, `[0, 3, 5]`, 

382 first onnx part contains segments `0:3=[0, 1, 2]`, 

383 second onnx part contains segments `3:5=[3, 4]` 

384 :return: list of onnx subgraphs (:epkg:`ModelProto`) 

385 """ 

386 res = [] 

387 for i in range(1, len(extremities)): 

388 a, b = extremities[i - 1:i + 1] 

389 onx = self._make_onnx(a, b, i - 1) 

390 res.append(onx) 

391 if self.verbose > 0: 

392 n_nodes = len(self.onnx_model.graph.node) 

393 total = sum(s.size for s in self.segments) 

394 size = sum(self.segments[i].size for i in range(a, b)) 

395 self.fLOG(f"[OnnxSplitting] part {i}: " 

396 f"#nodes={len(onx.graph.node)}" # pylint: disable=E1101 

397 f"/{n_nodes}, size={size}/{total}={size/total:1.2f}") 

398 return res 

399 

400 def _make_onnx(self, a, b, index=None): 

401 """ 

402 Builds one onnx subpart including segments from a to b (excluded). 

403 """ 

404 if index is None: 

405 index = a 

406 

407 # common parts 

408 value_info = {o.name: o for o in self.onnx_model.graph.output} 

409 value_info.update({ 

410 info.name: info 

411 for info in self.shapes.graph.value_info}) # pylint: disable=E1101 

412 

413 segs = self.segments[a:b] 

414 involved = set() 

415 for seg in segs: 

416 involved |= seg.involved 

417 

418 # initiliazers 

419 new_inits = [init for init in self.onnx_model.graph.initializer 

420 if init.name in involved] 

421 new_sp_inits = [init for init in self.onnx_model.graph.sparse_initializer 

422 if init.name in involved] 

423 

424 # nodes 

425 nodes = [] 

426 for seg in segs: 

427 for _, node in seg.nodes: 

428 nodes.append(node) 

429 

430 # inputs, outputs 

431 existing_inputs = [i for i in self.onnx_model.graph.input 

432 if i.name in involved] 

433 if a == 0: 

434 new_inputs = existing_inputs 

435 else: 

436 new_inputs = [value_info[segs[0].begin]] + existing_inputs 

437 

438 if b == len(self.segments): 

439 new_outputs = [i for i in self.onnx_model.graph.output 

440 if i.name in involved] 

441 else: 

442 new_outputs = [value_info[segs[-1].end]] 

443 

444 model = self.onnx_model 

445 graph = make_graph( 

446 nodes, f"{model.graph.name}-{index}", 

447 new_inputs, new_outputs, 

448 new_inits, doc_string=model.graph.doc_string, 

449 sparse_initializer=new_sp_inits, 

450 value_info=model.graph.value_info) 

451 new_model = make_model(graph, opset_imports=model.opset_import) 

452 new_model.ir_version = model.ir_version 

453 new_model.producer_name = model.producer_name 

454 new_model.producer_version = model.producer_version 

455 new_model.domain = model.domain 

456 new_model.model_version = model.model_version 

457 new_model.doc_string = model.doc_string 

458 return new_model 

459 

460 

461def split_onnx(onnx_model, n_parts=None, cut_points=None, 

462 verbose=0, stats=False, fLOG=None): 

463 """ 

464 Splits an ONNX model into *n_parts* consecutive subgraphs. 

465 Chained altogether, they are equivalent to the given model. 

466 

467 :param onnx_model: onnx model 

468 :param n_parts: number of subgraphs 

469 :param cut_points: use different cutting points that the ones 

470 the algorithm can found, it must be in the available set 

471 of cutting points 

472 :param verbose: display information related to the split 

473 :param stats: returns statistics as well, return of the 

474 function is a tuple 

475 :param fLOG: logging function 

476 :return: list of onnx model 

477 """ 

478 if len(onnx_model.functions) > 0: 

479 raise NotImplementedError( 

480 f"The function does not work if the model contains function: " 

481 f"{f.name for f in onnx_model.functions}.") 

482 if n_parts is not None and not isinstance(n_parts, int): 

483 raise TypeError( 

484 f"n_parts must be None or an interger not {type(n_parts)}.") 

485 if cut_points is not None and not isinstance(cut_points, (list, tuple)): 

486 raise TypeError( 

487 f"cut_points must be None or a list not {type(n_parts)}.") 

488 if verbose > 0: 

489 (fLOG or print)( 

490 f"[split_onnx] prepare splitting " 

491 f"{len(onnx_model.graph.node)} nodes in {n_parts} parts.") 

492 spl_onnx = OnnxSplitting(onnx_model, verbose=verbose, fLOG=fLOG or print) 

493 if n_parts is not None and len(spl_onnx.cutting_points) < n_parts: 

494 raise RuntimeError( # pragma: no cover 

495 f"Unable to split the onnn model, there are less cutting points " 

496 f"{len(spl_onnx.cutting_points)} than the number of requested " 

497 f"splits ({n_parts}).") 

498 if verbose > 0: 

499 (fLOG or print)( 

500 f"[split_onnx] starts splitting " 

501 f"{len(onnx_model.graph.node)} nodes in {n_parts} parts.") 

502 exts = spl_onnx.split_segment(n_parts, cut_points=cut_points) 

503 if verbose > 0: 

504 names = [spl_onnx.segments[i].begin for i in exts[1:-1]] 

505 (fLOG or print)(f"[split_onnx] splits: {exts}, names={names}") 

506 res = spl_onnx.make_onnx(exts) 

507 if stats: 

508 more = dict( 

509 split=spl_onnx, 

510 segments=[dict(size=s.size, nodes=len(s.nodes), 

511 involved=s.involved) 

512 for s in spl_onnx.segments], 

513 cutting_points=spl_onnx.cutting_points, 

514 extremities=exts, 

515 split_points=[spl_onnx.segments[e].begin for e in exts[1:-1]]) 

516 return res, more 

517 return res