Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Shape object.
4"""
5import numpy
8class BaseDimensionShape:
9 """
10 Base class to @see cl DimensionObject,
11 @see cl ShapeOperator, @see cl ShapeObject.
12 """
14 def to_string(self, use_x=True):
15 """
16 Converts the object into a string.
17 """
18 raise NotImplementedError()
20 def evaluate(self, **kwargs):
21 """
22 Evaluates the object, reduces the expression
23 to a number or a string.
24 """
25 raise NotImplementedError() # pragma: no cover
28class ShapeOperator(BaseDimensionShape):
29 """
30 Base class for all shapes operator.
31 """
33 def __init__(self, name, fct, fct_string, *args):
34 """
35 @param name display name of the operator
36 @param fct function doing the operator
37 if argument are numeric
38 @param fct_string function represented as a string
39 @param args argument of the operator
40 """
41 self._name = name
42 self._fct = fct
43 self._fct_string = fct_string
44 self._args = args
45 for a in self._args:
46 if not isinstance(a, DimensionObject):
47 raise TypeError(
48 "All arguments must be of type DimensionObject not '{}'."
49 "".format(type(a)))
51 def __repr__(self):
52 """
53 usual
54 """
55 return "{0}('{1}', {2}, '{2}', {3})".format(
56 self.__class__.__name__, self._name,
57 self._fct_string, self._args)
59 def to_string(self, use_x=True):
60 """
61 Displays as a string.
63 @return a string
64 """
65 raise NotImplementedError( # pragma: no cover
66 "Operator '{}' does not implement 'to_string': {}.".format(
67 self.__class__.__name__, repr(self)))
69 def evaluate(self, **kwargs):
70 """
71 Evalutes the operator.
73 @param kwargs value for the variables.
74 @return string or integer
75 """
76 args = []
77 has_string = False
78 for a in self._args:
79 a = DimensionObject._same_(a)
80 v = a.evaluate(**kwargs)
81 if isinstance(v, str):
82 has_string = True
83 args.append(v)
84 if has_string:
85 res = self._evaluate_string_(args, **kwargs)
86 else:
87 try:
88 res = self._fct(*args)
89 except TypeError as e:
90 raise RuntimeError(
91 "Unable to evaluate operator {} due to {}".format(repr(self), e)) from e
92 return res
94 def _evaluate_string_(self, args, **kwargs):
95 """
96 Evalutes the operator assuming some of them are still strings.
98 @param args arguments extracted by method *evaluate*
99 @param kwargs value for the variables.
100 @return string or integer
101 """
102 raise NotImplementedError(
103 "This function must be overwritten.") # pragma: no cover
106class ShapeBinaryOperator(ShapeOperator):
107 """
108 Base class for shape binary operator.
109 """
111 def __init__(self, name, fct, fct_string, x, y):
112 """
113 @param name display name of the operator
114 @param fct function doing the operator
115 if argument are numeric
116 @param fct_string function represented as a string
117 @param x first argument
118 @param y second argument
119 """
120 ShapeOperator.__init__(self, name, fct, fct_string, x, y)
121 if isinstance(x, tuple):
122 raise TypeError('x cannot be a tuple') # pragma: no cover
123 if isinstance(y, tuple):
124 raise TypeError('y cannot be a tuple') # pragma: no cover
126 def _to_string1(self, x, y):
127 return DimensionObject(self._fct(x._dim, y._dim)).to_string()
129 def _to_string2(self, x, y):
130 return DimensionObject("{}{}{}".format(x._dim, self._name, y._dim)).to_string()
132 def _to_string2b(self, x, y):
133 return DimensionObject("({}){}({})".format(x._dim, self._name, y._dim)).to_string()
135 def _to_string3(self, x):
136 return DimensionObject("{}{}x".format(x._dim, self._name)).to_string()
138 def to_string(self, use_x=True):
139 """
140 Applies binary operator to a dimension.
142 @param use_x use `'x'` if dimension is unknown
143 @return a string
144 """
145 x, y = self._args # pylint: disable=W0632
146 if isinstance(x._dim, int):
147 if isinstance(y, DimensionObject):
148 if isinstance(y._dim, int):
149 return self._to_string1(x, y)
150 if isinstance(y._dim, str):
151 return self._to_string2(x, y)
152 if y._dim is None:
153 if use_x:
154 return self._to_string3(x)
155 return DimensionObject("{}{}DimensionObject()".format(
156 x._dim, self._name)).to_string()
157 raise TypeError( # pragma: no cover
158 "Unable to handle type '{}'.".format(type(y._dim)))
159 raise TypeError( # pragma: no cover
160 "Unable to handle type '{}'.".format(type(y)))
161 elif isinstance(x._dim, str):
162 if isinstance(y._dim, int):
163 return self._to_string2(x, y)
164 if isinstance(y._dim, str):
165 return self._to_string2b(x, y)
166 raise TypeError( # pragma: no cover
167 "Unable to handle type '{}'.".format(type(y._dim)))
168 raise TypeError( # pragma: no cover
169 "Unable to handle type '{}'.".format(type(x._dim)))
171 def _evaluate_string_(self, args, **kwargs):
172 """
173 Evalutes the operator assuming some of them are still strings.
175 @param args arguments extracted by method *evaluate*
176 @param kwargs value for the variables.
177 @return string or integer
178 """
179 return self._name.join(map(lambda s: '({})'.format(s), args))
182class ShapeBinaryFctOperator(ShapeBinaryOperator):
183 """
184 Base class for shape binary operator defined by a function.
185 """
187 def _to_string2(self, x, y):
188 return DimensionObject("{}({},{})".format(self._name, x._dim, y._dim)).to_string()
190 def _to_string2b(self, x, y):
191 return DimensionObject("{}({},{})".format(self._name, x._dim, y._dim)).to_string()
193 def _to_string3(self, x):
194 return DimensionObject("{}({},x)".format(self._name, x._dim)).to_string()
196 def _evaluate_string_(self, args, **kwargs):
197 """
198 Evalutes the operator assuming some of them are still strings.
200 @param args arguments extracted by method *evaluate*
201 @param kwargs value for the variables.
202 @return string or integer
203 """
204 return "{}({})".format(self._name, ",".join(map(str, args)))
207class ShapeOperatorAdd(ShapeBinaryOperator):
208 """
209 Shape addition.
210 """
212 def __init__(self, x, y):
213 ShapeBinaryOperator.__init__(
214 self, '+', lambda a, b: a + b, 'lambda a, b: a + b', x, y)
216 def __repr__(self):
217 """
218 Displays a string.
220 @return a string
221 """
222 return "{0}({1}, {2})".format(
223 self.__class__.__name__, repr(self._args[0]), repr(self._args[1]))
226class ShapeOperatorMul(ShapeBinaryOperator):
227 """
228 Shape multiplication.
229 """
231 def __init__(self, x, y):
232 ShapeBinaryOperator.__init__(
233 self, '*', lambda a, b: a * b, 'lambda a, b: a * b', x, y)
235 def __repr__(self):
236 """
237 Displays a string.
239 @return a string
240 """
241 return "{0}({1}, {2})".format(
242 self.__class__.__name__, repr(self._args[0]), repr(self._args[1]))
245class ShapeOperatorGreater(ShapeBinaryOperator):
246 """
247 Shape comparison.
248 """
250 def __init__(self, x, y):
251 ShapeBinaryOperator.__init__(
252 self, '>', lambda a, b: a > b, 'lambda a, b: a > b', x, y)
254 def __repr__(self):
255 """
256 Displays a string.
258 @return a string
259 """
260 return "{0}({1}, {2})".format(
261 self.__class__.__name__, repr(self._args[0]), repr(self._args[1]))
264class ShapeOperatorMax(ShapeBinaryFctOperator):
265 """
266 Best on each dimension.
267 """
269 def __init__(self, x, y):
270 ShapeBinaryFctOperator.__init__(
271 self, 'max', lambda a, b: max(a, b), 'max(a, b)', x, y)
273 def __repr__(self):
274 """
275 Displays a string.
277 @return a string
278 """
279 return "{0}({1}, {2})".format(
280 self.__class__.__name__, repr(self._args[0]), repr(self._args[1]))
283class DimensionObject(BaseDimensionShape):
284 """
285 One dimension of a shape.
286 """
288 def __init__(self, obj):
289 """
290 @param obj int or @see cl DimensionObject or None to
291 specify something unknown
292 """
293 if obj is None or obj == 0 or obj == '?':
294 self._dim = None
295 elif isinstance(obj, (int, str, ShapeOperator, DimensionObject,
296 numpy.int32, numpy.int64)):
297 self._dim = obj
298 else:
299 raise TypeError("Unexpected type for obj: {}".format(type(obj)))
301 @property
302 def dim(self):
303 """
304 Returns the dimension.
305 """
306 return self._dim
308 def __repr__(self):
309 """
310 usual
311 """
312 if isinstance(self._dim, int):
313 return "DimensionObject({})".format(self._dim)
314 if isinstance(self._dim, DimensionObject):
315 return repr(self._dim)
316 if isinstance(self._dim, ShapeOperator):
317 return "DimensionObject({})".format(repr(self._dim))
318 return "DimensionObject('{}')".format(self._dim)
320 @staticmethod
321 def _same_(obj):
322 """
323 Returns *obj* if *obj* is @see cl DimensionObject
324 otherwise converts it.
325 """
326 if isinstance(obj, DimensionObject):
327 return obj
328 return DimensionObject(obj)
330 def to_string(self, use_x=True):
331 """
332 Represents the dimension as a string.
333 """
334 if isinstance(self._dim, int):
335 return '{}'.format(self._dim)
336 if isinstance(self._dim, ShapeOperator):
337 return self._dim.to_string()
338 if isinstance(self._dim, str):
339 return self._dim
340 if self._dim is None:
341 return 'x' if use_x else '?'
342 raise NotImplementedError( # pragma: no cover
343 "Not implemented for '{}'.".format(repr(self)))
345 def evaluate(self, **kwargs):
346 """
347 Evalutes the dimension.
349 @param kwargs value for the variables.
350 @return string or integer
351 """
352 if isinstance(self._dim, (int, ShapeOperator, DimensionObject)):
353 res = self._dim
354 elif isinstance(self._dim, str):
355 if self._dim in kwargs:
356 res = kwargs[self._dim]
357 else:
358 res = self._dim
359 elif self._dim is None:
360 pref = str(hex(id(self)))[2:]
361 res = "n{}".format(pref)
362 elif isinstance(self._dim, ):
363 res = self._dim.evaluate(**kwargs)
364 else:
365 raise NotImplementedError( # pragma: no cover
366 "Not implemented for '{}'.".format(repr(self)))
367 if isinstance(res, (ShapeOperator, DimensionObject)):
368 return res.evaluate(**kwargs)
369 return res
371 def __eq__(self, v):
372 """
373 usual
374 """
375 if isinstance(v, (int, str)):
376 return self._dim == v
377 if isinstance(v, DimensionObject):
378 return v == self._dim
379 if isinstance(v, ShapeOperator):
380 ve = v.evaluate()
381 return ve == self._dim
382 if v is None:
383 return self._dim is None
384 raise TypeError( # pragma: no cover
385 "Unable to compare a DimensionObject to {}".format(type(v)))
387 def __add__(self, obj):
388 """
389 usual
390 """
391 return DimensionObject(
392 ShapeOperatorAdd(self, DimensionObject._same_(obj)))
394 def __mul__(self, obj):
395 """
396 usual
397 """
398 return DimensionObject(
399 ShapeOperatorMul(self, DimensionObject._same_(obj)))
401 def __gt__(self, obj):
402 """
403 usual
404 """
405 if obj is None:
406 return not isinstance(self._dim, int)
407 if isinstance(self._dim, int) and isinstance(obj._dim, int):
408 return self._dim > obj._dim
409 return DimensionObject(
410 ShapeOperatorGreater(self, DimensionObject._same_(obj)))
413class ShapeObject(BaseDimensionShape):
414 """
415 Handles mathematical operations around shapes.
416 It stores a type (:epkg:`numpy` type),
417 and a name to somehow have an idea of where
418 the shape comes from in the :epkg:`ONNX` graph.
419 The shape itself is defined by a list of
420 @see cl DimensionObject or @see cl ShapeOperator
421 or *None* if the shape is unknown. A dimension is an
422 integer or a variable encoded as a string. This variable
423 is a way to tell the dimension may vary.
425 .. runpython::
426 :showcode:
427 :warningout: DeprecationWarning
429 import numpy
430 from mlprodict.onnxrt.shape_object import ShapeObject
432 sh1 = ShapeObject((1, 2), dtype=numpy.float32)
433 sh2 = ShapeObject((45, 2), dtype=numpy.float32)
434 mx = max(sh1, sh2)
435 print(mx)
437 sh1 = ShapeObject((1, 2), dtype=numpy.float32)
438 sh2 = ShapeObject((None, 2), dtype=numpy.float32)
439 print(sh2)
440 mx = max(sh1, sh2)
441 print(mx.to_string())
443 sh1 = ShapeObject((1, 2), dtype=numpy.float32)
444 sh2 = ShapeObject(('n', 2), dtype=numpy.float32)
445 print(sh2)
446 mx = max(sh1, sh2)
447 print(mx.evaluate(n=4))
448 """
450 def __init__(self, shape, dtype=None, use_n1=False, name=None):
451 """
452 @param shape tuple or `numpy.array`
453 @param dtype dtype
454 @param use_n1 use `'n'` if the first dimension is unknown
455 @param name optional, for debugging purposes
456 """
457 self.name = name
458 if isinstance(shape, numpy.ndarray):
459 self._shape = [DimensionObject(s) for s in shape.shape]
460 self._dtype = shape.dtype
461 elif isinstance(shape, dict) and 'type' in shape:
462 tshape = shape['type']
463 if tshape['kind'] == 'tensor':
464 if tshape['shape'] == ('?', ):
465 self._shape = None
466 else:
467 self._shape = [DimensionObject(s) for s in tshape['shape']]
468 self._dtype = tshape['elem']
469 elif tshape['kind'] == 'map':
470 self._shape = []
471 self._dtype = 'map'
472 elif tshape['kind'] == 'sequence':
473 self._shape = []
474 self._dtype = 'sequence'
475 else:
476 raise ValueError( # pragma: no cover
477 "Wrong shape value {}".format(shape))
478 elif isinstance(shape, (tuple, list)):
479 self._shape = []
480 for s in shape:
481 self._shape.append(DimensionObject(s))
482 self._dtype = dtype
483 elif shape is None:
484 # shape is unknown
485 self._shape = None
486 self._dtype = dtype
487 else:
488 raise TypeError( # pragma: no cover
489 "Unexpected type for shape: {}, shape={}".format(
490 type(shape), shape))
492 def _dtype_again():
493 if self._dtype is None:
494 raise ValueError(
495 "dtype cannot be None, shape type is {}\n{}".format(
496 type(shape), shape))
497 if isinstance(self._dtype, numpy.dtype):
498 # no need to go further
499 return
500 if self._dtype in (float, 'double', 'tensor(double)'):
501 self._dtype = numpy.float64
502 elif self._dtype in ('float32', 'float', 'tensor(float)'):
503 self._dtype = numpy.float32
504 elif self._dtype in (numpy.float16, 'float16', 'tensor(float16)'):
505 self._dtype = numpy.float16
506 elif self._dtype in ('int32', 'tensor(int32)'):
507 self._dtype = numpy.int32
508 elif self._dtype in (int, 'int', 'int64', 'tensor(int64)'):
509 self._dtype = numpy.int64
510 elif self._dtype in (str, 'str', numpy.str_, 'tensor(str)'):
511 self._dtype = numpy.str_
512 elif (hasattr(self._dtype, 'type') and self._dtype.type is numpy.string_):
513 pass
514 elif self._dtype in (bool, 'bool', numpy.bool_):
515 self._dtype = numpy.bool_
516 elif self._dtype in (object, numpy.object_):
517 pass
518 elif self._dtype in (numpy.int8, 'int8', ):
519 self._dtype = numpy.int8
520 elif self._dtype in (numpy.uint8, 'uint8', ):
521 self._dtype = numpy.uint8
522 elif self._dtype in (numpy.int16, 'int16', ):
523 self._dtype = numpy.int16
524 elif self._dtype in (numpy.uint16, 'uint16', ):
525 self._dtype = numpy.uint16
526 elif self._dtype in (numpy.uint32, 'uint32', ):
527 self._dtype = numpy.uint32
528 elif self._dtype in (numpy.uint64, 'uint64', ):
529 self._dtype = numpy.uint64
530 elif self._dtype in (numpy.complex64, 'complex64', ):
531 self._dtype = numpy.complex64
532 elif self._dtype in (numpy.complex128, 'complex128', ):
533 self._dtype = numpy.complex128
534 elif self._dtype == "tensor({'kind': 'tensor', 'elem': 'float', 'shape': })":
535 self._dtype = numpy.float32
536 elif self._dtype not in {
537 numpy.float32, numpy.float64, numpy.int32, numpy.int64,
538 numpy.str_, numpy.bool_, numpy.float16, None,
539 numpy.complex64, numpy.complex128,
540 'map', 'sequence'}:
541 raise ValueError( # pragma: no cover
542 "dtype has an unexpected value: '{}'.".format(self._dtype))
543 try:
544 _dtype_again()
545 except TypeError as e:
546 raise TypeError( # pragma: no cover
547 "Unexpected error with %r of type %r." % (
548 (self._dtype, type(self._dtype)))) from e
550 def _shape_again():
551 if self._shape is not None:
552 for i, a in enumerate(self._shape):
553 if not isinstance(a, DimensionObject):
554 raise TypeError( # pragma: no cover
555 'Dimension {} has a wrong type {}'.format(
556 i, type(a)))
557 if use_n1:
558 sh = self._shape[0] if self._shape else None
559 if isinstance(sh, DimensionObject) and sh._dim is None:
560 sh._dim = 'n'
561 if self._shape is not None:
562 for s in self._shape:
563 if isinstance(s, int):
564 raise TypeError( # pragma: no cover
565 "Unexpected type int in shape %r." % self)
566 _shape_again()
568 def reshape(self, shape):
569 """
570 Creates a new shape, checks the number of elements is the same.
571 """
572 sh = ShapeObject(shape, self.dtype, getattr(self, '_dim', None),
573 self.name)
574 p1 = self.product().evaluate()
575 p2 = sh.product().evaluate()
576 if isinstance(p1, int) and p1 != p2:
577 raise ValueError("Shape {} cannot be reshaped into {} "
578 "(p1={}, p2={}).".format(sh, shape, p1, p2))
579 return sh
581 def copy(self, dtype=None, name=None):
582 """
583 A copy not a deepcopy.
585 @param dtype None or a value to rewrite the type.
586 @param name overwrites the name
587 @return @see cl ShapeObject
588 """
589 if self._shape is None:
590 return ShapeObject(None, dtype=self.dtype, name=name or self.name)
591 return ShapeObject(self._shape.copy(),
592 self.dtype if dtype is None else dtype,
593 name=name or self.name)
595 def __getitem__(self, index):
596 """
597 Extracts a specific dimension.
598 """
599 if self._shape is None:
600 return None
601 if isinstance(index, int) and index >= len(self._shape):
602 return 1
603 return self._shape[index]
605 def __setitem__(self, index, value):
606 """
607 Changes a specific dimension.
608 """
609 if self._shape is None:
610 return
611 while len(self._shape) <= index:
612 self._shape.append(DimensionObject(1))
613 self._shape[index] = value
615 @property
616 def shape(self):
617 """
618 Returns the stored shape.
619 """
620 if self._shape is None:
621 return None
622 return tuple(self._shape)
624 def __len__(self):
625 """
626 Returns the number of dimensions.
627 """
628 if self._shape is None:
629 return 0
630 return len(self._shape)
632 @property
633 def dtype(self):
634 """
635 Returns the stored *dtype*.
636 """
637 return self._dtype
639 def reduce(self, axis=1, keepdims=False, dtype=None):
640 """
641 Reduces the matrix. Removes one dimension.
643 @param axis axis
644 @param keepdims keep dimensions, replaces the removed
645 dimension by 1
646 @param dtype if not None, changes the type
647 @return new dimension
648 """
649 if self._shape is None:
650 if self.name is None:
651 return self.copy()
652 return self.copy(name="{}-RD".format(self.name))
653 if axis is None:
654 return ShapeObject((1, ), self._dtype if dtype is None else dtype,
655 name="{}-RDN".format(self.name))
657 if isinstance(axis, ShapeObject):
659 def drop_axis(shape, a):
660 c = list(shape)
661 del c[a[0]]
662 return c
664 return ShapeObjectFct(
665 drop_axis, self, axis, name="DropAxis", dtype=self.dtype)
667 if 0 <= axis < len(self._shape):
668 cp = self._shape.copy()
669 if keepdims:
670 cp[axis] = DimensionObject(1)
671 else:
672 del cp[axis]
673 return ShapeObject(cp, self._dtype if dtype is None else dtype,
674 name="{}-RD".format(self.name))
675 raise IndexError("axis={} is wrong, shape is {}-tuple and equal to "
676 "{}".format(axis, len(self._shape), self))
678 def __repr__(self):
679 """
680 usual
681 """
682 st = str(self.dtype)
683 if "'" in st:
684 st = st.split("'")[1]
686 if self.shape is None:
687 if self.name is None:
688 return "ShapeObject(None, dtype={})".format(st)
689 return "ShapeObject(None, dtype={}, name='{}')".format(st, self.name)
691 st_shape = []
692 for s in self.shape:
693 if isinstance(getattr(s, "_dim", None), (int, str)):
694 st_shape.append(str(s._dim))
695 else:
696 st_shape.append(repr(s))
697 if len(st_shape) == 1:
698 st_shape.append('')
699 st_shape = '({})'.format(", ".join(st_shape))
700 if self.name is None:
701 return "ShapeObject({}, dtype={})".format(st_shape, st)
702 return "ShapeObject({}, dtype={}, name='{}')".format(
703 st_shape, st, self.name)
705 def __iter__(self):
706 """
707 Iterators over dimensions.
708 """
709 if self._shape is not None:
710 for d in self._shape:
711 yield d
713 def __gt__(self, a):
714 """
715 Compares shapes. Operator ``>``.
716 """
717 if isinstance(a, tuple):
718 a = ShapeObject(a, dtype=self._dtype)
719 if self._shape is None and a._shape is None:
720 return False
721 if self._shape is None:
722 return True
723 if a._shape is None:
724 return False
725 if len(self) > len(a):
726 return True
727 if len(self) < len(a):
728 return False
729 for d1, d2 in zip(self, a):
730 if d1 > d2:
731 return True
732 if d1 < d2:
733 return False
734 return False
736 def __eq__(self, a):
737 """
738 Tests equality between two shapes.
739 """
740 if isinstance(a, tuple):
741 a = ShapeObject(a, dtype=self._dtype)
742 if self._shape is None and a._shape is None:
743 return True
744 if self._shape is None or a._shape is None:
745 return False
746 if len(self) != len(a):
747 return False
748 for d1, d2 in zip(self, a):
749 if d1 == d2:
750 continue
751 return False
752 return True
754 def evaluate(self, **kwargs):
755 """
756 Evaluates the shape.
757 """
758 vs = []
759 for v in self:
760 d = v.evaluate(**kwargs)
761 vs.append(d)
762 return ShapeObject(tuple(vs), self._dtype, name="{}-EV".format(self.name))
764 def to_string(self, use_x=False):
765 """
766 Converts shapes into a string.
767 """
768 shapes = []
769 for a in self._shape:
770 shapes.append(a.to_string(use_x=use_x))
771 return '({})'.format(', '.join(shapes))
773 def product(self):
774 """
775 Multiplies all the dimension.
777 @return @see cl DimensionObject
778 """
779 cl = self[0]
780 for i in range(1, len(self)):
781 cl = cl * self[i]
782 return cl
784 def append(self, dim):
785 """
786 Appends a dimension.
787 """
788 if self._shape is None:
789 return
790 if isinstance(dim, DimensionObject):
791 self._shape.append(dim)
792 else:
793 self._shape.append(DimensionObject(dim))
795 def insert(self, dim, pos=0):
796 """
797 Inserts a dimension at position *pos*.
798 """
799 if self._shape is None:
800 return
801 if isinstance(dim, DimensionObject):
802 self._shape.insert(pos, dim)
803 else:
804 self._shape.insert(pos, DimensionObject(dim))
806 def squeeze(self, axis):
807 """
808 Removes one dimension.
809 """
810 cp = self.copy(name='{}-SZ'.format(self.name))
811 cp.drop_axis(axis)
812 return cp
814 def unsqueeze(self, axes):
815 """
816 Adds dimensions.
817 """
818 cp = self
819 name = '{}-USZ'.format(self.name)
820 for ax in axes[::-1]:
821 cp = cp.copy(name=name)
822 cp.insert(ax, 1)
823 return cp
825 def transpose(self, perm):
826 """
827 Removes one dimension.
828 """
829 if self.shape is None:
830 return self.copy(name='{}-TR'.format(self.name))
831 cp = ShapeObject([None for p in perm], dtype=self.dtype,
832 name="{}-TR".format(self.name))
833 for i, p in enumerate(perm):
834 if p >= len(self):
835 # This should not happen.
836 cp._shape[i] = None
837 else:
838 cp._shape[i] = self._shape[p]
839 return cp
841 def drop_axis(self, axis):
842 """
843 Drops an axis.
844 """
845 if self._shape is not None:
846 if isinstance(axis, (tuple, list)):
847 for i in sorted(axis, reverse=True):
848 del self._shape[i]
849 else:
850 del self._shape[axis]
852 def broadcast(self, a):
853 """
854 Computes the shape after a broadcast.
855 """
856 if a is None:
857 raise ValueError("a should not be None") # pragma: no cover
858 if a._shape is None:
859 return a.copy()
860 if self._shape is None:
861 return self.copy()
862 mx = max(len(self._shape), len(a._shape))
863 res = []
864 for i in range(mx):
865 if i < len(self._shape):
866 if i < len(a._shape):
867 res.append(ShapeOperatorMax(self[i], a[i]))
868 else:
869 res.append(self[i])
870 else:
871 res.append(a[i])
872 return ShapeObject(tuple(res), self.dtype, False,
873 name="broadcast-{}-{}".format(self.name, a.name))
875 @staticmethod
876 def _infer_merged_type(*args, use_dtype=True):
877 if use_dtype:
878 tys = set(a.dtype for a in args)
879 else:
880 tys = set(args)
881 if len(tys) == 1:
882 return list(tys)[0]
883 if any(tys & {numpy.float64, numpy.int64,
884 numpy.float32, numpy.int32,
885 numpy.float16}):
886 return numpy.float64
887 raise RuntimeError( # pragma: no cover
888 "Unable to infer types based on {} ({}).".format(
889 tys, len(tys)))
891 def concat_columns(self, axis, *shapes):
892 """
893 Concatenates columns from *shapes* to this one
894 along one axis.
895 """
896 args = [self] + list(shapes)
897 dtype = self._infer_merged_type(*args)
898 dim_axis = self[axis]
899 if isinstance(dim_axis, int):
900 dim_axis = DimensionObject(dim_axis)
901 if dim_axis is None:
902 return ShapeObject(None, dtype=dtype)
903 if isinstance(dim_axis, int):
904 raise TypeError( # pragma: no cover
905 "Unexpected type for shape %r." % self)
906 for a in shapes:
907 if a[axis] is None:
908 return ShapeObject(None, dtype=dtype)
909 dim_axis = dim_axis + a[axis]
910 a0 = args[0].copy(dtype=dtype)
911 a0[axis] = dim_axis
912 return a0
914 @staticmethod
915 def einsum_shape(equation, *inputs):
916 """
917 Computes :epkg:`einsum` shapes.
918 Not the most efficient one as it creates variables
919 of the given shapes.
920 """
921 for inp in inputs:
922 if inp.shape is None:
923 return inp
924 inp, out = [_.strip() for _ in equation.split(b"->")]
925 inps = [_.strip() for _ in inp.split(b',')]
926 if len(inputs) != len(inps):
927 raise RuntimeError( # pragma: no cover
928 "Input mismatch between '{}' and {}.".format(equation, inps))
929 shs = {}
930 for a, b in zip(inps, inputs):
931 if len(a) != len(b):
932 raise RuntimeError( # pragma: no cover
933 "Input mismatch '{}' (in '{}') and {}.".format(a, equation, b))
934 for c, s in zip(a, b):
935 if c not in shs:
936 shs[c] = s
937 elif shs[c] != s:
938 raise RuntimeError( # pragma: no cover
939 "Equation '{}'. Dimension mismatch '{}' != {}.".format(
940 equation, s, shs[c]))
941 new_shape = [shs[i] for i in out]
942 return ShapeObject(new_shape, dtype=ShapeObject._infer_merged_type(*inputs))
944 @staticmethod
945 def gather_shape(input, indices, axis):
946 """
947 Computes Gather shapes.
948 """
949 input_rank = len(input)
950 if input_rank is None:
951 return ShapeObject(None, dtype=input._dtype)
952 index_rank = len(indices)
953 if index_rank is None:
954 return ShapeObject(None, dtype=input._dtype)
956 if axis < 0:
957 axis = input_rank + axis
959 shape = []
960 for i in range(axis):
961 shape.append(input[i])
963 for dim in indices:
964 shape.append(dim)
966 for i in range(axis + 1, input_rank):
967 shape.append(input[i])
969 return ShapeObject(shape, dtype=input._dtype)
972class ShapeObjectFct(ShapeObject):
973 """
974 Computes a shape depending on a user defined function.
975 See @see cl Conv for an example.
976 """
978 def __init__(self, fct, *shapes, dtype=None, name=None):
979 """
980 @param fct function
981 @param shapes shapes sent to fct
982 @param dtype dtype
983 @param name optional, for debugging purposes
984 """
985 ShapeObject.__init__(self, None, dtype=dtype, name=name)
986 self._fct = fct
987 self._shapes = shapes
989 def evaluate(self, **kwargs):
990 """
991 Evaluates the shape.
992 """
993 vs = []
994 for v in self._shapes:
995 d = v.evaluate(**kwargs)
996 vs.append(d)
997 res = self._fct(*vs)
998 if self.name is not None:
999 res.name = self.name
1000 return res