"""
Overwrites unit test class with additional testing functions.
:githublink:`%|py|5`
"""
from io import StringIO
import os
import sys
import unittest
import warnings
import decimal
import pprint
from logging import getLogger, INFO, StreamHandler
from contextlib import redirect_stdout, redirect_stderr
from .ci_helper import is_travis_or_appveyor
from .profiling import profile
from ..texthelper import compare_module_version
[docs]class ExtTestCase(unittest.TestCase):
"""
Overwrites unit test class with additional testing functions.
Unless *setUp* is overwritten, warnings *FutureWarning* and
*PendingDeprecationWarning* are filtered out.
:githublink:`%|py|24`
"""
[docs] def setUp(self):
"""
Filters out *FutureWarning*, *PendingDeprecationWarning*.
:githublink:`%|py|29`
"""
warnings.simplefilter("ignore",
(FutureWarning,
PendingDeprecationWarning,
ImportWarning,
DeprecationWarning))
[docs] def tearDown(self):
"""
Stops filtering out *FutureWarning*, *PendingDeprecationWarning*.
:githublink:`%|py|39`
"""
warnings.simplefilter("default",
(FutureWarning,
PendingDeprecationWarning,
ImportWarning,
DeprecationWarning))
[docs] def assertNotEmpty(self, x):
"""
Checks that *x* is not empty.
:githublink:`%|py|58`
"""
if x is None or (hasattr(x, "__len__") and len(x) == 0):
raise AssertionError("x is empty")
[docs] def assertEmpty(self, x, none_allowed=True):
"""
Checks that *x* is empty.
:githublink:`%|py|65`
"""
if not((none_allowed and x is None) or (hasattr(x, "__len__") and len(x) == 0)):
if isinstance(x, (list, tuple, dict, set)):
end = min(5, len(x))
disp = "\n" + '\n'.join(map(str, x[:end]))
else:
disp = ""
raise AssertionError("x is not empty{0}".format(disp))
[docs] def assertGreater(self, x, y, strict=False): # pylint: disable=W0221
"""
Checks that ``x >= y``.
:githublink:`%|py|77`
"""
if x < y or (strict and x == y):
raise AssertionError("x <{2} y with x={0} and y={1}".format(
ExtTestCase._format_str(x), ExtTestCase._format_str(y),
"" if strict else "="))
[docs] def assertLesser(self, x, y, strict=False):
"""
Checks that ``x <= y``.
:githublink:`%|py|86`
"""
if x > y or (strict and x == y):
raise AssertionError("x >{2} y with x={0} and y={1}".format(
ExtTestCase._format_str(x), ExtTestCase._format_str(y),
"" if strict else "="))
[docs] def assertExists(self, name):
"""
Checks that *name* exists.
:githublink:`%|py|95`
"""
if not os.path.exists(name):
raise FileNotFoundError("Unable to find '{0}'.".format(name))
[docs] def assertNotExists(self, name):
"""
Checks that *name* does not exist.
:githublink:`%|py|102`
"""
if os.path.exists(name):
raise FileNotFoundError("Able to find '{0}'.".format(name))
[docs] def assertEqualDataFrame(self, d1, d2, **kwargs):
"""
Checks that two dataframes are equal.
Calls :epkg:`pandas:testing:assert_frame_equal`.
:githublink:`%|py|110`
"""
from pandas.testing import assert_frame_equal
assert_frame_equal(d1, d2, **kwargs)
[docs] def assertNotEqualDataFrame(self, d1, d2, **kwargs):
"""
Checks that two dataframes are different.
Calls :epkg:`pandas:testing:assert_frame_equal`.
:githublink:`%|py|118`
"""
from pandas.testing import assert_frame_equal
try:
assert_frame_equal(d1, d2, **kwargs)
except AssertionError:
return
raise AssertionError("Two dataframes are identical.")
[docs] def assertEqualArray(self, d1, d2, squeeze=False, **kwargs):
"""
Checks that two arrays are equal.
Relies on :epkg:`numpy:testing:assert_almost_equal`.
:githublink:`%|py|130`
"""
if d1 is None and d2 is None:
return
if d1 is None:
raise AssertionError("d1 is None, d2 is not")
if d2 is None:
raise AssertionError("d1 is not None, d2 is")
from numpy.testing import assert_almost_equal
import numpy
if squeeze:
d1 = numpy.squeeze(d1)
d2 = numpy.squeeze(d2)
assert_almost_equal(d1, d2, **kwargs)
[docs] def assertHasNoNan(self, a): # pylint: disable=W0221
"""
Checks that there is no NaN in ``a``.
:githublink:`%|py|147`
"""
if a is None:
raise AssertionError("a is None")
import numpy
if any(map(numpy.isnan, a.ravel())):
raise AssertionError("a has nan:\n{}".format(a))
def assertEqualSparseArray(self, d1, d2, **kwargs):
if type(d1) != type(d2): # pylint: disable=C0123
raise AssertionError("d1 and d2 have difference types {} != {}.".format(
type(d1), type(d2)))
if d1 is None and d2 is None:
return
if (hasattr(d1, 'data') and hasattr(d1, 'row') and hasattr(d1, 'col') and
hasattr(d2, 'data') and hasattr(d2, 'row') and hasattr(d2, 'col')):
# coo_matrix
self.assertEqual(d1.shape, d2.shape)
self.assertEqualArray(d1.data, d2.data)
self.assertEqualArray(d1.row, d2.row)
self.assertEqualArray(d1.col, d2.col)
return
if (hasattr(d1, 'data') and hasattr(d1, 'indices') and hasattr(d1, 'indptr') and
hasattr(d2, 'data') and hasattr(d2, 'indices') and hasattr(d2, 'indptr')):
# coo_matrix
self.assertEqual(d1.shape, d2.shape)
self.assertEqualArray(d1.data, d2.data)
self.assertEqualArray(d1.indices, d2.indices)
self.assertEqualArray(d1.indptr, d2.indptr)
return
raise NotImplementedError( # pragma: no cover
"Comparison not implemented for types {} and {}.".format(
type(d1), type(d2)))
[docs] def assertNotEqualArray(self, d1, d2, squeeze=False, **kwargs):
"""
Checks that two arrays are equal.
Relies on :epkg:`numpy:testing:assert_almost_equal`.
:githublink:`%|py|184`
"""
if d1 is None and d2 is None:
raise AssertionError("d1 and d2 are equal to None")
if d1 is None or d2 is None:
return
from numpy.testing import assert_almost_equal
import numpy
if squeeze:
d1 = numpy.squeeze(d1)
d2 = numpy.squeeze(d2)
try:
assert_almost_equal(d1, d2, **kwargs)
except AssertionError:
return
raise AssertionError("Two arrays are identical.")
[docs] def assertEqualNumber(self, d1, d2, **kwargs):
"""
Checks that two numbers are equal.
:githublink:`%|py|203`
"""
from numpy import number
if not isinstance(d1, (int, float, decimal.Decimal, number)):
raise TypeError('d1 is not a number but {0}'.format(type(d1)))
if not isinstance(d2, (int, float, decimal.Decimal, number)):
raise TypeError('d2 is not a number but {0}'.format(type(d2)))
diff = abs(float(d1 - d2))
mi = float(min(abs(d1), abs(d2)))
tol = kwargs.get('precision', None)
if tol is None:
if diff != 0:
raise AssertionError("d1 != d2: {0} != {1}".format(d1, d2))
else:
if mi == 0:
if diff > tol: # pragma: no cover
raise AssertionError(
"d1 != d2: {0} != {1} +/- {2}".format(d1, d2, tol))
else:
rel = diff / mi
if rel > tol:
raise AssertionError( # pragma: no cover
"d1 != d2: {0} != {1} +/- {2}".format(d1, d2, tol))
[docs] def assertRaise(self, fct, exc=None, msg=None):
"""
Checks that function *fct* with no parameter
raises an exception of a given type.
:param fct: function to test (no parameter)
:param exc: exception type to catch (None for all)
:param msg: error message to check (None for no message to check)
:githublink:`%|py|234`
"""
try:
fct()
except Exception as e:
if exc is None:
return # pragma: no cover
elif isinstance(e, exc):
if msg is None:
return
if msg not in str(e):
raise AssertionError( # pragma: no cover
"Function '{0}' raise exception with wrong message '{1}' "
"(must contain '{2}').".format(fct, e, msg))
return
raise AssertionError(
"Function '{0}' does not raise exception '{1}' but '{2}' of type "
"'{3}'.".format(fct, exc, e, type(e)))
raise AssertionError( # pragma: no cover
"Function '{0}' does not raise exception.".format(fct))
[docs] def capture(self, fct):
"""
Runs a function and capture standard output and error.
:param fct: function to run
:return: result of *fct*, output, error
:githublink:`%|py|260`
"""
sout = StringIO()
serr = StringIO()
with redirect_stdout(sout):
with redirect_stderr(serr):
res = fct()
return res, sout.getvalue(), serr.getvalue()
[docs] def assertStartsWith(self, sub, whole):
"""
Checks that string *sub* starts with *whole*.
:githublink:`%|py|271`
"""
if not whole.startswith(sub):
if len(whole) > len(sub) * 2:
whole = whole[:len(sub) * 2] # pragma: no cover
raise AssertionError(
"'{1}' does not start with '{0}'".format(sub, whole))
[docs] def assertNotStartsWith(self, sub, whole):
"""
Checks that string *sub* does not start with *whole*.
:githublink:`%|py|281`
"""
if whole.startswith(sub):
if len(whole) > len(sub) * 2:
whole = whole[:len(sub) * 2] # pragma: no cover
raise AssertionError(
"'{1}' starts with '{0}'".format(sub, whole))
[docs] def assertEndsWith(self, sub, whole):
"""
Checks that string *sub* ends with *whole*.
:githublink:`%|py|291`
"""
if not whole.endswith(sub):
if len(whole) > len(sub) * 2:
whole = whole[-len(sub) * 2:] # pragma: no cover
raise AssertionError(
"'{1}' does not end with '{0}'".format(sub, whole))
[docs] def assertNotEndsWith(self, sub, whole):
"""
Checks that string *sub* does not end with *whole*.
:githublink:`%|py|301`
"""
if whole.endswith(sub):
if len(whole) > len(sub) * 2:
whole = whole[-len(sub) * 2:]
raise AssertionError(
"'{1}' ends with '{0}'".format(sub, whole))
[docs] def assertEqual(self, a, b): # pylint: disable=W0221
"""
Checks that ``a == b``.
:githublink:`%|py|311`
"""
if a is None and b is not None:
raise AssertionError("a is None, b is not")
if a is not None and b is None:
raise AssertionError("a is not None, b is")
try:
unittest.TestCase.assertEqual(self, a, b)
except ValueError as e:
if "The truth value of a DataFrame is ambiguous" in str(e) or \
"The truth value of an array with more than one element is ambiguous." in str(e):
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=ImportWarning)
import pandas
if isinstance(a, pandas.DataFrame) and isinstance(b, pandas.DataFrame):
self.assertEqualDataFrame(a, b)
return
import numpy
if isinstance(a, numpy.ndarray) and isinstance(b, numpy.ndarray):
self.assertEqualArray(a, b)
return
raise AssertionError( # pragma: no cover
"Unable to check equality for types {0} and {1}".format(
type(a), type(b))) from e
[docs] def assertNotEqual(self, a, b): # pylint: disable=W0221
"""
Checks that ``a != b``.
:githublink:`%|py|338`
"""
if a is None and b is None:
raise AssertionError("a is None, b is too") # pragma: no cover
if a is None and b is not None:
return # pragma: no cover
if a is not None and b is None:
return # pragma: no cover
try:
unittest.TestCase.assertNotEqual(self, a, b)
except ValueError as e:
if "Can only compare identically-labeled DataFrame objects" in str(e) or \
"The truth value of a DataFrame is ambiguous." in str(e) or \
"The truth value of an array with more than one element is ambiguous." in str(e):
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=ImportWarning)
import pandas
if isinstance(a, pandas.DataFrame) and isinstance(b, pandas.DataFrame):
self.assertNotEqualDataFrame(a, b)
return
import numpy
if isinstance(a, numpy.ndarray) and isinstance(b, numpy.ndarray):
self.assertNotEqualArray(a, b)
return
raise e # pragma: no cover
[docs] def assertEqualFloat(self, a, b, precision=1e-5):
"""
Checks that ``abs(a-b) < precision``.
:githublink:`%|py|366`
"""
mi = min(abs(a), abs(b))
if mi == 0:
d = abs(a - b)
try:
self.assertLesser(d, precision)
except AssertionError:
raise AssertionError("{} != {} (p={})".format(a, b, precision))
else:
r = float(abs(a - b)) / mi
try:
self.assertLesser(r, precision)
except AssertionError:
raise AssertionError("{} != {} (p={})".format(a, b, precision))
[docs] def assertCallable(self, fct):
"""
Checks that *fct* is callable.
:githublink:`%|py|384`
"""
if not callable(fct):
raise AssertionError("fct is not callable: {0}".format(type(fct)))
[docs] def assertEqualDict(self, a, b):
"""
Checks that ``a == b``.
:githublink:`%|py|391`
"""
if not isinstance(a, dict):
raise TypeError('a is not dict but {0}'.format(type(a)))
if not isinstance(b, dict):
raise TypeError('b is not dict but {0}'.format(type(b)))
rows = []
for key in sorted(b):
if key not in a:
rows.append("** Added key '{0}' in b".format(key))
else:
if a[key] != b[key]:
rows.append(
"** Value != for key '{0}': != id({1}) != id({2})\n==1 {3}\n==2 {4}".format(
key, id(a[key]), id(b[key]), a[key], b[key]))
for key in sorted(a):
if key not in b:
rows.append("** Removed key '{0}' in a".format(key))
if len(rows) > 0:
raise AssertionError(
"Dictionaries are different\n{0}".format('\n'.join(rows)))
[docs] def fLOG(self, *args, **kwargs):
"""
Prints out some information.
:func:`fLOG <pyquickhelper.loghelper.flog.fLOG>`.
:githublink:`%|py|416`
"""
# delayed import
from ..loghelper import fLOG as _flog # pragma: no cover
_flog(*args, **kwargs) # pragma: no cover
[docs] def profile(self, fct, sort='cumulative', rootrem=None):
"""
Profiles the execution of a function.
:param fct: function to profile
:param sort: see `sort_stats <https://docs.python.org/3/library/profile.html#pstats.Stats.sort_stats>`_
:param rootrem: root to remove in filenames
:return: statistics text dump
:githublink:`%|py|429`
"""
return profile(fct, sort=sort, rootrem=rootrem)
[docs] def read_file(self, filename, mode='r', encoding="utf-8"):
"""
Returns the content of a file.
:param filename: filename
:param encoding: encoding
:param mode: reading mode
:return: content
:githublink:`%|py|440`
"""
self.assertExists(filename)
with open(filename, mode, encoding=encoding) as f:
return f.read()
[docs] def write_file(self, filename, content, mode='w', encoding='utf-8'):
"""
Writes the content of a file.
:param filename: filename
:param content: content to write
:param encoding: encoding
:param mode: reading mode
:return: content
:githublink:`%|py|454`
"""
with open(filename, mode, encoding=encoding) as f:
return f.write(content)
[docs] def assertIn(self, sub, ensemble, msg=None): # pylint: disable=W0221
"""
Checks that substring *sub* is in *text*.
:param sub: sub set
:param ensemble: full set
:param msg: error message
@raises AssertionError
:githublink:`%|py|466`
"""
if sub is None:
return # pragma: no cover
if ensemble is None:
raise AssertionError(msg or "'text' is None") # pragma: no cover
if sub not in ensemble:
raise AssertionError( # pragma: no cover
msg or "Unable to find '{}' in\n{}".format(
sub, pprint.pformat(ensemble)))
[docs] def assertWarning(self, fct):
"""
Returns the list of warnings raised while
executing function *fct*.
:param fct: function to run
:return: result, list of warnings
:githublink:`%|py|483`
"""
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
r = fct()
return r, list(w)
[docs] def assertLogging(self, fct, logger_name, level=INFO, log_sphinx=False):
"""
Returns the logged information in a logger defined
by its name.
:param fct: function to run
:param logger_name: logger name
:param level: level to intercept
:param log_sphinx: logging from :epkg:`sphinx`
:return: result, logged information
:githublink:`%|py|499`
"""
from sphinx.util import logging as logging_sphinx
class MyStream:
def __init__(self):
self.rows = []
def write(self, text):
self.rows.append(text)
def getvalue(self):
return "\n".join(self.rows)
def __len__(self):
return len(self.rows)
logger = (logging_sphinx.getLogger(logger_name).logger
if log_sphinx else getLogger(logger_name))
hs = list(logger.handlers)
for h in logger.handlers:
logger.removeHandler(h) # pragma: no cover
log_capture_string = MyStream()
ch = StreamHandler(log_capture_string)
ch.setLevel(level)
logger.addHandler(ch)
res = fct()
logs = log_capture_string.getvalue()
logger.removeHandler(ch)
for h in hs:
logger.addHandler(h) # pragma: no cover
return res, logs
[docs]def skipif_appveyor(msg):
"""
Skips a unit test if it runs on :epkg:`appveyor`.
:githublink:`%|py|540`
"""
if is_travis_or_appveyor() != 'appveyor':
return lambda x: x
msg = 'Test does not work on appveyor due to: ' + msg # pragma: no cover
return unittest.skip(msg) # pragma: no cover
[docs]def skipif_travis(msg):
"""
Skips a unit test if it runs on :epkg:`travis`.
:githublink:`%|py|550`
"""
if is_travis_or_appveyor() != 'travis':
return lambda x: x
msg = 'Test does not work on travis due to: ' + msg # pragma: no cover
return unittest.skip(msg) # pragma: no cover
[docs]def skipif_circleci(msg):
"""
Skips a unit test if it runs on :epkg:`circleci`.
:githublink:`%|py|560`
"""
if is_travis_or_appveyor() != 'circleci':
return lambda x: x
msg = 'Test does not work on circleci due to: ' + msg # pragma: no cover
return unittest.skip(msg) # pragma: no cover
[docs]def skipif_azure(msg):
"""
Skips a unit test if it runs on :epkg:`azure pipeline`.
:githublink:`%|py|570`
"""
if is_travis_or_appveyor() != 'azurepipe':
return lambda x: x # pragma: no cover
msg = 'Test does not work on azure pipeline due to: ' + msg # pragma: no cover
return unittest.skip(msg) # pragma: no cover
[docs]def skipif_azure_linux(msg):
"""
Skips a unit test if it runs on :epkg:`azure pipeline` on :epkg:`linux`.
:githublink:`%|py|580`
"""
if not sys.platform.startswith('lin') and is_travis_or_appveyor() != 'azurepipe':
return lambda x: x # pragma: no cover
msg = 'Test does not work on azure pipeline (linux) due to: ' + msg
return unittest.skip(msg)
[docs]def skipif_azure_macosx(msg):
"""
Skips a unit test if it runs on :epkg:`azure pipeline` on :epkg:`linux`.
:githublink:`%|py|590`
"""
if not sys.platform.startswith('darwin') and is_travis_or_appveyor() != 'azurepipe':
return lambda x: x
msg = 'Test does not work on azure pipeline (macosx) due to: ' + msg
return unittest.skip(msg)
[docs]def skipif_linux(msg):
"""
Skips a unit test if it runs on :epkg:`linux`.
.. versionadded:: 1.7
:githublink:`%|py|602`
"""
if not sys.platform.startswith('lin'):
return lambda x: x
msg = 'Test does not work on travis due to: ' + msg # pragma: no cover
return unittest.skip(msg) # pragma: no cover
[docs]def skipif_vless(version, msg):
"""
Skips a unit test if the version is stricly below *version* (tuple).
.. versionadded:: 1.7
:githublink:`%|py|614`
"""
if sys.version_info[:3] >= version:
return lambda x: x
msg = 'Python {} < {}: {}'.format(
sys.version_info[:3], version, msg) # pragma: no cover
return unittest.skip(msg) # pragma: no cover
[docs]def unittest_require_at_least(mod, version, msg=""):
"""
Skips a unit test if the version of one module
is not at least the provided version.
:param mod: module (the module must have an attribute ``__version__``)
:param version: expected version or more recent
:param msg: message
.. versionadded:: 1.9
:githublink:`%|py|632`
"""
v = getattr(mod, '__version__', None)
if v is None:
raise RuntimeError( # pragma: no cover
"Module '{}' has no version.".format(mod))
if compare_module_version(v, version) >= 0:
return lambda x: x
msg = "Module '{}' is older than '{}' (= '{}'). {}".format(
mod, version, v, msg)
return unittest.skip(msg)
[docs]def ignore_warnings(warns):
"""
Catches warnings.
:param warns: warnings to ignore
:githublink:`%|py|649`
"""
def wrapper(fct):
if warns is None:
raise AssertionError( # pragma: no cover
"warns cannot be None for '{}'.".format(fct))
def call_f(self):
with warnings.catch_warnings():
warnings.simplefilter("ignore", warns)
return fct(self)
return call_f
return wrapper
[docs]def testlog(logtype="print"):
"""
Logs before and after a function is called.
:param logtype: kind of logging, only `'print'` is implemented
and None to disable it
:githublink:`%|py|669`
"""
if logtype is None:
def nothing(arg):
pass
logfct = nothing
elif logtype == 'print':
logfct = print
else:
raise ValueError("Unexpected logtype %r." % logtype)
def wrapper(fct):
def call_f(self):
logfct('START %r' % fct.__name__)
fct(self)
logfct('DONE- %r' % fct.__name__)
return call_f
return wrapper
[docs]def assert_almost_equal_detailed(expected, value, **kwargs):
"""
Calls :epkg:`numpy:testing:assert_almost_equal`.
Add more informations in the exception message.
:param expected: expected value
:param value: value
:raises: AssertionError
:githublink:`%|py|697`
"""
from numpy.testing import assert_almost_equal
try:
assert_almost_equal(expected, value, **kwargs)
except AssertionError as e:
if expected.shape[0] != value.shape[0]:
raise e
rows = ['INNER EXCEPTION:', str(e), '------', 'ROWS BY ROWS']
for i, (r1, r2) in enumerate(zip(expected, value)):
try:
assert_almost_equal(r1, r2, **kwargs)
except AssertionError as e:
rows.append('----------------------')
rows.append("ISSUE WITH ROW {}/{}:0 {}".format(
i, expected.shape[0], str(e)))
if len(rows) > 10:
break
raise AssertionError("\n".join(rows))