Source code for onnx_array_api.npx.npx_types

from typing import Any, Tuple, Union

import numpy as np
from onnx import AttributeProto


class WrapperType:
    """
    WrapperType.
    """

    pass


class ElemTypeCstInner:
    """
    Defines all possible types and tensor element type.
    """

    __slots__ = []

    undefined = 0
    bool_ = 9
    int8 = 3
    int16 = 5
    int32 = 6
    int64 = 7
    uint8 = 2
    uint16 = 4
    uint32 = 12
    uint64 = 13
    float16 = 10
    float32 = 1
    float64 = 11
    bfloat16 = 16
    complex64 = 14
    complex128 = 15


class ElemTypeCstSet(ElemTypeCstInner):
    """
    Sets of element types.
    """

    allowed = set(range(1, 17))

    ints = {
        ElemTypeCstInner.int8,
        ElemTypeCstInner.int16,
        ElemTypeCstInner.int32,
        ElemTypeCstInner.int64,
        ElemTypeCstInner.uint8,
        ElemTypeCstInner.uint16,
        ElemTypeCstInner.uint32,
        ElemTypeCstInner.uint64,
    }

    floats = {
        ElemTypeCstInner.float16,
        ElemTypeCstInner.bfloat16,
        ElemTypeCstInner.float32,
        ElemTypeCstInner.float64,
    }

    numerics = {
        ElemTypeCstInner.int8,
        ElemTypeCstInner.int16,
        ElemTypeCstInner.int32,
        ElemTypeCstInner.int64,
        ElemTypeCstInner.uint8,
        ElemTypeCstInner.uint16,
        ElemTypeCstInner.uint32,
        ElemTypeCstInner.uint64,
        ElemTypeCstInner.float16,
        ElemTypeCstInner.bfloat16,
        ElemTypeCstInner.float32,
        ElemTypeCstInner.float64,
    }

    @staticmethod
    def combined(type_set):
        "Combines all types into a single integer by using power of 2."
        s = 0
        for dt in type_set:
            s += 1 << dt
        return s


class ElemTypeCst(ElemTypeCstSet):
    """
    Combination of element types.
    """

    Undefined = 0
    Bool = 1 << ElemTypeCstInner.bool_
    Int8 = 1 << ElemTypeCstInner.int8
    Int16 = 1 << ElemTypeCstInner.int16
    Int32 = 1 << ElemTypeCstInner.int32
    Int64 = 1 << ElemTypeCstInner.int64
    UInt8 = 1 << ElemTypeCstInner.uint8
    UInt16 = 1 << ElemTypeCstInner.uint16
    UInt32 = 1 << ElemTypeCstInner.uint32
    UInt64 = 1 << ElemTypeCstInner.uint64
    BFloat16 = 1 << ElemTypeCstInner.bfloat16
    Float16 = 1 << ElemTypeCstInner.float16
    Float32 = 1 << ElemTypeCstInner.float32
    Float64 = 1 << ElemTypeCstInner.float64
    Complex64 = 1 << ElemTypeCstInner.complex64
    Complex128 = 1 << ElemTypeCstInner.complex128

    Numerics = ElemTypeCstSet.combined(ElemTypeCstSet.numerics)
    Floats = ElemTypeCstSet.combined(ElemTypeCstSet.floats)
    Ints = ElemTypeCstSet.combined(ElemTypeCstSet.ints)


