Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Shape object. 

4""" 

5import numpy 

6 

7 

8class BaseDimensionShape: 

9 """ 

10 Base class to @see cl DimensionObject, 

11 @see cl ShapeOperator, @see cl ShapeObject. 

12 """ 

13 

14 def to_string(self, use_x=True): 

15 """ 

16 Converts the object into a string. 

17 """ 

18 raise NotImplementedError() 

19 

20 def evaluate(self, **kwargs): 

21 """ 

22 Evaluates the object, reduces the expression 

23 to a number or a string. 

24 """ 

25 raise NotImplementedError() # pragma: no cover 

26 

27 

28class ShapeOperator(BaseDimensionShape): 

29 """ 

30 Base class for all shapes operator. 

31 """ 

32 

33 def __init__(self, name, fct, fct_string, *args): 

34 """ 

35 @param name display name of the operator 

36 @param fct function doing the operator 

37 if argument are numeric 

38 @param fct_string function represented as a string 

39 @param args argument of the operator 

40 """ 

41 self._name = name 

42 self._fct = fct 

43 self._fct_string = fct_string 

44 self._args = args 

45 for a in self._args: 

46 if not isinstance(a, DimensionObject): 

47 raise TypeError( 

48 "All arguments must be of type DimensionObject not '{}'." 

49 "".format(type(a))) 

50 

51 def __repr__(self): 

52 """ 

53 usual 

54 """ 

55 return "{0}('{1}', {2}, '{2}', {3})".format( 

56 self.__class__.__name__, self._name, 

57 self._fct_string, self._args) 

58 

59 def to_string(self, use_x=True): 

60 """ 

61 Displays as a string. 

62 

63 @return a string 

64 """ 

65 raise NotImplementedError( # pragma: no cover 

66 "Operator '{}' does not implement 'to_string': {}.".format( 

67 self.__class__.__name__, repr(self))) 

68 

69 def evaluate(self, **kwargs): 

70 """ 

71 Evalutes the operator. 

72 

73 @param kwargs value for the variables. 

74 @return string or integer 

75 """ 

76 args = [] 

77 has_string = False 

78 for a in self._args: 

79 a = DimensionObject._same_(a) 

80 v = a.evaluate(**kwargs) 

81 if isinstance(v, str): 

82 has_string = True 

83 args.append(v) 

84 if has_string: 

85 res = self._evaluate_string_(args, **kwargs) 

86 else: 

87 try: 

88 res = self._fct(*args) 

89 except TypeError as e: 

90 raise RuntimeError( 

91 "Unable to evaluate operator {} due to {}".format(repr(self), e)) from e 

92 return res 

93 

94 def _evaluate_string_(self, args, **kwargs): 

95 """ 

96 Evalutes the operator assuming some of them are still strings. 

97 

98 @param args arguments extracted by method *evaluate* 

99 @param kwargs value for the variables. 

100 @return string or integer 

101 """ 

102 raise NotImplementedError( 

103 "This function must be overwritten.") # pragma: no cover 

104 

105 

106class ShapeBinaryOperator(ShapeOperator): 

107 """ 

108 Base class for shape binary operator. 

109 """ 

110 

111 def __init__(self, name, fct, fct_string, x, y): 

112 """ 

113 @param name display name of the operator 

114 @param fct function doing the operator 

115 if argument are numeric 

116 @param fct_string function represented as a string 

117 @param x first argument 

118 @param y second argument 

119 """ 

120 ShapeOperator.__init__(self, name, fct, fct_string, x, y) 

121 if isinstance(x, tuple): 

122 raise TypeError('x cannot be a tuple') # pragma: no cover 

123 if isinstance(y, tuple): 

124 raise TypeError('y cannot be a tuple') # pragma: no cover 

125 

126 def _to_string1(self, x, y): 

127 return DimensionObject(self._fct(x._dim, y._dim)).to_string() 

128 

129 def _to_string2(self, x, y): 

130 return DimensionObject("{}{}{}".format(x._dim, self._name, y._dim)).to_string() 

131 

132 def _to_string2b(self, x, y): 

133 return DimensionObject("({}){}({})".format(x._dim, self._name, y._dim)).to_string() 

134 

135 def _to_string3(self, x): 

136 return DimensionObject("{}{}x".format(x._dim, self._name)).to_string() 

137 

138 def to_string(self, use_x=True): 

139 """ 

