Source code for onnx_array_api.npx.npx_core_api

from inspect import _empty, signature
from typing import Any, Callable, Dict, Sequence, Union

import numpy as np
from onnx import FunctionProto, ModelProto, NodeProto

from .npx_tensors import EagerTensor
from .npx_types import ElemType, OptParType, ParType, TupleType
from .npx_var import Cst, Input, ManyIdentity, Par, Var


[docs]def cst(*args, **kwargs): """ Wraps a call to the building of class :class:`Cst`. """ return Cst(*args, **kwargs)
[docs]def tuple_var(*args: Sequence[Var]) -> Var: """ Tie many results all together before being returned by a function. """ return ManyIdentity(*args)
[docs]def make_tuple( n_elements_or_first_variable: Union[int, Var], *args: Sequence[Var], **kwargs: Dict[str, Any], ) -> Var: """ Wraps a call to the building of class :class:`Tuple`. *n_elements_or_first_variable* is the number of elements in the tuple or the number of detected arguments if not specified. """ if isinstance(n_elements_or_first_variable, int): n_elements = n_elements_or_first_variable return Var(*args, n_var_outputs=n_elements, **kwargs) args = [n_elements_or_first_variable, *args] return tuple_var(*args, **kwargs)
[docs]def var(*args: Sequence[Var], **kwargs: Dict[str, Any]) -> Var: """ Wraps a call to the building of class :class:`Var`. """ return Var(*args, **kwargs)
def _process_parameter(fn, sig, k, v, new_pars, inline): annotation = sig.parameters[k].annotation if k in sig.parameters else None if v is None and len(new_pars) == 0 and annotation is None: # It could be an optional input or a parameter. raise NotImplementedError( f"Unable to decide between an optional input or a " f"parameter for name={k!r}." ) if isinstance(v, Par): if inline: new_pars[k] = v.value else: new_pars[k] = v return if isinstance(v, type) and k == "dtype": vto = ElemType.numpy_map[v] if inline: new_pars[k] = vto else: new_pars[k] = Par( k, dtype=ParType[int], value=vto, parent_op=(fn.__module__, fn.__name__, 0), ) return if isinstance(v, (int, float, str, tuple)): if inline: new_pars[k] = v else: new_pars[k] = Par( k, dtype=ParType[type(v)], value=v, parent_op=(fn.__module__, fn.__name__, 0), ) return if isinstance(v, (Cst, Var)): raise TypeError( f"Parameter {k!r} is a tensor ({type(v)}), it is not " f"supported for a named parameter." ) if isinstance(v, (FunctionProto, NodeProto, ModelProto)): new_pars[k] = v return if v is None and issubclass(annotation, OptParType): return raise TypeError( f"Unexpected type for parameter {k!r}, type={type(v)}, " f"annotation={annotation}." ) def _xapi(fn: Callable, inline: bool): """ Decorator to use before any function using part of the numpy API. The function inspects the input and decides which version of the function to call. :param fn: function :param inline: inline the function instead of creating a function """ sig = signature(fn) eager_onnx_tensor_classes = {} # It has the same signature def wrapper(*inputs, **kwargs): if any(map(lambda x: isinstance(x, EagerTensor), inputs)): tensor_class = None for x in inputs: if isinstance(x, EagerTensor): tensor_class = x.__class__ break if tensor_class is None: raise RuntimeError( f"Unable to find an EagerTensor in types " f"{[type(x) for x in inputs]}." ) if tensor_class not in eager_onnx_tensor_classes: from .npx_jit_eager import eager_onnx eager_onnx_tensor_classes[tensor_class] = eager_onnx(fn, tensor_class) eag = eager_onnx_tensor_classes[tensor_class] res = eag(*inputs, already_eager=True, **kwargs) if not isinstance(res, tuple): raise TypeError(f"Return of the eager must be a tuple not {type(res)}.") return res if len(res) > 1 else res[0] # conversion to onnx new_inputs = [] new_pars = {} parnames = {} pos = 0 for name, par in sig.parameters.items(): if par.kind == par.VAR_POSITIONAL: break if par.kind in (par.POSITIONAL_ONLY, par.POSITIONAL_OR_KEYWORD): parnames[pos] = name pos += 1 continue last_input = -1 for ind, i in enumerate(inputs): annotation = ( sig.parameters[parnames[ind]].annotation if ind in parnames else None ) if ( annotation is not None and isinstance(annotation, type) and issubclass(annotation, ParType) ): # no more inputs break last_input = ind if isinstance(i, (Var, np.ndarray)): new_inputs.append(i) elif isinstance(i, (int, float)): new_inputs.append( np.array([i], dtype=np.int64 if isinstance(i, int) else np.float32) ) elif isinstance(i, str): new_inputs.append(Input(i)) elif i is None: # optional input new_inputs.append(None) else: raise TypeError( f"Unexpected type for input {ind}, type={type(i)}. " f"Did you forget to wrap the constant with 'cst(.)'?" ) for ind in range(last_input + 1, len(inputs)): k = parnames[ind] if k in kwargs: break _process_parameter(fn, sig, k, inputs[ind], new_pars, inline) for k, v in kwargs.items(): _process_parameter(fn, sig, k, v, new_pars, inline) if issubclass(sig.return_annotation, TupleType): n_var_outputs = sig.return_annotation.len() return Var( *new_inputs, op=fn, inline=inline, n_var_outputs=n_var_outputs, **new_pars, ) return Var(*new_inputs, op=fn, inline=inline, **new_pars) rows = ["", "", "Signature:", "", "::", "", " ("] for p in sig.parameters.values(): if p.annotation == _empty: rows.append(f" {p.name},") else: if hasattr(p.annotation, "__args__"): args = p.annotation.__args__ if ( isinstance(args, tuple) and len(args) == 2 and isinstance(None, args[1]) ): # args[1] == type(None) # optional annot = args[0] else: raise TypeError( f"Unable to interpret annotation for parameter " f"{p.name!r} with {p.annotation} and args={args}." ) else: annot = p.annotation try: a_name = annot.type_name() except AttributeError as e: raise AttributeError( f"Unexpected annotation type {p.annotation!r}." ) from e rows.append(f" {p.name}: {a_name},") if sig.return_annotation == _empty: rows.append(" ):") else: rows.append(f" ) -> {sig.return_annotation.type_name()}:") wrapper.__doc__ = (fn.__doc__ or "") + "\n" + "\n".join(rows) return wrapper
[docs]def npxapi_function(fn): """ Decorator to use before any function using part of the numpy API. The function inspects the input and decides which version of the function to call. """ return _xapi(fn, inline=False)
[docs]def npxapi_inline(fn): """ Decorator to use before any function using part of the numpy API. The function inspects the input and decides which version of the function to call. """ return _xapi(fn, inline=True)