Source code for onnx_array_api.npx.npx_numpy_tensors

from typing import Any, Callable, List, Tuple

import numpy as np
from onnx import ModelProto
from onnx.reference import ReferenceEvaluator

from .npx_tensors import EagerTensor, JitTensor
from .npx_types import TensorType


[docs]class NumpyTensor: """ Default backend based on :func:`onnx.reference.ReferenceEvaluator`. :param input_names: input names :param onx: onnx model """
[docs] class Evaluator: """ Wraps class :class:`onnx.reference.ReferenceEvaluator` to have a signature closer to python function. """ def __init__(self, tensor_class: type, input_names: List[str], onx: ModelProto): self.ref = ReferenceEvaluator(onx) self.input_names = input_names self.tensor_class = tensor_class
[docs] def run(self, *inputs: List["NumpyTensor"]) -> List["NumpyTensor"]: """ Executes the function. :param inputs: function inputs :return: outputs """ if len(inputs) != len(self.input_names): raise ValueError( f"Expected {len(self.input_names)} inputs but got {len(inputs)}, " f"self.input_names={self.input_names}, inputs={inputs}." ) feeds = {} for name, inp in zip(self.input_names, inputs): feeds[name] = inp.value res = self.ref.run(None, feeds) return list(map(self.tensor_class, res))
def __init__(self, tensor: np.ndarray): if isinstance(tensor, np.ndarray): self._tensor = tensor elif isinstance(tensor, NumpyTensor): self._tensor = tensor._tensor elif isinstance( tensor, ( np.int64, np.float32, np.float64, np.int32, np.float16, np.int8, np.int16, np.uint8, np.uint16, np.uint32, np.uint64, ), ): self._tensor = np.array(tensor) else: raise TypeError(f"A numpy array is expected not {type(tensor)}.") def __repr__(self) -> str: "usual" return f"{self.__class__.__name__}({self._tensor!r})"
[docs] def numpy(self): "Returns the array converted into a numpy array." return self._tensor
@property def dtype(self) -> Any: "Returns the element type of this tensor." return self._tensor.dtype @property def key(self) -> Any: "Unique key for a tensor of the same type." return (self.dtype, len(self._tensor.shape)) @property def value(self) -> np.ndarray: "Returns the value of this tensor as a numpy array." return self._tensor @property def tensor_type(self) -> TensorType: "Returns the tensor type of this tensor." return TensorType[self.dtype] @property def dims(self): """ Returns the dimensions of the tensor. First dimension is the batch dimension if the tensor has more than one dimension. """ if len(self._tensor.shape) == 0: return (0,) if len(self._tensor.shape) == 1: return self._tensor.shape return (None,) + self._tensor.shape[1:] @property def ndim(self): "Returns the number of dimensions (rank)." return len(self.shape) @property def shape(self) -> Tuple[int, ...]: "Returns the shape of the tensor." return self._tensor.shape @property def tensor_type_dims(self) -> TensorType: """ Returns the tensor type of this tensor. This property is used to define a key used to cache a jitted function. Same keys keys means same ONNX graph. Different keys usually means same ONNX graph but different input shapes. """ return TensorType[self.dtype, self.dims]
[docs] @classmethod def create_function(cls: Any, input_names: List[str], onx: ModelProto) -> Callable: """ Creates a python function calling the onnx backend used by this class. :param onx: onnx model :return: python function """ return cls.Evaluator(cls, input_names, onx)
[docs] @classmethod def get_opsets(cls, opsets): """ Updates the opsets for a given backend. This method should be overloaded. By default, it returns opsets. """ return opsets
[docs] @classmethod def get_ir_version(cls, ir_version): """ Updates the IR version. This method should be overloaded. By default, it returns ir_version. """ return ir_version
[docs] def const_cast(self, to: Any = None) -> "EagerTensor": """ Casts a constant without any ONNX conversion. """ return self.__class__(self._tensor.astype(to))
# The class should support whatever Var supports. # This part is not yet complete.
[docs]class EagerNumpyTensor(NumpyTensor, EagerTensor): """ Defines a value for a specific backend. """ pass
[docs]class JitNumpyTensor(NumpyTensor, JitTensor): """ Defines a value for a specific backend. """ pass