140 Applies binary operator to a dimension. 

141 

142 @param use_x use `'x'` if dimension is unknown 

143 @return a string 

144 """ 

145 x, y = self._args # pylint: disable=W0632 

146 if isinstance(x._dim, int): 

147 if isinstance(y, DimensionObject): 

148 if isinstance(y._dim, int): 

149 return self._to_string1(x, y) 

150 if isinstance(y._dim, str): 

151 return self._to_string2(x, y) 

152 if y._dim is None: 

153 if use_x: 

154 return self._to_string3(x) 

155 return DimensionObject("{}{}DimensionObject()".format( 

156 x._dim, self._name)).to_string() 

157 raise TypeError( # pragma: no cover 

158 "Unable to handle type '{}'.".format(type(y._dim))) 

159 raise TypeError( # pragma: no cover 

160 "Unable to handle type '{}'.".format(type(y))) 

161 elif isinstance(x._dim, str): 

162 if isinstance(y._dim, int): 

163 return self._to_string2(x, y) 

164 if isinstance(y._dim, str): 

165 return self._to_string2b(x, y) 

166 raise TypeError( # pragma: no cover 

167 "Unable to handle type '{}'.".format(type(y._dim))) 

168 raise TypeError( # pragma: no cover 

169 "Unable to handle type '{}'.".format(type(x._dim))) 

170 

171 def _evaluate_string_(self, args, **kwargs): 

172 """ 

173 Evalutes the operator assuming some of them are still strings. 

174 

175 @param args arguments extracted by method *evaluate* 

176 @param kwargs value for the variables. 

177 @return string or integer 

178 """ 

179 return self._name.join(map(lambda s: '({})'.format(s), args)) 

180 

181 

182class ShapeBinaryFctOperator(ShapeBinaryOperator): 

183 """ 

184 Base class for shape binary operator defined by a function. 

185 """ 

186 

187 def _to_string2(self, x, y): 

188 return DimensionObject("{}({},{})".format(self._name, x._dim, y._dim)).to_string() 

189 

190 def _to_string2b(self, x, y): 

191 return DimensionObject("{}({},{})".format(self._name, x._dim, y._dim)).to_string() 

192 

193 def _to_string3(self, x): 

194 return DimensionObject("{}({},x)".format(self._name, x._dim)).to_string() 

195 

196 def _evaluate_string_(self, args, **kwargs): 

197 """ 

198 Evalutes the operator assuming some of them are still strings. 

199 

200 @param args arguments extracted by method *evaluate* 

201 @param kwargs value for the variables. 

202 @return string or integer 

203 """ 

204 return "{}({})".format(self._name, ",".join(map(str, args))) 

205 

206 

207class ShapeOperatorAdd(ShapeBinaryOperator): 

208 """ 

209 Shape addition. 

210 """ 

211 

212 def __init__(self, x, y): 

213 ShapeBinaryOperator.__init__( 

214 self, '+', lambda a, b: a + b, 'lambda a, b: a + b', x, y) 

215 

216 def __repr__(self): 

217 """ 

218 Displays a string. 

219 

220 @return a string 

221 """ 

222 return "{0}({1}, {2})".format( 

223 self.__class__.__name__, repr(self._args[0]), repr(self._args[1])) 

224 

225 

226class ShapeOperatorMul(ShapeBinaryOperator): 

227 """ 

228 Shape multiplication. 

229 """ 

230 

231 def __init__(self, x, y): 

232 ShapeBinaryOperator.__init__( 

233 self, '*', lambda a, b: a * b, 'lambda a, b: a * b', x, y) 

234 

235 def __repr__(self): 

236 """ 

237 Displays a string. 

238 

239 @return a string 

240 """ 

241 return "{0}({1}, {2})".format( 

242 self.__class__.__name__, repr(self._args[0]), repr(self._args[1])) 

243 

244 

245class ShapeOperatorGreater(ShapeBinaryOperator): 

246 """ 

247 Shape comparison. 

248 """ 

249 

250 def __init__(self, x, y): 

251 ShapeBinaryOperator.__init__( 

252 self, '>', lambda a, b: a > b, 'lambda a, b: a > b', x, y) 

253 

254 def __repr__(self): 

255 """ 

256 Displays a string. 

257 

258 @return a string 

