Tools to test models from the :epkg:`ONNX Zoo`.

.. versionadded:: 0.6

import os
import urllib.request
from collections import OrderedDict
import numpy
from onnx import TensorProto, numpy_helper

[docs]def short_list_zoo_models(): """ Returns a short list from :epkg:`ONNX Zoo`. :return: list of dictionaries. .. runpython:: :showcode: import pprint from import short_list_zoo_models pprint.pprint(short_list_zoo_models()) :githublink:`%|py|26` """ return [ dict(name="mobilenet", model="" "classification/mobilenet/model/mobilenetv2-7.tar.gz"), dict(name="resnet18", model="" "classification/resnet/model/resnet18-v1-7.tar.gz"), dict(name="squeezenet", model="" "classification/squeezenet/model/squeezenet1.0-9.tar.gz", folder="squeezenet"), dict(name="densenet121", model="" "classification/densenet-121/model/densenet-9.tar.gz", folder="densenet121"), dict(name="inception2", model="" "classification/inception_and_googlenet/inception_v2/" "model/inception-v2-9.tar.gz"), dict(name="shufflenet", model="" "classification/shufflenet/model/shufflenet-9.tar.gz"), dict(name="efficientnet-lite4", model="" "classification/efficientnet-lite4/model/" "efficientnet-lite4-11.tar.gz"), ]
[docs]def _download_url(url, output_path, name, verbose=False): if verbose: from tqdm import tqdm class DownloadProgressBar(tqdm): "progress bar hook" def update_to(self, b=1, bsize=1, tsize=None): "progress bar hook" if tsize is not None: = tsize self.update(b * bsize - self.n) with DownloadProgressBar(unit='B', unit_scale=True, miniters=1, desc=name) as t: urllib.request.urlretrieve( url, filename=output_path, reporthook=t.update_to) else: urllib.request.urlretrieve(url, filename=output_path)
[docs]def load_data(folder): """ Restores protobuf data stored in a folder. :param folder: folder :return: dictionary :githublink:`%|py|83` """ res = OrderedDict() res['in'] = OrderedDict() res['out'] = OrderedDict() files = os.listdir(folder) for name in files: noext, ext = os.path.splitext(name) if ext == '.pb': data = TensorProto() with open(os.path.join(folder, name), 'rb') as f: data.ParseFromString( if noext.startswith('input'): res['in'][noext] = numpy_helper.to_array(data) elif noext.startswith('output'): res['out'][noext] = numpy_helper.to_array(data) else: raise ValueError( # pragma: no cover "Unable to guess anything about %r." % noext) return res
[docs]def download_model_data(name, model=None, cache=None, verbose=False): """ Downloads a model and returns a link to the local :epkg:`ONNX` file and data which can be used as inputs. :param name: model name (see :func:`short_list_zoo_models <>`) :param model: url or empty to get the default value returned by :func:`short_list_zoo_models <>`) :param cache: folder to cache the downloaded data :param verbose: display a progress bar :return: local onnx file, input data :githublink:`%|py|116` """ suggested_folder = None if model is None: model_list = short_list_zoo_models() for mod in model_list: if mod['name'] == name: model = mod['model'] if 'folder' in mod: # pylint: disable=R1715 suggested_folder = mod['folder'] break if model is None: raise ValueError( "Unable to find a default value for name=%r." % name) # downloads last_name = model.split('/')[-1] if cache is None: cache = os.path.abspath('.') dest = os.path.join(cache, last_name) if not os.path.exists(dest): _download_url(model, dest, name, verbose=verbose) size = os.stat(dest).st_size if size < 2 ** 20: # pragma: no cover os.remove(dest) raise RuntimeError( "Unable to download model from %r." % model) outtar = os.path.splitext(dest)[0] if not os.path.exists(outtar): from pyquickhelper.filehelper.compression_helper import ( ungzip_files) ungzip_files(dest, unzip=False, where_to=cache, remove_space=False) onnx_file = os.path.splitext(outtar)[0] if not os.path.exists(onnx_file): from pyquickhelper.filehelper.compression_helper import ( untar_files) untar_files(outtar, where_to=cache) if suggested_folder is not None: fold_onnx = [suggested_folder] else: fold_onnx = [onnx_file, onnx_file.split('-')[0], '-'.join(onnx_file.split('-')[:-1]), '-'.join(onnx_file.split('-')[:-1]).replace('-', '_')] fold_onnx_ok = [_ for _ in fold_onnx if os.path.exists(_)] if len(fold_onnx_ok) != 1: raise FileNotFoundError( # pragma: no cover "Unable to find an existing folder among %r." % fold_onnx) onnx_file = fold_onnx_ok[0] onnx_files = [_ for _ in os.listdir(onnx_file) if _.endswith(".onnx")] if len(onnx_files) != 1: raise FileNotFoundError( # pragma: no cover "Unable to find any onnx file in %r." % onnx_files) final_onnx = os.path.join(onnx_file, onnx_files[0]) # data data = [_ for _ in os.listdir(onnx_file) if os.path.isdir(os.path.join(onnx_file, _))] examples = OrderedDict() for f in data: examples[f] = load_data(os.path.join(onnx_file, f)) return final_onnx, examples
[docs]def verify_model(onnx_file, examples, runtime=None, abs_tol=5e-4, verbose=0, fLOG=None): """ Verifies a model. :param onnx_file: ONNX file :param examples: list of examples to verify :param runtime: a runtime to use :param abs_tol: error tolerance when checking the output :param verbose: verbosity level for for runtime other than `'onnxruntime'` :param fLOG: logging function when `verbose > 0` :return: errors for every sample :githublink:`%|py|196` """ if runtime == 'onnxruntime': from onnxruntime import InferenceSession sess = InferenceSession(onnx_file) meth = lambda data, s=sess:, data) names = [ for p in sess.get_inputs()] onames = list(range(len(sess.get_outputs()))) else: def _lin_(sess, data, names): r =, verbose=verbose, fLOG=fLOG) return [r[n] for n in names] from ..onnxrt import OnnxInference sess = OnnxInference(onnx_file, runtime=runtime) names = sess.input_names onames = sess.output_names meth = lambda data, s=sess, ns=onames: _lin_(s, data, ns) rows = [] for index, (name, data_inout) in enumerate(examples.items()): data = data_inout["in"] if len(data) != len(names): raise RuntimeError( "Mismathed number of inputs %d != %d\ninputs: %r\nmodel: %r." "" % (len(data), len(names), list(sorted(data)), names)) inputs = {n: data[v] for n, v in zip(names, data)} outputs = meth(inputs) expected = data_inout['out'] if len(outputs) != len(onames): raise RuntimeError( "Number of outputs %d is != expected outputs %d." % ( len(outputs), len(onames))) for i, (output, expect) in enumerate(zip(outputs, expected.items())): if output.shape != expect[1].shape: raise ValueError( "Shape mismatch got %r != expected %r." % ( output.shape, expect[1].shape)) diff = numpy.abs(output - expect[1]).ravel() absolute = diff.max() relative = absolute / numpy.median(diff) if absolute > 0 else 0. if absolute > abs_tol: raise ValueError( "Example %d, inferred and expected resuls are different " "for output %d: abs=%r rel=%r (runtime=%r)." "" % (index, i, absolute, relative, runtime)) rows.append(dict(name=name, i=i, abs=absolute, rel=relative)) return rows