onnx-array-api: (Numpy) Array API for ONNX#

https://dev.azure.com/xavierdupre3/onnx-array-api/_apis/build/status/sdpython.onnx-array-api https://badge.fury.io/py/onnx-array-api.svg GitHub Issues MIT License size https://img.shields.io/badge/code%20style-black-000000.svg

onnx-array-api implements a numpy API for ONNX. It gives the user the ability to convert functions written following the numpy API to convert that function into ONNX as well as to execute it.

Sources available on github/onnx-array-api, see also code coverage.

<<<

import numpy as np  # A
from onnx_array_api.npx import absolute, jit_onnx
from onnx_array_api.plotting.dot_plot import to_dot


def l1_loss(x, y):
    return absolute(x - y).sum()


def l2_loss(x, y):
    return ((x - y) ** 2).sum()


def myloss(x, y):
    return l1_loss(x[:, 0], y[:, 0]) + l2_loss(x[:, 1], y[:, 1])


jitted_myloss = jit_onnx(myloss)

x = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32)
y = np.array([[0.11, 0.22], [0.33, 0.44]], dtype=np.float32)
res = jitted_myloss(x, y)
print(res)

>>>

    [0.042]