259 """ 

260 return "{0}({1}, {2})".format( 

261 self.__class__.__name__, repr(self._args[0]), repr(self._args[1])) 

262 

263 

264class ShapeOperatorMax(ShapeBinaryFctOperator): 

265 """ 

266 Best on each dimension. 

267 """ 

268 

269 def __init__(self, x, y): 

270 ShapeBinaryFctOperator.__init__( 

271 self, 'max', lambda a, b: max(a, b), 'max(a, b)', x, y) 

272 

273 def __repr__(self): 

274 """ 

275 Displays a string. 

276 

277 @return a string 

278 """ 

279 return "{0}({1}, {2})".format( 

280 self.__class__.__name__, repr(self._args[0]), repr(self._args[1])) 

281 

282 

283class DimensionObject(BaseDimensionShape): 

284 """ 

285 One dimension of a shape. 

286 """ 

287 

288 def __init__(self, obj): 

289 """ 

290 @param obj int or @see cl DimensionObject or None to 

291 specify something unknown 

292 """ 

293 if obj is None or obj == 0 or obj == '?': 

294 self._dim = None 

295 elif isinstance(obj, (int, str, ShapeOperator, DimensionObject, 

296 numpy.int32, numpy.int64)): 

297 self._dim = obj 

298 else: 

299 raise TypeError("Unexpected type for obj: {}".format(type(obj))) 

300 

301 @property 

302 def dim(self): 

303 """ 

304 Returns the dimension. 

305 """ 

306 return self._dim 

307 

308 def __repr__(self): 

309 """ 

310 usual 

311 """ 

312 if isinstance(self._dim, int): 

313 return "DimensionObject({})".format(self._dim) 

314 if isinstance(self._dim, DimensionObject): 

315 return repr(self._dim) 

316 if isinstance(self._dim, ShapeOperator): 

317 return "DimensionObject({})".format(repr(self._dim)) 

318 return "DimensionObject('{}')".format(self._dim) 

319 

320 @staticmethod 

321 def _same_(obj): 

322 """ 

323 Returns *obj* if *obj* is @see cl DimensionObject 

324 otherwise converts it. 

325 """ 

326 if isinstance(obj, DimensionObject): 

327 return obj 

328 return DimensionObject(obj) 

329 

330 def to_string(self, use_x=True): 

331 """ 

332 Represents the dimension as a string. 

333 """ 

334 if isinstance(self._dim, int): 

335 return '{}'.format(self._dim) 

336 if isinstance(self._dim, ShapeOperator): 

337 return self._dim.to_string() 

338 if isinstance(self._dim, str): 

339 return self._dim 

340 if self._dim is None: 

341 return 'x' if use_x else '?' 

342 raise NotImplementedError( # pragma: no cover 

343 "Not implemented for '{}'.".format(repr(self))) 

344 

345 def evaluate(self, **kwargs): 

346 """ 

347 Evalutes the dimension. 

348 

349 @param kwargs value for the variables. 

350 @return string or integer 

351 """ 

352 if isinstance(self._dim, (int, ShapeOperator, DimensionObject)): 

353 res = self._dim 

354 elif isinstance(self._dim, str): 

355 if self._dim in kwargs: 

356 res = kwargs[self._dim] 

357 else: 

358 res = self._dim 

359 elif self._dim is None: 

360 pref = str(hex(id(self)))[2:] 

361 res = "n{}".format(pref) 

362 elif isinstance(self._dim, ): 

363 res = self._dim.evaluate(**kwargs) 

364 else: 

365 raise NotImplementedError( # pragma: no cover 

366 "Not implemented for '{}'.".format(repr(self))) 

367 if isinstance(res, (ShapeOperator, DimensionObject)): 

368 return res.evaluate(**kwargs) 

369 return res 

370 

371 def __eq__(self, v): 

372 """ 

373 usual 

374 """ 

375 if isinstance(v, (int, str)): 

376 return self._dim == v 

377 if isinstance(v, DimensionObject): 

378 return v == self._dim 

379 if isinstance(v, ShapeOperator): 

380 ve = v.evaluate() 

381 return ve == self._dim 

382 if v is None: 

383 return self._dim is None 

384 raise TypeError( # pragma: no cover 

385 "Unable to compare a DimensionObject to {}".format(type(v))) 

386 

387 def __add__(self, obj): 

388 """ 

389 usual 

