Source code for onnx_array_api.ort.ort_tensors

from typing import Any, Callable, List, Optional, Tuple, Union

import numpy as np
from onnx import ModelProto, TensorProto
from onnx.defs import onnx_opset_version
from onnxruntime import InferenceSession, RunOptions, get_available_providers
from onnxruntime.capi._pybind_state import OrtDevice as C_OrtDevice
from onnxruntime.capi._pybind_state import OrtMemType
from onnxruntime.capi._pybind_state import OrtValue as C_OrtValue
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidArgument

from ..npx.npx_tensors import EagerTensor, JitTensor
from ..npx.npx_types import TensorType


[docs]class OrtTensor: """ Default backend based on :class:`onnxruntime.InferenceSession`. Data is not copied. :param input_names: input names :param onx: onnx model """ CPU = C_OrtDevice(C_OrtDevice.cpu(), OrtMemType.DEFAULT, 0) CUDA0 = C_OrtDevice(C_OrtDevice.cuda(), OrtMemType.DEFAULT, 0) providers = [ c for c in ["CUDAExecutionProvider", "CPUExecutionProvider"] if c in get_available_providers() ]
[docs] @staticmethod def from_array( value: np.ndarray, device: Optional[C_OrtDevice] = None ) -> "OrtTensor": """ Creates an instance of :class:`OrtTensor` from a numpy array. Relies on `ortvalue_from_numpy`. A copy of the data in the Numpy object is held by the :class:`C_OrtValue` only if the device is **not cpu**. Any expression such as `from_array(x.copy())`, or `from_array(x.astype(np.float32))`, ... creates an intermediate variable scheduled to be deleted by the garbage collector as soon as the function returns. In that case, the buffer holding the values is deleted and the instance `OrtTenor` is no longer equal to the original value: `assert_allclose(value, tensor.numpy())` is false. `value` must remain alive as long as the `OrtTensor` is. :param value: value :param device: CPU, GPU, value such as `OrtTensor.CPU`, `OrtTensor.CUDA0` :return: instance of :class:`OrtTensor` """ if device is None: device = OrtTensor.CPU return OrtTensor(C_OrtValue.ortvalue_from_numpy(value, device))
[docs] def numpy(self) -> np.ndarray: """ Converts the :class:`OrtValue` into numpy array. """ return self._tensor.numpy()
[docs] class Evaluator: """ Wraps class :class:`onnxruntime.InferenceSession` to have a signature closer to python function. """ def __init__(self, tensor_class: type, input_names: List[str], onx: ModelProto): try: self.ref = InferenceSession( onx.SerializeToString(), providers=tensor_class.providers, ) except InvalidArgument as e: if ( len(onx.graph.output) == 1 and onx.graph.output[0].type.tensor_type.elem_type == TensorProto.UNDEFINED ): # ShapeInference cannot use python function for unknown node type. # Let's give the only output the same type as the first # input. onx.graph.output[0].type.tensor_type.elem_type = onx.graph.input[ 0 ].type.tensor_type.elem_type self.ref = InferenceSession( onx.SerializeToString(), providers=tensor_class.providers, ) else: if len(onx.graph.node) <= 3: raise RuntimeError( f"Unable to create an InferenceSession with model {onx}." ) from e raise e self.input_names = input_names self.tensor_class = tensor_class self.output_names = [output.name for output in self.ref._outputs_meta] self.run_options = RunOptions()
[docs] def run(self, *inputs: List["OrtTensor"]) -> List["OrtTensor"]: """ Executes the function. :param inputs: function inputs :return: outputs """ if len(inputs) != len(self.input_names): raise ValueError( f"Expected {len(self.input_names)} inputs but got " f"len(inputs)={len(inputs)}." ) feeds = {} for name, inp in zip(self.input_names, inputs): feeds[name] = inp.value res = self.ref._sess.run_with_ort_values( feeds, self.output_names, self.run_options ) return list(map(inputs[0].__class__, res))
def __init__(self, tensor: Union[C_OrtValue, "OrtTensor"]): if isinstance(tensor, C_OrtValue): self._tensor = tensor elif isinstance(tensor, OrtTensor): self._tensor = tensor._tensor else: raise ValueError(f"An OrtValue is expected not {type(tensor)}.") @property def shape(self) -> Tuple[int, ...]: "Returns the shape of the tensor." return self._tensor.shape() @property def dtype(self) -> Any: "Returns the element type of this tensor." return self._tensor.element_type() @property def key(self) -> Any: "Unique key for a tensor of the same type." return (self.dtype, len(self.shape)) @property def value(self) -> C_OrtValue: "Returns the value of this tensor as a numpy array." return self._tensor @property def tensor_type(self) -> TensorType: "Returns the tensor type of this tensor." return TensorType[self.dtype] @property def dims(self): """ Returns the dimensions of the tensor. First dimension is the batch dimension if the tensor has more than one dimension. """ if len(self.shape) == 0: return (0,) if len(self.shape) == 1: return tuple(self.shape) return (None, *tuple(self.shape[1:])) @property def tensor_type_dims(self) -> TensorType: """ Returns the tensor type of this tensor. This property is used to define a key used to cache a jitted function. Same keys keys means same ONNX graph. Different keys usually means same ONNX graph but different input shapes. """ return TensorType[self.dtype, self.dims]
[docs] @classmethod def create_function(cls: Any, input_names: List[str], onx: ModelProto) -> Callable: """ Creates a python function calling the onnx backend used by this class. :param onx: onnx model :return: python function """ return cls.Evaluator(cls, input_names, onx)
class OrtCommon: """ Common methods to jit and eager mode. """ @classmethod def get_opsets(cls, opsets): if opsets is None: return {"": min(onnx_opset_version(), 18), "com.microsoft": 1} if "com.microsoft" in opsets: return opsets opsets = opsets.copy() opsets.update({"com.microsoft": 1}) return opsets @classmethod def get_ir_version(cls, ir_version): if ir_version is None: return 8 return min(ir_version, 8)
[docs]class EagerOrtTensor(OrtTensor, OrtCommon, EagerTensor): """ Defines a value for a specific backend. """ pass
[docs]class JitOrtTensor(OrtTensor, OrtCommon, JitTensor): """ Defines a value for a specific backend. """ pass