"""
Shape object.
:githublink:`%|py|5`
"""
import numpy
[docs]class BaseDimensionShape:
"""
Base class to :class:`DimensionObject <mlprodict.onnxrt.shape_object.DimensionObject>`,
:class:`ShapeOperator <mlprodict.onnxrt.shape_object.ShapeOperator>`, :class:`ShapeObject`.
:githublink:`%|py|12`
"""
[docs] def to_string(self, use_x=True):
"""
Converts the object into a string.
:githublink:`%|py|17`
"""
raise NotImplementedError()
[docs] def evaluate(self, **kwargs):
"""
Evaluates the object, reduces the expression
to a number or a string.
:githublink:`%|py|24`
"""
raise NotImplementedError() # pragma: no cover
[docs]class ShapeOperator(BaseDimensionShape):
"""
Base class for all shapes operator.
:githublink:`%|py|31`
"""
[docs] def __init__(self, name, fct, fct_string, *args):
"""
:param name: display name of the operator
:param fct: function doing the operator
if argument are numeric
:param fct_string: function represented as a string
:param args: argument of the operator
:githublink:`%|py|40`
"""
self._name = name
self._fct = fct
self._fct_string = fct_string
self._args = args
for a in self._args:
if not isinstance(a, DimensionObject):
raise TypeError(
"All arguments must be of type DimensionObject not '{}'."
"".format(type(a)))
[docs] def __repr__(self):
"""
usual
:githublink:`%|py|54`
"""
return "{0}('{1}', {2}, '{2}', {3})".format(
self.__class__.__name__, self._name,
self._fct_string, self._args)
[docs] def to_string(self, use_x=True):
"""
Displays as a string.
:return: a string
:githublink:`%|py|64`
"""
raise NotImplementedError( # pragma: no cover
"Operator '{}' does not implement 'to_string': {}.".format(
self.__class__.__name__, repr(self)))
[docs] def evaluate(self, **kwargs):
"""
Evalutes the operator.
:param kwargs: value for the variables.
:return: string or integer
:githublink:`%|py|75`
"""
args = []
has_string = False
for a in self._args:
a = DimensionObject._same_(a)
v = a.evaluate(**kwargs)
if isinstance(v, str):
has_string = True
args.append(v)
if has_string:
res = self._evaluate_string_(args, **kwargs)
else:
try:
res = self._fct(*args)
except TypeError as e:
raise RuntimeError(
"Unable to evaluate operator {} due to {}".format(repr(self), e)) from e
return res
[docs] def _evaluate_string_(self, args, **kwargs):
"""
Evalutes the operator assuming some of them are still strings.
:param args: arguments extracted by method *evaluate*
:param kwargs: value for the variables.
:return: string or integer
:githublink:`%|py|101`
"""
raise NotImplementedError(
"This function must be overwritten.") # pragma: no cover
[docs]class ShapeBinaryOperator(ShapeOperator):
"""
Base class for shape binary operator.
:githublink:`%|py|109`
"""
[docs] def __init__(self, name, fct, fct_string, x, y):
"""
:param name: display name of the operator
:param fct: function doing the operator
if argument are numeric
:param fct_string: function represented as a string
:param x: first argument
:param y: second argument
:githublink:`%|py|119`
"""
ShapeOperator.__init__(self, name, fct, fct_string, x, y)
if isinstance(x, tuple):
raise TypeError('x cannot be a tuple') # pragma: no cover
if isinstance(y, tuple):
raise TypeError('y cannot be a tuple') # pragma: no cover
[docs] def _to_string1(self, x, y):
return DimensionObject(self._fct(x._dim, y._dim)).to_string()
[docs] def _to_string2(self, x, y):
return DimensionObject("{}{}{}".format(x._dim, self._name, y._dim)).to_string()
[docs] def _to_string2b(self, x, y):
return DimensionObject("({}){}({})".format(x._dim, self._name, y._dim)).to_string()
[docs] def _to_string3(self, x):
return DimensionObject("{}{}x".format(x._dim, self._name)).to_string()
[docs] def to_string(self, use_x=True):
"""
Applies binary operator to a dimension.
:param use_x: use `'x'` if dimension is unknown
:return: a string
:githublink:`%|py|144`
"""
x, y = self._args # pylint: disable=W0632
if isinstance(x._dim, int):
if isinstance(y, DimensionObject):
if isinstance(y._dim, int):
return self._to_string1(x, y)
if isinstance(y._dim, str):
return self._to_string2(x, y)
if y._dim is None:
if use_x:
return self._to_string3(x)
return DimensionObject("{}{}DimensionObject()".format(
x._dim, self._name)).to_string()
raise TypeError( # pragma: no cover
"Unable to handle type '{}'.".format(type(y._dim)))
raise TypeError( # pragma: no cover
"Unable to handle type '{}'.".format(type(y)))
elif isinstance(x._dim, str):
if isinstance(y._dim, int):
return self._to_string2(x, y)
if isinstance(y._dim, str):
return self._to_string2b(x, y)
raise TypeError( # pragma: no cover
"Unable to handle type '{}'.".format(type(y._dim)))
raise TypeError( # pragma: no cover
"Unable to handle type '{}'.".format(type(x._dim)))
[docs] def _evaluate_string_(self, args, **kwargs):
"""
Evalutes the operator assuming some of them are still strings.
:param args: arguments extracted by method *evaluate*
:param kwargs: value for the variables.
:return: string or integer
:githublink:`%|py|178`
"""
return self._name.join(map(lambda s: '({})'.format(s), args))
[docs]class ShapeBinaryFctOperator(ShapeBinaryOperator):
"""
Base class for shape binary operator defined by a function.
:githublink:`%|py|185`
"""
[docs] def _to_string2(self, x, y):
return DimensionObject("{}({},{})".format(self._name, x._dim, y._dim)).to_string()
[docs] def _to_string2b(self, x, y):
return DimensionObject("{}({},{})".format(self._name, x._dim, y._dim)).to_string()
[docs] def _to_string3(self, x):
return DimensionObject("{}({},x)".format(self._name, x._dim)).to_string()
[docs] def _evaluate_string_(self, args, **kwargs):
"""
Evalutes the operator assuming some of them are still strings.
:param args: arguments extracted by method *evaluate*
:param kwargs: value for the variables.
:return: string or integer
:githublink:`%|py|203`
"""
return "{}({})".format(self._name, ",".join(map(str, args)))
[docs]class ShapeOperatorAdd(ShapeBinaryOperator):
"""
Shape addition.
:githublink:`%|py|210`
"""
[docs] def __init__(self, x, y):
ShapeBinaryOperator.__init__(
self, '+', lambda a, b: a + b, 'lambda a, b: a + b', x, y)
[docs] def __repr__(self):
"""
Displays a string.
:return: a string
:githublink:`%|py|221`
"""
return "{0}({1}, {2})".format(
self.__class__.__name__, repr(self._args[0]), repr(self._args[1]))
[docs]class ShapeOperatorMul(ShapeBinaryOperator):
"""
Shape multiplication.
:githublink:`%|py|229`
"""
[docs] def __init__(self, x, y):
ShapeBinaryOperator.__init__(
self, '*', lambda a, b: a * b, 'lambda a, b: a * b', x, y)
[docs] def __repr__(self):
"""
Displays a string.
:return: a string
:githublink:`%|py|240`
"""
return "{0}({1}, {2})".format(
self.__class__.__name__, repr(self._args[0]), repr(self._args[1]))
[docs]class ShapeOperatorGreater(ShapeBinaryOperator):
"""
Shape comparison.
:githublink:`%|py|248`
"""
[docs] def __init__(self, x, y):
ShapeBinaryOperator.__init__(
self, '>', lambda a, b: a > b, 'lambda a, b: a > b', x, y)
[docs] def __repr__(self):
"""
Displays a string.
:return: a string
:githublink:`%|py|259`
"""
return "{0}({1}, {2})".format(
self.__class__.__name__, repr(self._args[0]), repr(self._args[1]))
[docs]class ShapeOperatorMax(ShapeBinaryFctOperator):
"""
Best on each dimension.
:githublink:`%|py|267`
"""
[docs] def __init__(self, x, y):
ShapeBinaryFctOperator.__init__(
self, 'max', lambda a, b: max(a, b), 'max(a, b)', x, y)
[docs] def __repr__(self):
"""
Displays a string.
:return: a string
:githublink:`%|py|278`
"""
return "{0}({1}, {2})".format(
self.__class__.__name__, repr(self._args[0]), repr(self._args[1]))
[docs]class DimensionObject(BaseDimensionShape):
"""
One dimension of a shape.
:githublink:`%|py|286`
"""
[docs] def __init__(self, obj):
"""
:param obj: int or :class:`DimensionObject <mlprodict.onnxrt.shape_object.DimensionObject>` or None to
specify something unknown
:githublink:`%|py|292`
"""
if obj is None or obj == 0 or obj == '?':
self._dim = None
elif isinstance(obj, (int, str, ShapeOperator, DimensionObject,
numpy.int32, numpy.int64)):
self._dim = obj
else:
raise TypeError("Unexpected type for obj: {}".format(type(obj)))
@property
def dim(self):
"""
Returns the dimension.
:githublink:`%|py|305`
"""
return self._dim
[docs] def __repr__(self):
"""
usual
:githublink:`%|py|311`
"""
if isinstance(self._dim, int):
return "DimensionObject({})".format(self._dim)
if isinstance(self._dim, DimensionObject):
return repr(self._dim)
if isinstance(self._dim, ShapeOperator):
return "DimensionObject({})".format(repr(self._dim))
return "DimensionObject('{}')".format(self._dim)
[docs] @staticmethod
def _same_(obj):
"""
Returns *obj* if *obj* is :class:`DimensionObject <mlprodict.onnxrt.shape_object.DimensionObject>`
otherwise converts it.
:githublink:`%|py|325`
"""
if isinstance(obj, DimensionObject):
return obj
return DimensionObject(obj)
[docs] def to_string(self, use_x=True):
"""
Represents the dimension as a string.
:githublink:`%|py|333`
"""
if isinstance(self._dim, int):
return '{}'.format(self._dim)
if isinstance(self._dim, ShapeOperator):
return self._dim.to_string()
if isinstance(self._dim, str):
return self._dim
if self._dim is None:
return 'x' if use_x else '?'
raise NotImplementedError( # pragma: no cover
"Not implemented for '{}'.".format(repr(self)))
[docs] def evaluate(self, **kwargs):
"""
Evalutes the dimension.
:param kwargs: value for the variables.
:return: string or integer
:githublink:`%|py|351`
"""
if isinstance(self._dim, (int, ShapeOperator, DimensionObject)):
res = self._dim
elif isinstance(self._dim, str):
if self._dim in kwargs:
res = kwargs[self._dim]
else:
res = self._dim
elif self._dim is None:
pref = str(hex(id(self)))[2:]
res = "n{}".format(pref)
elif isinstance(self._dim, ):
res = self._dim.evaluate(**kwargs)
else:
raise NotImplementedError( # pragma: no cover
"Not implemented for '{}'.".format(repr(self)))
if isinstance(res, (ShapeOperator, DimensionObject)):
return res.evaluate(**kwargs)
return res
[docs] def __eq__(self, v):
"""
usual
:githublink:`%|py|374`
"""
if isinstance(v, (int, str)):
return self._dim == v
if isinstance(v, DimensionObject):
return v == self._dim
if isinstance(v, ShapeOperator):
ve = v.evaluate()
return ve == self._dim
if v is None:
return self._dim is None
raise TypeError( # pragma: no cover
"Unable to compare a DimensionObject to {}".format(type(v)))
[docs] def __add__(self, obj):
"""
usual
:githublink:`%|py|390`
"""
return DimensionObject(
ShapeOperatorAdd(self, DimensionObject._same_(obj)))
[docs] def __mul__(self, obj):
"""
usual
:githublink:`%|py|397`
"""
return DimensionObject(
ShapeOperatorMul(self, DimensionObject._same_(obj)))
[docs] def __gt__(self, obj):
"""
usual
:githublink:`%|py|404`
"""
if obj is None:
return not isinstance(self._dim, int)
if isinstance(self._dim, int) and isinstance(obj._dim, int):
return self._dim > obj._dim
return DimensionObject(
ShapeOperatorGreater(self, DimensionObject._same_(obj)))
[docs]class ShapeObject(BaseDimensionShape):
"""
Handles mathematical operations around shapes.
It stores a type (:epkg:`numpy` type),
and a name to somehow have an idea of where
the shape comes from in the :epkg:`ONNX` graph.
The shape itself is defined by a list of
:class:`DimensionObject <mlprodict.onnxrt.shape_object.DimensionObject>` or :class:`ShapeOperator`
or *None* if the shape is unknown. A dimension is an
integer or a variable encoded as a string. This variable
is a way to tell the dimension may vary.
.. runpython::
:showcode:
import numpy
from mlprodict.onnxrt.shape_object import ShapeObject
sh1 = ShapeObject((1, 2), dtype=numpy.float32)
sh2 = ShapeObject((45, 2), dtype=numpy.float32)
mx = max(sh1, sh2)
print(mx)
sh1 = ShapeObject((1, 2), dtype=numpy.float32)
sh2 = ShapeObject((None, 2), dtype=numpy.float32)
print(sh2)
mx = max(sh1, sh2)
print(mx.to_string())
sh1 = ShapeObject((1, 2), dtype=numpy.float32)
sh2 = ShapeObject(('n', 2), dtype=numpy.float32)
print(sh2)
mx = max(sh1, sh2)
print(mx.evaluate(n=4))
:githublink:`%|py|447`
"""
[docs] def __init__(self, shape, dtype=None, use_n1=False, name=None):
"""
:param shape: tuple or `numpy.array`
:param dtype: dtype
:param use_n1: use `'n'` if the first dimension is unknown
:param name: optional, for debugging purposes
:githublink:`%|py|455`
"""
self.name = name
if isinstance(shape, numpy.ndarray):
self._shape = [DimensionObject(s) for s in shape.shape]
self._dtype = shape.dtype
elif isinstance(shape, dict) and 'type' in shape:
tshape = shape['type']
if tshape['kind'] == 'tensor':
if tshape['shape'] == ('?', ):
self._shape = None
else:
self._shape = [DimensionObject(s) for s in tshape['shape']]
self._dtype = tshape['elem']
elif tshape['kind'] == 'map':
self._shape = []
self._dtype = 'map'
else:
raise ValueError( # pragma: no cover
"Wrong shape value {}".format(shape))
elif isinstance(shape, (tuple, list)):
self._shape = []
for s in shape:
self._shape.append(DimensionObject(s))
self._dtype = dtype
elif shape is None:
# shape is unknown
self._shape = None
self._dtype = dtype
else:
raise TypeError( # pragma: no cover
"Unexpected type for shape: {}".format(type(shape)))
if self._dtype is None:
raise ValueError(
"dtype cannot be None, shape type is {}\n{}".format(
type(shape), shape))
if self._dtype in (float, 'double'):
self._dtype = numpy.float64
elif self._dtype in ('float32', 'float'):
self._dtype = numpy.float32
elif self._dtype in (numpy.float16, 'float16'):
self._dtype = numpy.float16
elif self._dtype in ('int32', ):
self._dtype = numpy.int32
elif self._dtype in (int, 'int', 'int64'):
self._dtype = numpy.int64
elif self._dtype in (str, 'str'):
self._dtype = numpy.str
elif (hasattr(self._dtype, 'type') and self._dtype.type is numpy.string_):
pass
elif self._dtype in (bool, 'bool'):
self._dtype = numpy.bool
elif self._dtype in (object, numpy.object_):
pass
elif self._dtype in (numpy.int8, 'int8', ):
self._dtype = numpy.int8
elif self._dtype in (numpy.uint8, 'uint8', ):
self._dtype = numpy.uint8
elif self._dtype not in {
numpy.float32, numpy.float64, numpy.int32, numpy.int64,
numpy.str, numpy.bool, numpy.float16, None,
'map'}:
raise ValueError( # pragma: no cover
"dtype has an unexpected value: '{}'.".format(self._dtype))
if self._shape is not None:
for i, a in enumerate(self._shape):
if not isinstance(a, DimensionObject):
raise TypeError( # pragma: no cover
'Dimension {} has a wrong type {}'.format(
i, type(a)))
if use_n1:
sh = self._shape[0] if self._shape else None
if isinstance(sh, DimensionObject) and sh._dim is None:
sh._dim = 'n'
[docs] def reshape(self, shape):
"""
Creates a new shape, checks the number of elements is the same.
:githublink:`%|py|533`
"""
sh = ShapeObject(shape, self.dtype, getattr(self, '_dim', None),
self.name)
p1 = self.product().evaluate()
p2 = sh.product().evaluate()
if isinstance(p1, int) and p1 != p2:
raise ValueError("Shape {} cannot be reshaped into {} "
"(p1={}, p2={}).".format(sh, shape, p1, p2))
return sh
[docs] def copy(self, dtype=None, name=None):
"""
A copy not a deepcopy.
:param dtype: None or a value to rewrite the type.
:param name: overwrites the name
:return: :class:`ShapeObject <mlprodict.onnxrt.shape_object.ShapeObject>`
:githublink:`%|py|550`
"""
if self._shape is None:
return ShapeObject(None, dtype=self.dtype, name=name or self.name)
return ShapeObject(self._shape.copy(),
self.dtype if dtype is None else dtype,
name=name or self.name)
[docs] def __getitem__(self, index):
"""
Extracts a specific dimension.
:githublink:`%|py|560`
"""
if self._shape is None:
return None
if isinstance(index, int) and index >= len(self._shape):
return 1
return self._shape[index]
[docs] def __setitem__(self, index, value):
"""
Changes a specific dimension.
:githublink:`%|py|570`
"""
if self._shape is None:
return
while len(self._shape) <= index:
self._shape.append(DimensionObject(1))
self._shape[index] = value
@property
def shape(self):
"""
Returns the stored shape.
:githublink:`%|py|581`
"""
if self._shape is None:
return None
return tuple(self._shape)
[docs] def __len__(self):
"""
Returns the number of dimensions.
:githublink:`%|py|589`
"""
if self._shape is None:
return 0
return len(self._shape)
@property
def dtype(self):
"""
Returns the stored *dtype*.
:githublink:`%|py|598`
"""
return self._dtype
[docs] def reduce(self, axis=1, keepdims=False, dtype=None):
"""
Reduces the matrix. Removes one dimension.
:param axis: axis
:param keepdims: keep dimensions, replaces the removed
dimension by 1
:param dtype: if not None, changes the type
:return: new dimension
:githublink:`%|py|610`
"""
if self._shape is None:
if self.name is None:
return self.copy()
return self.copy(name="{}-RD".format(self.name))
if axis is None:
return ShapeObject((1, ), self._dtype if dtype is None else dtype,
name="{}-RDN".format(self.name))
if isinstance(axis, ShapeObject):
def drop_axis(shape, a):
c = list(shape)
del c[a[0]]
return c
return ShapeObjectFct(
drop_axis, self, axis, name="DropAxis", dtype=self.dtype)
if 0 <= axis < len(self._shape):
cp = self._shape.copy()
if keepdims:
cp[axis] = DimensionObject(1)
else:
del cp[axis]
return ShapeObject(cp, self._dtype if dtype is None else dtype,
name="{}-RD".format(self.name))
raise IndexError("axis={} is wrong, shape is {}-tuple and equal to "
"{}".format(axis, len(self._shape), self))
[docs] def __repr__(self):
"""
usual
:githublink:`%|py|643`
"""
st = str(self.dtype)
if "'" in st:
st = st.split("'")[1]
if self.shape is None:
if self.name is None:
return "ShapeObject(None, dtype={})".format(st)
return "ShapeObject(None, dtype={}, name='{}')".format(st, self.name)
st_shape = []
for s in self.shape:
if isinstance(getattr(s, "_dim", None), (int, str)):
st_shape.append(str(s._dim))
else:
st_shape.append(repr(s))
if len(st_shape) == 1:
st_shape.append('')
st_shape = '({})'.format(", ".join(st_shape))
if self.name is None:
return "ShapeObject({}, dtype={})".format(st_shape, st)
return "ShapeObject({}, dtype={}, name='{}')".format(
st_shape, st, self.name)
[docs] def __iter__(self):
"""
Iterators over dimensions.
:githublink:`%|py|670`
"""
if self._shape is not None:
for d in self._shape:
yield d
[docs] def __gt__(self, a):
"""
Compares shapes. Operator ``>``.
:githublink:`%|py|678`
"""
if isinstance(a, tuple):
a = ShapeObject(a, dtype=self._dtype)
if self._shape is None and a._shape is None:
return False
if self._shape is None:
return True
if a._shape is None:
return False
if len(self) > len(a):
return True
if len(self) < len(a):
return False
for d1, d2 in zip(self, a):
if d1 > d2:
return True
if d1 < d2:
return False
return False
[docs] def __eq__(self, a):
"""
Tests equality between two shapes.
:githublink:`%|py|701`
"""
if isinstance(a, tuple):
a = ShapeObject(a, dtype=self._dtype)
if self._shape is None and a._shape is None:
return True
if self._shape is None or a._shape is None:
return False
if len(self) != len(a):
return False
for d1, d2 in zip(self, a):
if d1 == d2:
continue
return False
return True
[docs] def evaluate(self, **kwargs):
"""
Evaluates the shape.
:githublink:`%|py|719`
"""
vs = []
for v in self:
d = v.evaluate(**kwargs)
vs.append(d)
return ShapeObject(tuple(vs), self._dtype, name="{}-EV".format(self.name))
[docs] def to_string(self, use_x=False):
"""
Converts shapes into a string.
:githublink:`%|py|729`
"""
shapes = []
for a in self._shape:
shapes.append(a.to_string(use_x=use_x))
return '({})'.format(', '.join(shapes))
[docs] def product(self):
"""
Multiplies all the dimension.
:return: :class:`DimensionObject <mlprodict.onnxrt.shape_object.DimensionObject>`
:githublink:`%|py|740`
"""
cl = self[0]
for i in range(1, len(self)):
cl = cl * self[i]
return cl
[docs] def append(self, dim):
"""
Appends a dimension.
:githublink:`%|py|749`
"""
if self._shape is None:
return
if isinstance(dim, DimensionObject):
self._shape.append(dim)
else:
self._shape.append(DimensionObject(dim))
[docs] def insert(self, dim, pos=0):
"""
Inserts a dimension at position *pos*.
:githublink:`%|py|760`
"""
if self._shape is None:
return
if isinstance(dim, DimensionObject):
self._shape.insert(pos, dim)
else:
self._shape.insert(pos, DimensionObject(dim))
[docs] def squeeze(self, axis):
"""
Removes one dimension.
:githublink:`%|py|771`
"""
cp = self.copy(name='{}-SZ'.format(self.name))
cp.drop_axis(axis)
return cp
[docs] def unsqueeze(self, axes):
"""
Adds dimensions.
:githublink:`%|py|779`
"""
cp = self
name = '{}-USZ'.format(self.name)
for ax in axes[::-1]:
cp = cp.copy(name=name)
cp.insert(ax, 1)
return cp
[docs] def transpose(self, perm):
"""
Removes one dimension.
:githublink:`%|py|790`
"""
if self.shape is None:
return self.copy(name='{}-TR'.format(self.name))
cp = ShapeObject([None for p in perm], dtype=self.dtype,
name="{}-TR".format(self.name))
for i, p in enumerate(perm):
if p >= len(self):
# This should not happen.
cp._shape[i] = None
else:
cp._shape[i] = self._shape[p]
return cp
[docs] def drop_axis(self, axis):
"""
Drops an axis.
:githublink:`%|py|806`
"""
if self._shape is not None:
if isinstance(axis, (tuple, list)):
for i in sorted(axis, reverse=True):
del self._shape[i]
else:
del self._shape[axis]
[docs] def broadcast(self, a):
"""
Computes the shape after a broadcast.
:githublink:`%|py|817`
"""
if a is None:
raise ValueError("a should not be None") # pragma: no cover
if a._shape is None:
return a.copy()
if self._shape is None:
return self.copy()
mx = max(len(self._shape), len(a._shape))
res = []
for i in range(mx):
if i < len(self._shape):
if i < len(a._shape):
res.append(ShapeOperatorMax(self[i], a[i]))
else:
res.append(self[i])
else:
res.append(a[i])
return ShapeObject(tuple(res), self.dtype, False,
name="broadcast-{}-{}".format(self.name, a.name))
[docs] @staticmethod
def _infer_merged_type(*args):
tys = set(a.dtype for a in args)
if len(tys) == 1:
return list(tys)[0]
if any(tys & {numpy.float64, numpy.int64,
numpy.float32, numpy.int32,
numpy.float16}):
return numpy.float64
raise RuntimeError( # pragma: no cover
"Unable to infer types based on {}.".format(tys))
[docs] def concat_columns(self, axis, *shapes):
"""
Concatenates columns from *shapes* to this one
along one axis.
:githublink:`%|py|853`
"""
args = [self] + list(shapes)
dtype = self._infer_merged_type(*args)
dim_axis = args[0][axis]
if dim_axis is None:
return ShapeObject(None, dtype=dtype)
for a in shapes:
if a[axis] is None:
return ShapeObject(None, dtype=dtype)
dim_axis = dim_axis + a[axis]
a0 = args[0].copy(dtype=dtype)
a0[axis] = dim_axis
return a0
[docs] @staticmethod
def einsum_shape(equation, *inputs):
"""
Computes :epkg:`einsum` shapes.
Not the most efficient one as it creates variables
of the given shapes.
:githublink:`%|py|873`
"""
for inp in inputs:
if inp.shape is None:
return inp
inp, out = [_.strip() for _ in equation.split(b"->")]
inps = [_.strip() for _ in inp.split(b',')]
if len(inputs) != len(inps):
raise RuntimeError( # pragma: no cover
"Input mismatch between '{}' and {}.".format(equation, inps))
shs = {}
for a, b in zip(inps, inputs):
if len(a) != len(b):
raise RuntimeError( # pragma: no cover
"Input mismatch '{}' (in '{}') and {}.".format(a, equation, b))
for c, s in zip(a, b):
if c not in shs:
shs[c] = s
elif shs[c] != s:
raise RuntimeError( # pragma: no cover
"Equation '{}'. Dimension mismatch '{}' != {}.".format(
equation, s, shs[c]))
new_shape = [shs[i] for i in out]
return ShapeObject(new_shape, dtype=ShapeObject._infer_merged_type(*inputs))
[docs] @staticmethod
def gather_shape(input, indices, axis):
"""
Computes Gather shapes.
:githublink:`%|py|901`
"""
input_rank = len(input)
if input_rank is None:
return ShapeObject(None, dtype=input._dtype)
index_rank = len(indices)
if index_rank is None:
return ShapeObject(None, dtype=input._dtype)
if axis < 0:
axis = input_rank + axis
shape = []
for i in range(axis):
shape.append(input[i])
for dim in indices:
shape.append(dim)
for i in range(axis + 1, input_rank):
shape.append(input[i])
return ShapeObject(shape, dtype=input._dtype)
[docs]class ShapeObjectFct(ShapeObject):
"""
Computes a shape depending on a user defined function.
See :class:`Conv <mlprodict.onnxrt.ops_cpu.op_conv.Conv>` for an example.
:githublink:`%|py|929`
"""
[docs] def __init__(self, fct, *shapes, dtype=None, name=None):
"""
:param fct: function
:param shapes: shapes sent to fct
:param dtype: dtype
:param name: optional, for debugging purposes
:githublink:`%|py|937`
"""
ShapeObject.__init__(self, None, dtype=dtype, name=name)
self._fct = fct
self._shapes = shapes
[docs] def evaluate(self, **kwargs):
"""
Evaluates the shape.
:githublink:`%|py|945`
"""
vs = []
for v in self._shapes:
d = v.evaluate(**kwargs)
vs.append(d)
res = self._fct(*vs)
if self.name is not None:
res.name = self.name
return res