390 """ 

391 return DimensionObject( 

392 ShapeOperatorAdd(self, DimensionObject._same_(obj))) 

393 

394 def __mul__(self, obj): 

395 """ 

396 usual 

397 """ 

398 return DimensionObject( 

399 ShapeOperatorMul(self, DimensionObject._same_(obj))) 

400 

401 def __gt__(self, obj): 

402 """ 

403 usual 

404 """ 

405 if obj is None: 

406 return not isinstance(self._dim, int) 

407 if isinstance(self._dim, int) and isinstance(obj._dim, int): 

408 return self._dim > obj._dim 

409 return DimensionObject( 

410 ShapeOperatorGreater(self, DimensionObject._same_(obj))) 

411 

412 

413class ShapeObject(BaseDimensionShape): 

414 """ 

415 Handles mathematical operations around shapes. 

416 It stores a type (:epkg:`numpy` type), 

417 and a name to somehow have an idea of where 

418 the shape comes from in the :epkg:`ONNX` graph. 

419 The shape itself is defined by a list of 

420 @see cl DimensionObject or @see cl ShapeOperator 

421 or *None* if the shape is unknown. A dimension is an 

422 integer or a variable encoded as a string. This variable 

423 is a way to tell the dimension may vary. 

424 

425 .. runpython:: 

426 :showcode: 

427 :warningout: DeprecationWarning 

428 

429 import numpy 

430 from mlprodict.onnxrt.shape_object import ShapeObject 

431 

432 sh1 = ShapeObject((1, 2), dtype=numpy.float32) 

433 sh2 = ShapeObject((45, 2), dtype=numpy.float32) 

434 mx = max(sh1, sh2) 

435 print(mx) 

436 

437 sh1 = ShapeObject((1, 2), dtype=numpy.float32) 

438 sh2 = ShapeObject((None, 2), dtype=numpy.float32) 

439 print(sh2) 

440 mx = max(sh1, sh2) 

441 print(mx.to_string()) 

442 

443 sh1 = ShapeObject((1, 2), dtype=numpy.float32) 

444 sh2 = ShapeObject(('n', 2), dtype=numpy.float32) 

445 print(sh2) 

446 mx = max(sh1, sh2) 

447 print(mx.evaluate(n=4)) 

448 """ 

449 

450 def __init__(self, shape, dtype=None, use_n1=False, name=None): 

451 """ 

452 @param shape tuple or `numpy.array` 

453 @param dtype dtype 

454 @param use_n1 use `'n'` if the first dimension is unknown 

455 @param name optional, for debugging purposes 

