Source code for onnx_array_api.ext_test_case

import sys
import unittest
import warnings
from contextlib import redirect_stderr, redirect_stdout
from io import StringIO
from timeit import Timer
from typing import Any, Callable, Dict, List, Optional

import numpy
from numpy.testing import assert_allclose


[docs]def ignore_warnings(warns: List[Warning]) -> Callable: """ Catches warnings. @param warns warnings to ignore """ def wrapper(fct): if warns is None: raise AssertionError(f"warns cannot be None for '{fct}'.") def call_f(self): with warnings.catch_warnings(): warnings.simplefilter("ignore", warns) return fct(self) return call_f return wrapper
[docs]def measure_time( stmt: Callable, context: Optional[Dict[str, Any]] = None, repeat: int = 10, number: int = 50, div_by_number: bool = True, max_time: Optional[float] = None, ) -> Dict[str, Any]: """ Measures a statement and returns the results as a dictionary. :param stmt: string :param context: variable to know in a dictionary :param repeat: average over *repeat* experiment :param number: number of executions in one row :param div_by_number: divide by the number of executions :param max_time: execute the statement until the total goes beyond this time (approximatively), *repeat* is ignored, *div_by_number* must be set to True :return: dictionary .. runpython:: :showcode: from onnx_array_api.ext_test_case import measure_time from math import cos res = measure_time(lambda: cos(0.5)) print(res) See `Timer.repeat <https://docs.python.org/3/library/ timeit.html?timeit.Timer.repeat>`_ for a better understanding of parameter *repeat* and *number*. The function returns a duration corresponding to *number* times the execution of the main statement. .. versionchanged:: 0.4 Parameter *max_time* was added. """ if not callable(stmt) and not isinstance(stmt, str): raise TypeError( f"stmt is not callable or a string but is of type {type(stmt)!r}." ) if context is None: context = {} import numpy if isinstance(stmt, str): tim = Timer(stmt, globals=context) else: tim = Timer(stmt) if max_time is not None: if not div_by_number: raise ValueError( "div_by_number must be set to True of max_time is defined." ) i = 1 total_time = 0 results = [] while True: for j in (1, 2): number = i * j time_taken = tim.timeit(number) results.append((number, time_taken)) total_time += time_taken if total_time >= max_time: break if total_time >= max_time: break ratio = (max_time - total_time) / total_time ratio = max(ratio, 1) i = int(i * ratio) res = numpy.array(results) tw = res[:, 0].sum() ttime = res[:, 1].sum() mean = ttime / tw ave = res[:, 1] / res[:, 0] dev = (((ave - mean) ** 2 * res[:, 0]).sum() / tw) ** 0.5 mes = dict( average=mean, deviation=dev, min_exec=numpy.min(ave), max_exec=numpy.max(ave), repeat=1, number=tw, ttime=ttime, ) else: res = numpy.array(tim.repeat(repeat=repeat, number=number)) if div_by_number: res /= number mean = numpy.mean(res) dev = numpy.mean(res**2) dev = (dev - mean**2) ** 0.5 mes = dict( average=mean, deviation=dev, min_exec=numpy.min(res), max_exec=numpy.max(res), repeat=repeat, number=number, ttime=res.sum(), ) if "values" in context: if hasattr(context["values"], "shape"): mes["size"] = context["values"].shape[0] else: mes["size"] = len(context["values"]) else: mes["context_size"] = sys.getsizeof(context) return mes
[docs]class ExtTestCase(unittest.TestCase): _warns = [] def assertEqualArray( self, expected: numpy.ndarray, value: numpy.ndarray, atol: float = 0, rtol: float = 0, ): self.assertEqual(expected.dtype, value.dtype) self.assertEqual(expected.shape, value.shape) assert_allclose(expected, value, atol=atol, rtol=rtol) def assertRaise(self, fct: Callable, exc_type: Exception): try: fct() except exc_type as e: if not isinstance(e, exc_type): raise AssertionError(f"Unexpected exception {type(e)!r}.") return raise AssertionError("No exception was raised.") def assertEmpty(self, value: Any): if value is None: return if len(value) == 0: return raise AssertionError(f"value is not empty: {value!r}.") def assertNotEmpty(self, value: Any): if value is None: raise AssertionError(f"value is empty: {value!r}.") if isinstance(value, (list, dict, tuple, set)): if len(value) == 0: raise AssertionError(f"value is empty: {value!r}.") def assertStartsWith(self, prefix: str, full: str): if not full.startswith(prefix): raise AssertionError(f"prefix={prefix!r} does not start string {full!r}.")
[docs] @classmethod def tearDownClass(cls): for name, line, w in cls._warns: warnings.warn(f"\n{name}:{line}: {type(w)}\n {str(w)}")
[docs] def capture(self, fct: Callable): """ Runs a function and capture standard output and error. :param fct: function to run :return: result of *fct*, output, error """ sout = StringIO() serr = StringIO() with redirect_stdout(sout): with redirect_stderr(serr): res = fct() return res, sout.getvalue(), serr.getvalue()