[docs]class ElemType(ElemTypeCst): """ Allowed element type based on numpy dtypes. :param dtype: integer or a string """ names_int = { att: getattr(ElemTypeCstInner, att) for att in dir(ElemTypeCstInner) if isinstance(getattr(ElemTypeCstInner, att), int) } int_names = { getattr(ElemTypeCstInner, att): att for att in dir(ElemTypeCstInner) if isinstance(getattr(ElemTypeCstInner, att), int) } set_names = { getattr(ElemTypeCst, att): att for att in dir(ElemTypeCst) if isinstance(getattr(ElemTypeCst, att), int) and "A" <= att[0] <= "Z" } numpy_map = { **{ getattr(np, att): getattr(ElemTypeCst, att) for att in dir(ElemTypeCst) if isinstance(getattr(ElemTypeCst, att), int) and hasattr(np, att) }, **{ np.dtype(att): getattr(ElemTypeCst, att) for att in dir(ElemTypeCst) if isinstance(getattr(ElemTypeCst, att), int) and hasattr(np, att) }, } __slots__ = ["dtype"] @classmethod def __class_getitem__(cls, dtype: Union[str, int]): if isinstance(dtype, str): dtype = ElemType.names_int[dtype] elif dtype in ElemType.numpy_map: dtype = ElemType.numpy_map[dtype] elif dtype == 0: pass elif dtype not in ElemType.allowed: raise ValueError(f"Unexpected dtype {dtype} not in {ElemType.allowed}.") newt = type(f"{cls.__name__}{dtype}", (cls,), dict(dtype=dtype)) if "<" in newt.__name__: raise NameError(f"Name is wrong {newt.__name__!r}.") return newt def __eq__(self, t): "Compares types." return self.dtype == t.dtype
[docs] @classmethod def type_name(cls) -> str: "Returns its fullname." s = ElemType.int_names[cls.dtype] return s
[docs] @classmethod def get_set_name(cls, dtypes): "Returns the set name." tt = [] for dt in dtypes: if isinstance(dt, int): tt.append(dt) else: tt.append(dt.dtype) dtypes = set(tt) for d in dir(cls): if dtypes == getattr(cls, d): return d return None
[docs]class ParType: """ Defines a parameter type. :param dtype: parameter type :param optional: is optional or not """ map_names = {int: "int", float: "float", str: "str"} @classmethod def __class_getitem__(cls, dtype): if isinstance(dtype, (int, float)): msg = str(dtype) else: msg = getattr(dtype, "__name__", str(dtype)) newt = type(f"{cls.__name__}{msg}", (cls,), dict(dtype=dtype)) if "<" in newt.__name__: raise NameError(f"Name is wrong {newt.__name__!r}.") return newt
[docs] @classmethod def type_name(cls) -> str: "Returns its full name." if cls.dtype in ParType.map_names: newt = f"ParType[{ParType.map_names[cls.dtype]}]" else: newt = f"ParType[{cls.dtype}]" if "<" in newt or "{" in newt: raise NameError(f"Name is wrong {newt!r}.") return newt
[docs] @classmethod def onnx_type(cls): "Returns the onnx corresponding type." if cls.dtype == int: return AttributeProto.INT if cls.dtype == float: return AttributeProto.FLOAT if cls.dtype == str: return AttributeProto.STRING raise RuntimeError( f"Unsupported attribute type {cls.dtype!r} " f"for parameter {cls!r}." )
[docs]class OptParType(ParType): """ Defines an optional parameter type. :param dtype: parameter type """ @classmethod def __class_getitem__(cls, dtype): if isinstance(dtype, (int, float)): msg = str(dtype) else: msg = dtype.__name__ newt = type(f"{cls.__name__}{msg}", (cls,), dict(dtype=dtype)) if "<" in newt.__name__: raise NameError(f"Name is wrong {newt.__name__!r}.") return newt
[docs] @classmethod def type_name(cls) -> str: "Returns its full name." newt = f"OptParType[{ParType.map_names[cls.dtype]}]" if "<" in newt or "{" in newt: raise NameError(f"Name is wrong {newt!r}.") return newt
class ShapeType(Tuple[int, ...]): """ Defines a shape type. """ @classmethod def __class_getitem__(cls, *args): if any(map(lambda t: t is not None and not isinstance(t, (int, str)), args)): raise TypeError( f"Unexpected value for args={args}, every element should int or str." ) ext = "_".join(map(str, args)) newt = type(f"{cls.__name__}{ext}", (cls,), dict(shape=args)) if "<" in newt.__name__: raise NameError(f"Name is wrong {newt.__name__!r}.") return newt def __repr__(self) -> str: "usual" return f"{self.__class__.__name__}[{self.shape}]" def __str__(self) -> str: "usual" return f"{self.__class__.__name__}[{self.shape}]"
[docs]class TensorType: """ Used to annotate functions. :param dtypes: tuple of :class:`ElemType` :param shape: tuple of integer or strings or None :param name: name of the type """ @classmethod def __class_getitem__(cls, *args): if isinstance(args, tuple) and len(args) == 1 and isinstance(args[0], tuple): args = args[0] name = None dtypes = None shape = None for a in args: if isinstance(a, str): if hasattr(ElemType, a): if dtypes is not None: raise TypeError(f"Unexpected type {type(a)} in {args}.") v = getattr(ElemType, a) dtypes = tuple(v) if isinstance(v, set) else (v,) else: name = a continue if isinstance(a, set): dtypes = tuple(a) continue if isinstance(a, tuple): shape = a continue if isinstance(a, int): if dtypes is not None: raise TypeError(f"Unexpected type {type(a)} in {args}.") dtypes = (a,) continue if a is None: continue if a in ElemType.numpy_map: if dtypes is not None: raise TypeError(f"Unexpected type {type(a)} in {args}.") dtypes = (ElemType.numpy_map[a],) continue raise TypeError(f"Unexpected type {type(a)} in {args}.") if isinstance(dtypes, ElemType): dtypes = (dtypes,) elif ( isinstance(dtypes, str) or dtypes in ElemType.allowed or dtypes in ElemType.numpy_map ): dtypes = (ElemType[dtypes],) if not isinstance(dtypes, tuple): raise TypeError(f"dtypes must be a tuple not {type(dtypes)}, args={args}.") check = [] for dt in dtypes: if isinstance(dt, ElemType): check.append(dt) elif dt in ElemType.allowed: check.append(ElemType[dt]) elif isinstance(dt, int): check.append(ElemType[dt]) else: raise TypeError(f"Unexpected type {type(dt)} in {dtypes}, args={args}.") dtypes = tuple(check) if isinstance(shape, int): shape = (shape,) msg = [] if name: msg.append(name) if dtypes is not None: msg.append("_".join(map(lambda t: str(t.dtype), dtypes))) if shape is not None: msg.append("_".join(map(str, shape))) final = "__".join(msg) if final: final = "_" + final newt = type( f"{cls.__name__}{final}", (cls,), dict(name=name, dtypes=dtypes, shape=shape), ) if "<" in newt.__name__: raise NameError(f"Name is wrong {newt.__name__!r}.") return newt
[docs] @classmethod def type_name(cls) -> str: "Returns its full name." set_name = ElemType.get_set_name(cls.dtypes) if not set_name: st = ( cls.dtypes[0].type_name() if len(cls.dtypes) == 1 else set(t.type_name() for t in cls.dtypes) ) set_name = repr(st) if cls.shape: if cls.name: newt = f"TensorType[{set_name}, {cls.shape!r}, {cls.name!r}]" else: newt = f"TensorType[{set_name}, {cls.shape!r}]" elif cls.name: newt = f"TensorType[{set_name}, {cls.name!r}]" else: newt = f"TensorType[{set_name}]" if "<" in newt or "{" in newt: raise NameError(f"Name is wrong {newt!r}.") return newt
def _name_set(self): s = 0 for dt in self.dtypes: s += 1 << dt.dtype try: return ElemType.set_names[s] except KeyError: raise RuntimeError( f"Unable to guess element type name for {s}: " f"{repr(self)} in {ElemType.set_names}." )
[docs] @classmethod def issuperset(cls, tensor_type: type) -> bool: """ Tells if *cls* is a superset of *tensor_type*. """ set1 = set(t.dtype for t in cls.dtypes) set2 = set(t.dtype for t in tensor_type.dtypes) if not set1.issuperset(set2): return False if cls.shape is None: return True if tensor_type.shape is None: return False if len(cls.shape) != len(tensor_type.shape): return False for a, b in zip(cls.shape, tensor_type.shape): if isinstance(a, int): if a != b: return False return True
[docs]class SequenceType: """ Defines a sequence of tensors. """ @classmethod def __class_getitem__(cls, elem_type: Any, *args) -> "SequenceType": name = None if len(args) == 1: name = args[0] elif len(args) > 1: raise ValueError(f"Unexected value {args}.") if name: newt = type( f"{cls.__name__}_{name}_{elem_type.__name__}", (cls,), dict(name=name, elem_type=elem_type), ) else: newt = type( f"{cls.__name__}{elem_type.__name__}", (cls,), dict(name=name, elem_type=elem_type), ) if "<" in newt.__name__: raise NameError(f"Name is wrong {newt.__name__!r}.") return newt
[docs] @classmethod def type_name(cls) -> str: "Returns its full name." if cls.name: newt = f"SequenceType[{cls.elem_type.type_name()}], {cls.name!r})" else: newt = f"SequenceType[{cls.elem_type.type_name()!r}]" if "<" in newt or "{" in newt: raise NameError(f"Name is wrong {newt!r}.") return newt
[docs]class TupleType: """ Defines a sequence of tensors. """ @classmethod def __class_getitem__(cls, *args) -> "TupleType": if len(args) == 1 and isinstance(args[0], int): return cls.elem_types[args[0]] if isinstance(args, tuple) and len(args) == 1 and isinstance(args[0], tuple): args = args[0] name = None elem_types = [] for a in args: if isinstance(a, str): name = a elif isinstance(a, type) and issubclass(a, TensorType): elem_types.append(a) elif a in (int, float, str): elem_types.append(a) else: raise TypeError( f"Unexpected value type={type(a)}, value={a} in {args}." ) msg = [] if name: msg.append(name) for t in elem_types: msg.append(t.__name__) final = "_".join(msg) newt = type( f"{cls.__name__}_{final}", (cls,), dict(name=name, elem_types=tuple(elem_types)), ) if "<" in newt.__name__: raise NameError(f"Name is wrong {newt.__name__!r}.") return newt
[docs] @classmethod def len(cls): "Returns the number of types." return len(cls.elem_types)
[docs] @classmethod def type_name(cls) -> str: "Returns its full name." dts = ", ".join(map(lambda s: s.type_name(), cls.elem_types)) if cls.name: newt = f"TupleType[{dts}, {cls.name!r}]" else: newt = f"TupleType[{dts}]" if "<" in newt or "{" in newt: raise NameError(f"Name is wrong {newt!r}.") return newt
def _make_type(name: str, elem_type: int): def class_getitem(cls, shape: Union[int, ShapeType]) -> TensorType: if isinstance(shape, int): shape = (shape,) return TensorType[elem_type, shape] new_type = type(name, tuple(), {}) new_type.__class_getitem__ = classmethod(class_getitem) return new_type Bool = _make_type("Bool", ElemType.bool_) BFloat16 = _make_type("BFloat16", ElemType.bfloat16) Float16 = _make_type("Float16", ElemType.float16) Float32 = _make_type("Float32", ElemType.float32) Float64 = _make_type("Float32", ElemType.float64) Int8 = _make_type("int8", ElemType.int8) Int16 = _make_type("int16", ElemType.int16) Int32 = _make_type("int32", ElemType.int32) Int64 = _make_type("int64", ElemType.int64) UInt8 = _make_type("uint8", ElemType.uint8) UInt16 = _make_type("uint16", ElemType.uint16) UInt32 = _make_type("uint32", ElemType.uint32) UInt64 = _make_type("uint64", ElemType.uint64)