456 """ 

457 self.name = name 

458 if isinstance(shape, numpy.ndarray): 

459 self._shape = [DimensionObject(s) for s in shape.shape] 

460 self._dtype = shape.dtype 

461 elif isinstance(shape, dict) and 'type' in shape: 

462 tshape = shape['type'] 

463 if tshape['kind'] == 'tensor': 

464 if tshape['shape'] == ('?', ): 

465 self._shape = None 

466 else: 

467 self._shape = [DimensionObject(s) for s in tshape['shape']] 

468 self._dtype = tshape['elem'] 

469 elif tshape['kind'] == 'map': 

470 self._shape = [] 

471 self._dtype = 'map' 

472 elif tshape['kind'] == 'sequence': 

473 self._shape = [] 

474 self._dtype = 'sequence' 

475 else: 

476 raise ValueError( # pragma: no cover 

477 "Wrong shape value {}".format(shape)) 

478 elif isinstance(shape, (tuple, list)): 

479 self._shape = [] 

480 for s in shape: 

481 self._shape.append(DimensionObject(s)) 

482 self._dtype = dtype 

483 elif shape is None: 

484 # shape is unknown 

485 self._shape = None 

486 self._dtype = dtype 

487 else: 

488 raise TypeError( # pragma: no cover 

489 "Unexpected type for shape: {}, shape={}".format( 

490 type(shape), shape)) 

491 

492 def _dtype_again(): 

493 if self._dtype is None: 

494 raise ValueError( 

495 "dtype cannot be None, shape type is {}\n{}".format( 

496 type(shape), shape)) 

497 if isinstance(self._dtype, numpy.dtype): 

498 # no need to go further 

499 return 

500 if self._dtype in (float, 'double', 'tensor(double)'): 

501 self._dtype = numpy.float64 

502 elif self._dtype in ('float32', 'float', 'tensor(float)'): 

503 self._dtype = numpy.float32 

504 elif self._dtype in (numpy.float16, 'float16', 'tensor(float16)'): 

505 self._dtype = numpy.float16 

506 elif self._dtype in ('int32', 'tensor(int32)'): 

507 self._dtype = numpy.int32 

508 elif self._dtype in (int, 'int', 'int64', 'tensor(int64)'): 

509 self._dtype = numpy.int64 

510 elif self._dtype in (str, 'str', numpy.str_, 'tensor(str)'): 

511 self._dtype = numpy.str_ 

512 elif (hasattr(self._dtype, 'type') and self._dtype.type is numpy.string_): 

513 pass 

514 elif self._dtype in (bool, 'bool', numpy.bool_): 

515 self._dtype = numpy.bool_ 

516 elif self._dtype in (object, numpy.object_): 

517 pass 

518 elif self._dtype in (numpy.int8, 'int8', ): 

519 self._dtype = numpy.int8 

520 elif self._dtype in (numpy.uint8, 'uint8', ): 

521 self._dtype = numpy.uint8 

522 elif self._dtype in (numpy.int16, 'int16', ): 

523 self._dtype = numpy.int16 

524 elif self._dtype in (numpy.uint16, 'uint16', ): 

525 self._dtype = numpy.uint16 

526 elif self._dtype in (numpy.uint32, 'uint32', ): 

527 self._dtype = numpy.uint32 

528 elif self._dtype in (numpy.uint64, 'uint64', ): 

529 self._dtype = numpy.uint64 

530 elif self._dtype in (numpy.complex64, 'complex64', ): 

531 self._dtype = numpy.complex64 

532 elif self._dtype in (numpy.complex128, 'complex128', ): 

533 self._dtype = numpy.complex128 

534 elif self._dtype == "tensor({'kind': 'tensor', 'elem': 'float', 'shape': })": 

535 self._dtype = numpy.float32 

536 elif self._dtype not in { 

537 numpy.float32, numpy.float64, numpy.int32, numpy.int64, 

538 numpy.str_, numpy.bool_, numpy.float16, None, 

539 numpy.complex64, numpy.complex128, 

540 'map', 'sequence'}: 

541 raise ValueError( # pragma: no cover 

542 "dtype has an unexpected value: '{}'.".format(self._dtype)) 

543 try: 

544 _dtype_again() 

545 except TypeError as e: 

546 raise TypeError( # pragma: no cover 

547 "Unexpected error with %r of type %r." % ( 

548 (self._dtype, type(self._dtype)))) from e 

549 

550 def _shape_again(): 

551 if self._shape is not None: 

552 for i, a in enumerate(self._shape): 

553 if not isinstance(a, DimensionObject): 

554 raise TypeError( # pragma: no cover 

555 'Dimension {} has a wrong type {}'.format( 

556 i, type(a))) 

557 if use_n1: 

558 sh = self._shape[0] if self._shape else None 

559 if isinstance(sh, DimensionObject) and sh._dim is None: 

560 sh._dim = 'n' 

561 if self._shape is not None: 

562 for s in self._shape: 

563 if isinstance(s, int): 

564 raise TypeError( # pragma: no cover 

565 "Unexpected type int in shape %r." % self) 

566 _shape_again() 

567 

568 def reshape(self, shape): 

569 """ 

570 Creates a new shape, checks the number of elements is the same. 

571 """ 

572 sh = ShapeObject(shape, self.dtype, getattr(self, '_dim', None), 

573 self.name) 

574 p1 = self.product().evaluate() 

575 p2 = sh.product().evaluate() 

576 if isinstance(p1, int) and p1 != p2: 

577 raise ValueError("Shape {} cannot be reshaped into {} " 

578 "(p1={}, p2={}).".format(sh, shape, p1, p2)) 

579 return sh 

580 

581 def copy(self, dtype=None, name=None): 

582 """ 

583 A copy not a deepcopy. 

584 

585 @param dtype None or a value to rewrite the type. 

586 @param name overwrites the name 

587 @return @see cl ShapeObject 

588 """ 

589 if self._shape is None: 

590 return ShapeObject(None, dtype=self.dtype, name=name or self.name) 

591 return ShapeObject(self._shape.copy(), 

592 self.dtype if dtype is None else dtype, 

593 name=name or self.name) 

594 

595 def __getitem__(self, index): 

596 """ 

