Source code for pyquickhelper.pycode.unittestclass

"""
Overwrites unit test class with additional testing functions.


:githublink:`%|py|5`
"""
from io import StringIO
import os
import sys
import unittest
import warnings
import decimal
from contextlib import redirect_stdout, redirect_stderr
from .ci_helper import is_travis_or_appveyor
from .profiling import profile
from ..loghelper import fLOG


[docs]class ExtTestCase(unittest.TestCase): """ Overwrites unit test class with additional testing functions. Unless *setUp* is overwritten, warnings *FutureWarning* and *PendingDeprecationWarning* are filtered out. :githublink:`%|py|22` """
[docs] def setUp(self): """ Filters out *FutureWarning*, *PendingDeprecationWarning*. :githublink:`%|py|27` """ warnings.simplefilter("ignore", (FutureWarning, PendingDeprecationWarning, ImportWarning, DeprecationWarning))
[docs] def tearDown(self): """ Stops filtering out *FutureWarning*, *PendingDeprecationWarning*. :githublink:`%|py|37` """ warnings.simplefilter("default", (FutureWarning, PendingDeprecationWarning, ImportWarning, DeprecationWarning))
[docs] @staticmethod def _format_str(s): """ Returns ``s`` or ``'s'`` depending on the type. :githublink:`%|py|48` """ if hasattr(s, "replace"): return "'{0}'".format(s) else: return s
[docs] def assertNotEmpty(self, x): """ Checks that *x* is not empty. :githublink:`%|py|57` """ if x is None or (hasattr(x, "__len__") and len(x) == 0): raise AssertionError("x is empty")
[docs] def assertEmpty(self, x, none_allowed=True): """ Checks that *x* is empty. :githublink:`%|py|64` """ if not((none_allowed and x is None) or (hasattr(x, "__len__") and len(x) == 0)): if isinstance(x, (list, tuple, dict, set)): end = min(5, len(x)) disp = "\n" + '\n'.join(map(str, x[:end])) else: disp = "" raise AssertionError("x is not empty{0}".format(disp))
[docs] def assertGreater(self, x, y, strict=False): # pylint: disable=W0221 """ Checks that ``x >= y``. :githublink:`%|py|76` """ if x < y or (strict and x == y): raise AssertionError("x <{2} y with x={0} and y={1}".format( ExtTestCase._format_str(x), ExtTestCase._format_str(y), "" if strict else "="))
[docs] def assertLesser(self, x, y, strict=False): """ Checks that ``x <= y``. :githublink:`%|py|85` """ if x > y or (strict and x == y): raise AssertionError("x >{2} y with x={0} and y={1}".format( ExtTestCase._format_str(x), ExtTestCase._format_str(y), "" if strict else "="))
[docs] def assertExists(self, name): """ Checks that *name* exists. :githublink:`%|py|94` """ if not os.path.exists(name): raise FileNotFoundError("Unable to find '{0}'.".format(name))
[docs] def assertNotExists(self, name): """ Checks that *name* does not exist. :githublink:`%|py|101` """ if os.path.exists(name): raise FileNotFoundError("Able to find '{0}'.".format(name))
[docs] def assertEqualDataFrame(self, d1, d2, **kwargs): """ Checks that two dataframes are equal. Calls :epkg:`pandas:testing:assert_frame_equal`. :githublink:`%|py|109` """ from pandas.testing import assert_frame_equal assert_frame_equal(d1, d2, **kwargs)
[docs] def assertNotEqualDataFrame(self, d1, d2, **kwargs): """ Checks that two dataframes are different. Calls :epkg:`pandas:testing:assert_frame_equal`. :githublink:`%|py|117` """ from pandas.testing import assert_frame_equal try: assert_frame_equal(d1, d2, **kwargs) except AssertionError: return raise AssertionError("Two dataframes are identical.")
[docs] def assertEqualArray(self, d1, d2, **kwargs): """ Checks that two arrays are equal. Relies on :epkg:`numpy:testing:assert_almost_equal.html`. :githublink:`%|py|129` """ if d1 is None and d2 is None: return if d1 is None: raise AssertionError("d1 is None, d2 is not") if d2 is None: raise AssertionError("d1 is not None, d2 is") from numpy.testing import assert_almost_equal assert_almost_equal(d1, d2, **kwargs)
[docs] def assertNotEqualArray(self, d1, d2, **kwargs): """ Checks that two arrays are equal. Relies on :epkg:`numpy:testing:assert_almost_equal.html`. :githublink:`%|py|143` """ if d1 is None and d2 is None: raise AssertionError("d1 and d2 are equal to None") if d1 is None or d2 is None: return from numpy.testing import assert_almost_equal try: assert_almost_equal(d1, d2, **kwargs) except AssertionError: return raise AssertionError("Two arrays are identical.")
[docs] def assertEqualNumber(self, d1, d2, **kwargs): """ Checks that two numbers are equal. :githublink:`%|py|158` """ from numpy import number if not isinstance(d1, (int, float, decimal.Decimal, number)): raise TypeError('d1 is not a number but {0}'.format(type(d1))) if not isinstance(d2, (int, float, decimal.Decimal, number)): raise TypeError('d2 is not a number but {0}'.format(type(d2))) diff = abs(float(d1 - d2)) mi = float(min(abs(d1), abs(d2))) tol = kwargs.get('precision', None) if tol is None: if diff != 0: raise AssertionError("d1 != d2: {0} != {1}".format(d1, d2)) else: if mi == 0: if diff > tol: raise AssertionError( "d1 != d2: {0} != {1} +/- {2}".format(d1, d2, tol)) else: rel = diff / mi if rel > tol: raise AssertionError( "d1 != d2: {0} != {1} +/- {2}".format(d1, d2, tol))
[docs] def assertRaise(self, fct, exc=None, msg=None): """ Checks that function *fct* with no parameter raises an exception of a given type. :param fct: function to test (no parameter) :param exc: exception type to catch (None for all) :param msg: error message to check (None for no message to check) :githublink:`%|py|189` """ try: fct() except Exception as e: if exc is None: return elif isinstance(e, exc): if msg is None: return if msg not in str(e): raise AssertionError( "Function '{0}' raise exception with wrong message '{1}' (must contain '{2}').".format(fct, e, msg)) return raise AssertionError( "Function '{0}' does not raise exception '{1}' but '{2}' of type '{3}'.".format(fct, exc, e, type(e))) raise AssertionError( "Function '{0}' does not raise exception.".format(fct))
[docs] def capture(self, fct): """ Runs a function and capture standard otuput and error. :githublink:`%|py|210` """ sout = StringIO() serr = StringIO() with redirect_stdout(sout): with redirect_stderr(serr): res = fct() return res, sout.getvalue(), serr.getvalue()
[docs] def assertStartsWith(self, sub, whole): """ Checks that string *sub* starts with *whole*. :githublink:`%|py|221` """ if not whole.startswith(sub): if len(whole) > len(sub) * 2: whole = whole[:len(sub) * 2] raise AssertionError( "'{1}' does not start with '{0}'".format(sub, whole))
[docs] def assertNotStartsWith(self, sub, whole): """ Checks that string *sub* does not start with *whole*. :githublink:`%|py|231` """ if whole.startswith(sub): if len(whole) > len(sub) * 2: whole = whole[:len(sub) * 2] raise AssertionError( "'{1}' starts with '{0}'".format(sub, whole))
[docs] def assertEndsWith(self, sub, whole): """ Checks that string *sub* ends with *whole*. :githublink:`%|py|241` """ if not whole.endswith(sub): if len(whole) > len(sub) * 2: whole = whole[-len(sub) * 2:] raise AssertionError( "'{1}' does not end with '{0}'".format(sub, whole))
[docs] def assertNotEndsWith(self, sub, whole): """ Checks that string *sub* does not end with *whole*. :githublink:`%|py|251` """ if whole.endswith(sub): if len(whole) > len(sub) * 2: whole = whole[-len(sub) * 2:] raise AssertionError( "'{1}' ends with '{0}'".format(sub, whole))
[docs] def assertEqual(self, a, b): # pylint: disable=W0221 """ Checks that ``a == b``. :githublink:`%|py|261` """ if a is None and b is not None: raise AssertionError("a is None, b is not") if a is not None and b is None: raise AssertionError("a is not None, b is") try: unittest.TestCase.assertEqual(self, a, b) except ValueError as e: if "The truth value of a DataFrame is ambiguous" in str(e) or \ "The truth value of an array with more than one element is ambiguous." in str(e): with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=ImportWarning) import pandas if isinstance(a, pandas.DataFrame) and isinstance(b, pandas.DataFrame): self.assertEqualDataFrame(a, b) return import numpy if isinstance(a, numpy.ndarray) and isinstance(b, numpy.ndarray): self.assertEqualArray(a, b) return raise AssertionError("Unable to check equality for types {0} and {1}".format( type(a), type(b))) from e
[docs] def assertNotEqual(self, a, b): # pylint: disable=W0221 """ Checks that ``a != b``. :githublink:`%|py|287` """ if a is None and b is None: raise AssertionError("a is None, b is too") if a is None and b is not None: return if a is not None and b is None: return try: unittest.TestCase.assertNotEqual(self, a, b) except ValueError as e: if "Can only compare identically-labeled DataFrame objects" in str(e) or \ "The truth value of a DataFrame is ambiguous." in str(e) or \ "The truth value of an array with more than one element is ambiguous." in str(e): with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=ImportWarning) import pandas if isinstance(a, pandas.DataFrame) and isinstance(b, pandas.DataFrame): self.assertNotEqualDataFrame(a, b) return import numpy if isinstance(a, numpy.ndarray) and isinstance(b, numpy.ndarray): self.assertNotEqualArray(a, b) return raise e
[docs] def assertEqualFloat(self, a, b, precision=1e-5): """ Checks that ``abs(a-b) < precision``. :githublink:`%|py|315` """ mi = min(abs(a), abs(b)) if mi == 0: d = abs(a - b) self.assertLesser(d, precision) else: r = float(abs(a - b)) / mi self.assertLesser(r, precision)
[docs] def assertCallable(self, fct): """ Checks that *fct* is callable. :githublink:`%|py|327` """ if not callable(fct): raise AssertionError("fct is not callable: {0}".format(type(fct)))
[docs] def assertEqualDict(self, a, b): """ Checks that ``a == b``. :githublink:`%|py|334` """ if not isinstance(a, dict): raise TypeError('a is not dict but {0}'.format(type(a))) if not isinstance(b, dict): raise TypeError('b is not dict but {0}'.format(type(b))) rows = [] for key in sorted(b): if key not in a: rows.append("** Added key '{0}' in b".format(key)) else: if a[key] != b[key]: rows.append( "** Value != for key '{0}': != id({1}) != id({2})\n==1 {3}\n==2 {4}".format( key, id(a[key]), id(b[key]), a[key], b[key])) for key in sorted(a): if key not in b: rows.append("** Removed key '{0}' in a".format(key)) if len(rows) > 0: raise AssertionError( "Dictionaries are different\n{0}".format('\n'.join(rows)))
[docs] def fLOG(self, *args, **kwargs): """ Prints out some information. :func:`fLOG <pyquickhelper.loghelper.flog.fLOG>`. :githublink:`%|py|359` """ fLOG(*args, **kwargs)
[docs] def profile(self, fct, sort='cumulative', rootrem=None): """ Profiles the execution of a function. :param fct: function to profile :param sort: see `sort_stats <https://docs.python.org/3/library/profile.html#pstats.Stats.sort_stats>`_ :param rootrem: root to remove in filenames :return: statistics text dump :githublink:`%|py|370` """ return profile(fct, sort=sort, rootrem=rootrem)
[docs] def read_file(self, filename, mode='r', encoding="utf-8"): """ Returns the content of a file. :param filename: filename :param encoding: encoding :param mode: reading mode :return: content :githublink:`%|py|381` """ self.assertExists(filename) with open(filename, mode, encoding=encoding) as f: return f.read()
[docs] def write_file(self, filename, content, mode='w', encoding='utf-8'): """ Writes the content of a file. :param filename: filename :param content: content to write :param encoding: encoding :param mode: reading mode :return: content :githublink:`%|py|395` """ with open(filename, mode, encoding=encoding) as f: return f.write(content)
[docs]def skipif_appveyor(msg): """ Skips a unit test if it runs on :epkg:`appveyor`. :githublink:`%|py|403` """ if is_travis_or_appveyor() != 'appveyor': return lambda x: x msg = 'Test does not work on appveyor due to: ' + msg return unittest.skip(msg)
[docs]def skipif_travis(msg): """ Skips a unit test if it runs on :epkg:`travis`. :githublink:`%|py|413` """ if is_travis_or_appveyor() != 'travis': return lambda x: x msg = 'Test does not work on travis due to: ' + msg return unittest.skip(msg)
[docs]def skipif_circleci(msg): """ Skips a unit test if it runs on :epkg:`circleci`. :githublink:`%|py|423` """ if is_travis_or_appveyor() != 'circleci': return lambda x: x msg = 'Test does not work on circleci due to: ' + msg return unittest.skip(msg)
[docs]def skipif_azure(msg): """ Skips a unit test if it runs on :epkg:`azure pipeline`. :githublink:`%|py|433` """ if is_travis_or_appveyor() != 'azurepipe': return lambda x: x msg = 'Test does not work on azure pipeline due to: ' + msg return unittest.skip(msg)
[docs]def skipif_azure_linux(msg): """ Skips a unit test if it runs on :epkg:`azure pipeline` on :epkg:`linux`. :githublink:`%|py|443` """ if not sys.platform.startswith('lin') and is_travis_or_appveyor() != 'azurepipe': return lambda x: x msg = 'Test does not work on azure pipeline (linux) due to: ' + msg return unittest.skip(msg)
[docs]def skipif_azure_macosx(msg): """ Skips a unit test if it runs on :epkg:`azure pipeline` on :epkg:`linux`. :githublink:`%|py|453` """ if not sys.platform.startswith('darwin') and is_travis_or_appveyor() != 'azurepipe': return lambda x: x msg = 'Test does not work on azure pipeline (macosx) due to: ' + msg return unittest.skip(msg)
[docs]def skipif_linux(msg): """ Skips a unit test if it runs on :epkg:`linux`. .. versionadded:: 1.7 :githublink:`%|py|465` """ if not sys.platform.startswith('lin'): return lambda x: x msg = 'Test does not work on travis due to: ' + msg return unittest.skip(msg)
[docs]def skipif_vless(version, msg): """ Skips a unit test if the version is stricly below *version* (tuple). .. versionadded:: 1.7 :githublink:`%|py|477` """ if sys.version_info[:3] >= version: return lambda x: x msg = 'Python {} < {}: {}'.format(sys.version_info[:3], version, msg) return unittest.skip(msg)