597 Extracts a specific dimension. 

598 """ 

599 if self._shape is None: 

600 return None 

601 if isinstance(index, int) and index >= len(self._shape): 

602 return 1 

603 return self._shape[index] 

604 

605 def __setitem__(self, index, value): 

606 """ 

607 Changes a specific dimension. 

608 """ 

609 if self._shape is None: 

610 return 

611 while len(self._shape) <= index: 

612 self._shape.append(DimensionObject(1)) 

613 self._shape[index] = value 

614 

615 @property 

616 def shape(self): 

617 """ 

618 Returns the stored shape. 

619 """ 

620 if self._shape is None: 

621 return None 

622 return tuple(self._shape) 

623 

624 def __len__(self): 

625 """ 

626 Returns the number of dimensions. 

627 """ 

628 if self._shape is None: 

629 return 0 

630 return len(self._shape) 

631 

632 @property 

633 def dtype(self): 

634 """ 

635 Returns the stored *dtype*. 

636 """ 

637 return self._dtype 

638 

639 def reduce(self, axis=1, keepdims=False, dtype=None): 

640 """ 

641 Reduces the matrix. Removes one dimension. 

642 

643 @param axis axis 

644 @param keepdims keep dimensions, replaces the removed 

645 dimension by 1 

646 @param dtype if not None, changes the type 

647 @return new dimension 

648 """ 

649 if self._shape is None: 

650 if self.name is None: 

651 return self.copy() 

652 return self.copy(name="{}-RD".format(self.name)) 

653 if axis is None: 

654 return ShapeObject((1, ), self._dtype if dtype is None else dtype, 

655 name="{}-RDN".format(self.name)) 

656 

657 if isinstance(axis, ShapeObject): 

658 

659 def drop_axis(shape, a): 

660 c = list(shape) 

661 del c[a[0]] 

662 return c 

663 

664 return ShapeObjectFct( 

665 drop_axis, self, axis, name="DropAxis", dtype=self.dtype) 

666 

667 if 0 <= axis < len(self._shape): 

668 cp = self._shape.copy() 

669 if keepdims: 

670 cp[axis] = DimensionObject(1) 

671 else: 

672 del cp[axis] 

673 return ShapeObject(cp, self._dtype if dtype is None else dtype, 

674 name="{}-RD".format(self.name)) 

675 raise IndexError("axis={} is wrong, shape is {}-tuple and equal to " 

676 "{}".format(axis, len(self._shape), self)) 

677 

678 def __repr__(self): 

679 """ 

680 usual 

681 """ 

682 st = str(self.dtype) 

683 if "'" in st: 

684 st = st.split("'")[1] 

685 

686 if self.shape is None: 

687 if self.name is None: 

688 return "ShapeObject(None, dtype={})".format(st) 

689 return "ShapeObject(None, dtype={}, name='{}')".format(st, self.name) 

690 

691 st_shape = [] 

692 for s in self.shape: 

693 if isinstance(getattr(s, "_dim", None), (int, str)): 

694 st_shape.append(str(s._dim)) 

695 else: 

696 st_shape.append(repr(s)) 

697 if len(st_shape) == 1: 

698 st_shape.append('') 

699 st_shape = '({})'.format(", ".join(st_shape)) 

700 if self.name is None: 

701 return "ShapeObject({}, dtype={})".format(st_shape, st) 

702 return "ShapeObject({}, dtype={}, name='{}')".format( 

703 st_shape, st, self.name) 

704 

705 def __iter__(self): 

706 """ 

707 Iterators over dimensions. 

708 """ 

709 if self._shape is not None: 

710 for d in self._shape: 

711 yield d 

712 

713 def __gt__(self, a): 

714 """ 

715 Compares shapes. Operator ``>``. 

716 """ 

717 if isinstance(a, tuple): 

718 a = ShapeObject(a, dtype=self._dtype) 

719 if self._shape is None and a._shape is None: 

720 return False 

721 if self._shape is None: 

722 return True 

723 if a._shape is None: 

724 return False 

725 if len(self) > len(a): 

726 return True 

727 if len(self) < len(a): 

728 return False 

729 for d1, d2 in zip(self, a): 

730 if d1 > d2: 

731 return True 

732 if d1 < d2: 

733 return False 

734 return False 

735 

736 def __eq__(self, a): 

737 """ 

738 Tests equality between two shapes. 

739 """ 

740 if isinstance(a, tuple): 

741 a = ShapeObject(a, dtype=self._dtype) 

742 if self._shape is None and a._shape is None: 

743 return True 

744 if self._shape is None or a._shape is None: 

745 return False 

746 if len(self) != len(a): 

747 return False 

748 for d1, d2 in zip(self, a): 

749 if d1 == d2: 

750 continue 

751 return False 

752 return True 

753 

754 def evaluate(self, **kwargs): 

755 """ 

756 Evaluates the shape. 

757 """ 

758 vs = [] 

759 for v in self: 

760 d = v.evaluate(**kwargs) 

761 vs.append(d) 

762 return ShapeObject(tuple(vs), self._dtype, name="{}-EV".format(self.name)) 

763 

764 def to_string(self, use_x=False): 

765 """ 

766 Converts shapes into a string. 

767 """ 

768 shapes = [] 

769 for a in self._shape: 

770 shapes.append(a.to_string(use_x=use_x)) 

771 return '({})'.format(', '.join(shapes)) 

772 

773 def product(self): 

774 """ 

775 Multiplies all the dimension. 

776 

777 @return @see cl DimensionObject 

778 """ 

779 cl = self[0] 

780 for i in range(1, len(self)): 

781 cl = cl * self[i] 

782 return cl 

783 

784 def append(self, dim): 

785 """ 

786 Appends a dimension. 

787 """ 

788 if self._shape is None: 

789 return 

790 if isinstance(dim, DimensionObject): 

791 self._shape.append(dim) 

792 else: 

793 self._shape.append(DimensionObject(dim)) 

794 

795 def insert(self, dim, pos=0): 

796 """ 

797 Inserts a dimension at position *pos*. 

798 """ 

799 if self._shape is None: 

800 return 

801 if isinstance(dim, DimensionObject): 

802 self._shape.insert(pos, dim) 

803 else: 

804 self._shape.insert(pos, DimensionObject(dim)) 

805 

806 def squeeze(self, axis): 

807 """ 

808 Removes one dimension. 

809 """ 

810 cp = self.copy(name='{}-SZ'.format(self.name)) 

811 cp.drop_axis(axis) 

812 return cp 

813 

814 def unsqueeze(self, axes): 

815 """ 

816 Adds dimensions. 

817 """ 

818 cp = self 

819 name = '{}-USZ'.format(self.name) 

820 for ax in axes[::-1]: 

821 cp = cp.copy(name=name) 

822 cp.insert(ax, 1) 

823 return cp 

824 

825 def transpose(self, perm): 

826 """ 

827 Removes one dimension. 

828 """ 

829 if self.shape is None: 

830 return self.copy(name='{}-TR'.format(self.name)) 

831 cp = ShapeObject([None for p in perm], dtype=self.dtype, 

832 name="{}-TR".format(self.name)) 

833 for i, p in enumerate(perm): 

834 if p >= len(self): 

835 # This should not happen. 

836 cp._shape[i] = None 

837 else: 

838 cp._shape[i] = self._shape[p] 

839 return cp 

840 

841 def drop_axis(self, axis): 

842 """ 

843 Drops an axis. 

844 """ 

845 if self._shape is not None: 

846 if isinstance(axis, (tuple, list)): 

847 for i in sorted(axis, reverse=True): 

848 del self._shape[i] 

849 else: 

850 del self._shape[axis] 

851 

852 def broadcast(self, a): 

853 """ 

854 Computes the shape after a broadcast. 

855 """ 

856 if a is None: 

857 raise ValueError("a should not be None") # pragma: no cover 

858 if a._shape is None: 

859 return a.copy() 

860 if self._shape is None: 

861 return self.copy() 

862 mx = max(len(self._shape), len(a._shape)) 

863 res = [] 

864 for i in range(mx): 

865 if i < len(self._shape): 

866 if i < len(a._shape): 

867 res.append(ShapeOperatorMax(self[i], a[i])) 

868 else: 

869 res.append(self[i]) 

870 else: 

871 res.append(a[i]) 

872 return ShapeObject(tuple(res), self.dtype, False, 

873 name="broadcast-{}-{}".format(self.name, a.name)) 

874 

875 @staticmethod 

876 def _infer_merged_type(*args, use_dtype=True): 

877 if use_dtype: 

878 tys = set(a.dtype for a in args) 

879 else: 

880 tys = set(args) 

881 if len(tys) == 1: 

882 return list(tys)[0] 

883 if any(tys & {numpy.float64, numpy.int64, 

884 numpy.float32, numpy.int32, 

885 numpy.float16}): 

886 return numpy.float64 

887 raise RuntimeError( # pragma: no cover 

888 "Unable to infer types based on {} ({}).".format( 

889 tys, len(tys))) 

890 

891 def concat_columns(self, axis, *shapes): 

892 """ 

893 Concatenates columns from *shapes* to this one 

894 along one axis. 

895 """ 

896 args = [self] + list(shapes) 

897 dtype = self._infer_merged_type(*args) 

898 dim_axis = self[axis] 

899 if isinstance(dim_axis, int): 

900 dim_axis = DimensionObject(dim_axis) 

901 if dim_axis is None: 

902 return ShapeObject(None, dtype=dtype) 

903 if isinstance(dim_axis, int): 

904 raise TypeError( # pragma: no cover 

905 "Unexpected type for shape %r." % self) 

906 for a in shapes: 

907 if a[axis] is None: 

908 return ShapeObject(None, dtype=dtype) 

909 dim_axis = dim_axis + a[axis] 

910 a0 = args[0].copy(dtype=dtype) 

911 a0[axis] = dim_axis 

912 return a0 

913 

914 @staticmethod 

915 def einsum_shape(equation, *inputs): 

916 """ 

917 Computes :epkg:`einsum` shapes. 

918 Not the most efficient one as it creates variables 

919 of the given shapes. 

920 """ 

921 for inp in inputs: 

922 if inp.shape is None: 

923 return inp 

924 inp, out = [_.strip() for _ in equation.split(b"->")] 

925 inps = [_.strip() for _ in inp.split(b',')] 

926 if len(inputs) != len(inps): 

927 raise RuntimeError( # pragma: no cover 

928 "Input mismatch between '{}' and {}.".format(equation, inps)) 

929 shs = {} 

930 for a, b in zip(inps, inputs): 

931 if len(a) != len(b): 

932 raise RuntimeError( # pragma: no cover 

933 "Input mismatch '{}' (in '{}') and {}.".format(a, equation, b)) 

934 for c, s in zip(a, b): 

935 if c not in shs: 

936 shs[c] = s 

937 elif shs[c] != s: 

938 raise RuntimeError( # pragma: no cover 

939 "Equation '{}'. Dimension mismatch '{}' != {}.".format( 

940 equation, s, shs[c])) 

941 new_shape = [shs[i] for i in out] 

942 return ShapeObject(new_shape, dtype=ShapeObject._infer_merged_type(*inputs)) 

943 

944 @staticmethod 

945 def gather_shape(input, indices, axis): 

946 """ 

947 Computes Gather shapes. 

948 """ 

949 input_rank = len(input) 

950 if input_rank is None: 

951 return ShapeObject(None, dtype=input._dtype) 

952 index_rank = len(indices) 

953 if index_rank is None: 

954 return ShapeObject(None, dtype=input._dtype) 

955 

956 if axis < 0: 

957 axis = input_rank + axis 

958 

959 shape = [] 

960 for i in range(axis): 

961 shape.append(input[i]) 

962 

963 for dim in indices: 

964 shape.append(dim) 

965 

966 for i in range(axis + 1, input_rank): 

967 shape.append(input[i]) 

968 

969 return ShapeObject(shape, dtype=input._dtype) 

970 

971 

972class ShapeObjectFct(ShapeObject): 

973 """ 

974 Computes a shape depending on a user defined function. 

975 See @see cl Conv for an example. 

976 """ 

977 

978 def __init__(self, fct, *shapes, dtype=None, name=None): 

979 """ 

980 @param fct function 

981 @param shapes shapes sent to fct 

982 @param dtype dtype 

983 @param name optional, for debugging purposes 

984 """ 

985 ShapeObject.__init__(self, None, dtype=dtype, name=name) 

986 self._fct = fct 

987 self._shapes = shapes 

988 

989 def evaluate(self, **kwargs): 

990 """ 

991 Evaluates the shape. 

992 """ 

993 vs = [] 

994 for v in self._shapes: 

995 d = v.evaluate(**kwargs) 

996 vs.append(d) 

997 res = self._fct(*vs) 

998 if self.name is not None: 

999 res.name = self.name 

1000 return res