Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2@file
3@brief Overwrites unit test class with additional testing functions.
4"""
5from io import StringIO
6import os
7import sys
8import unittest
9import warnings
10import decimal
11import pprint
12from logging import getLogger, INFO, StreamHandler
13from contextlib import redirect_stdout, redirect_stderr
14from .ci_helper import is_travis_or_appveyor
15from .profiling import profile
16from ..texthelper import compare_module_version
19class ExtTestCase(unittest.TestCase):
20 """
21 Overwrites unit test class with additional testing functions.
22 Unless *setUp* is overwritten, warnings *FutureWarning* and
23 *PendingDeprecationWarning* are filtered out.
24 """
26 def setUp(self):
27 """
28 Filters out *FutureWarning*, *PendingDeprecationWarning*.
29 """
30 warnings.simplefilter("ignore",
31 (FutureWarning,
32 PendingDeprecationWarning,
33 ImportWarning,
34 DeprecationWarning))
36 def tearDown(self):
37 """
38 Stops filtering out *FutureWarning*, *PendingDeprecationWarning*.
39 """
40 warnings.simplefilter("default",
41 (FutureWarning,
42 PendingDeprecationWarning,
43 ImportWarning,
44 DeprecationWarning))
46 @staticmethod
47 def _format_str(s):
48 """
49 Returns ``s`` or ``'s'`` depending on the type.
50 """
51 if hasattr(s, "replace"):
52 return "'{0}'".format(s)
53 return s
55 def assertNotEmpty(self, x):
56 """
57 Checks that *x* is not empty.
58 """
59 if x is None or (hasattr(x, "__len__") and len(x) == 0):
60 raise AssertionError("x is empty")
62 def assertEmpty(self, x, none_allowed=True):
63 """
64 Checks that *x* is empty.
65 """
66 if not((none_allowed and x is None) or (hasattr(x, "__len__") and len(x) == 0)):
67 if isinstance(x, (list, tuple, dict, set)):
68 end = min(5, len(x))
69 disp = "\n" + '\n'.join(map(str, x[:end]))
70 else:
71 disp = ""
72 raise AssertionError("x is not empty{0}".format(disp))
74 def assertGreater(self, x, y, strict=False): # pylint: disable=W0221,W0237
75 """
76 Checks that ``x >= y``.
77 """
78 if x < y or (strict and x == y):
79 raise AssertionError("x <{2} y with x={0} and y={1}".format(
80 ExtTestCase._format_str(x), ExtTestCase._format_str(y),
81 "" if strict else "="))
83 def assertLesser(self, x, y, strict=False):
84 """
85 Checks that ``x <= y``.
86 """
87 if x > y or (strict and x == y):
88 raise AssertionError("x >{2} y with x={0} and y={1}".format(
89 ExtTestCase._format_str(x), ExtTestCase._format_str(y),
90 "" if strict else "="))
92 def assertExists(self, name):
93 """
94 Checks that *name* exists.
95 """
96 if not os.path.exists(name):
97 raise FileNotFoundError("Unable to find '{0}'.".format(name))
99 def assertNotExists(self, name):
100 """
101 Checks that *name* does not exist.
102 """
103 if os.path.exists(name):
104 raise FileNotFoundError( # pragma: no cover
105 "Able to find '{0}'.".format(name))
107 def assertEqualDataFrame(self, d1, d2, **kwargs):
108 """
109 Checks that two dataframes are equal.
110 Calls :epkg:`pandas:testing:assert_frame_equal`.
111 """
112 from pandas.testing import assert_frame_equal
113 assert_frame_equal(d1, d2, **kwargs)
115 def assertNotEqualDataFrame(self, d1, d2, **kwargs):
116 """
117 Checks that two dataframes are different.
118 Calls :epkg:`pandas:testing:assert_frame_equal`.
119 """
120 from pandas.testing import assert_frame_equal
121 try:
122 assert_frame_equal(d1, d2, **kwargs)
123 except AssertionError:
124 return
125 raise AssertionError("Two dataframes are identical.")
127 def assertEqualArray(self, d1, d2, squeeze=False, **kwargs):
128 """
129 Checks that two arrays are equal.
130 Relies on :epkg:`numpy:testing:assert_almost_equal`.
131 """
132 if d1 is None and d2 is None:
133 return
134 if d1 is None:
135 raise AssertionError("d1 is None, d2 is not")
136 if d2 is None:
137 raise AssertionError("d1 is not None, d2 is")
138 from numpy.testing import assert_almost_equal
139 import numpy
140 if squeeze:
141 d1 = numpy.squeeze(d1)
142 d2 = numpy.squeeze(d2)
143 assert_almost_equal(d1, d2, **kwargs)
145 def assertHasNoNan(self, a): # pylint: disable=W0221
146 """
147 Checks that there is no NaN in ``a``.
148 """
149 if a is None:
150 raise AssertionError("a is None")
151 import numpy
152 if any(map(numpy.isnan, a.ravel())):
153 raise AssertionError("a has nan:\n{}".format(a))
155 def assertEqualSparseArray(self, d1, d2, **kwargs):
156 if type(d1) != type(d2): # pylint: disable=C0123
157 raise AssertionError("d1 and d2 have difference types {} != {}.".format(
158 type(d1), type(d2)))
159 if d1 is None and d2 is None:
160 return
161 if (hasattr(d1, 'data') and hasattr(d1, 'row') and hasattr(d1, 'col') and
162 hasattr(d2, 'data') and hasattr(d2, 'row') and hasattr(d2, 'col')):
163 # coo_matrix
164 self.assertEqual(d1.shape, d2.shape)
165 self.assertEqualArray(d1.data, d2.data)
166 self.assertEqualArray(d1.row, d2.row)
167 self.assertEqualArray(d1.col, d2.col)
168 return
169 if (hasattr(d1, 'data') and hasattr(d1, 'indices') and hasattr(d1, 'indptr') and
170 hasattr(d2, 'data') and hasattr(d2, 'indices') and hasattr(d2, 'indptr')):
171 # coo_matrix
172 self.assertEqual(d1.shape, d2.shape)
173 self.assertEqualArray(d1.data, d2.data)
174 self.assertEqualArray(d1.indices, d2.indices)
175 self.assertEqualArray(d1.indptr, d2.indptr)
176 return
177 raise NotImplementedError( # pragma: no cover
178 "Comparison not implemented for types {} and {}.".format(
179 type(d1), type(d2)))
181 def assertNotEqualArray(self, d1, d2, squeeze=False, **kwargs):
182 """
183 Checks that two arrays are equal.
184 Relies on :epkg:`numpy:testing:assert_almost_equal`.
185 """
186 if d1 is None and d2 is None:
187 raise AssertionError("d1 and d2 are equal to None")
188 if d1 is None or d2 is None:
189 return
190 from numpy.testing import assert_almost_equal
191 import numpy
192 if squeeze:
193 d1 = numpy.squeeze(d1)
194 d2 = numpy.squeeze(d2)
195 try:
196 assert_almost_equal(d1, d2, **kwargs)
197 except AssertionError:
198 return
199 raise AssertionError("Two arrays are identical.")
201 def assertEqualNumber(self, d1, d2, **kwargs):
202 """
203 Checks that two numbers are equal.
204 """
205 from numpy import number
206 if not isinstance(d1, (int, float, decimal.Decimal, number)):
207 raise TypeError('d1 is not a number but {0}'.format(type(d1)))
208 if not isinstance(d2, (int, float, decimal.Decimal, number)):
209 raise TypeError('d2 is not a number but {0}'.format(type(d2)))
210 diff = abs(float(d1 - d2))
211 mi = float(min(abs(d1), abs(d2)))
212 tol = kwargs.get('precision', None)
213 if tol is None:
214 if diff != 0:
215 raise AssertionError("d1 != d2: {0} != {1}".format(d1, d2))
216 else:
217 if mi == 0:
218 if diff > tol: # pragma: no cover
219 raise AssertionError(
220 "d1 != d2: {0} != {1} +/- {2}".format(d1, d2, tol))
221 else:
222 rel = diff / mi
223 if rel > tol:
224 raise AssertionError( # pragma: no cover
225 "d1 != d2: {0} != {1} +/- {2}".format(d1, d2, tol))
227 def assertRaise(self, fct, exc=None, msg=None):
228 """
229 Checks that function *fct* with no parameter
230 raises an exception of a given type.
232 @param fct function to test (no parameter)
233 @param exc exception type to catch (None for all)
234 @param msg error message to check (None for no message to check)
235 """
236 try:
237 fct()
238 except Exception as e:
239 if exc is None:
240 return # pragma: no cover
241 elif isinstance(e, exc):
242 if msg is None:
243 return
244 if msg not in str(e):
245 raise AssertionError( # pragma: no cover
246 "Function '{0}' raise exception with wrong message '{1}' "
247 "(must contain '{2}').".format(fct, e, msg))
248 return
249 raise AssertionError(
250 "Function '{0}' does not raise exception '{1}' but '{2}' of type "
251 "'{3}'.".format(fct, exc, e, type(e)))
252 raise AssertionError( # pragma: no cover
253 "Function '{0}' does not raise exception.".format(fct))
255 def capture(self, fct):
256 """
257 Runs a function and capture standard output and error.
259 @param fct function to run
260 @return result of *fct*, output, error
261 """
262 sout = StringIO()
263 serr = StringIO()
264 with redirect_stdout(sout):
265 with redirect_stderr(serr):
266 res = fct()
267 return res, sout.getvalue(), serr.getvalue()
269 def assertStartsWith(self, sub, whole):
270 """
271 Checks that string *sub* starts with *whole*.
272 """
273 if not whole.startswith(sub):
274 if len(whole) > len(sub) * 2:
275 whole = whole[:len(sub) * 2] # pragma: no cover
276 raise AssertionError(
277 "'{1}' does not start with '{0}'".format(sub, whole))
279 def assertNotStartsWith(self, sub, whole):
280 """
281 Checks that string *sub* does not start with *whole*.
282 """
283 if whole.startswith(sub):
284 if len(whole) > len(sub) * 2:
285 whole = whole[:len(sub) * 2] # pragma: no cover
286 raise AssertionError(
287 "'{1}' starts with '{0}'".format(sub, whole))
289 def assertEndsWith(self, sub, whole):
290 """
291 Checks that string *sub* ends with *whole*.
292 """
293 if not whole.endswith(sub):
294 if len(whole) > len(sub) * 2:
295 whole = whole[-len(sub) * 2:] # pragma: no cover
296 raise AssertionError(
297 "'{1}' does not end with '{0}'".format(sub, whole))
299 def assertNotEndsWith(self, sub, whole):
300 """
301 Checks that string *sub* does not end with *whole*.
302 """
303 if whole.endswith(sub):
304 if len(whole) > len(sub) * 2:
305 whole = whole[-len(sub) * 2:]
306 raise AssertionError(
307 "'{1}' ends with '{0}'".format(sub, whole))
309 def assertEqual(self, a, b): # pylint: disable=W0221
310 """
311 Checks that ``a == b``.
312 """
313 if a is None and b is not None:
314 raise AssertionError("a is None, b is not")
315 if a is not None and b is None:
316 raise AssertionError("a is not None, b is")
317 try:
318 unittest.TestCase.assertEqual(self, a, b)
319 except ValueError as e:
320 if "The truth value of a DataFrame is ambiguous" in str(e) or \
321 "The truth value of an array with more than one element is ambiguous." in str(e):
322 with warnings.catch_warnings():
323 warnings.filterwarnings("ignore", category=ImportWarning)
324 import pandas
325 if isinstance(a, pandas.DataFrame) and isinstance(b, pandas.DataFrame):
326 self.assertEqualDataFrame(a, b)
327 return
328 import numpy
329 if isinstance(a, numpy.ndarray) and isinstance(b, numpy.ndarray):
330 self.assertEqualArray(a, b)
331 return
332 raise AssertionError( # pragma: no cover
333 "Unable to check equality for types {0} and {1}".format(
334 type(a), type(b))) from e
336 def assertNotEqual(self, a, b): # pylint: disable=W0221
337 """
338 Checks that ``a != b``.
339 """
340 if a is None and b is None:
341 raise AssertionError("a is None, b is too") # pragma: no cover
342 if a is None and b is not None:
343 return # pragma: no cover
344 if a is not None and b is None:
345 return # pragma: no cover
346 try:
347 unittest.TestCase.assertNotEqual(self, a, b)
348 except ValueError as e:
349 if "Can only compare identically-labeled DataFrame objects" in str(e) or \
350 "The truth value of a DataFrame is ambiguous." in str(e) or \
351 "The truth value of an array with more than one element is ambiguous." in str(e):
352 with warnings.catch_warnings():
353 warnings.filterwarnings("ignore", category=ImportWarning)
354 import pandas
355 if isinstance(a, pandas.DataFrame) and isinstance(b, pandas.DataFrame):
356 self.assertNotEqualDataFrame(a, b)
357 return
358 import numpy
359 if isinstance(a, numpy.ndarray) and isinstance(b, numpy.ndarray):
360 self.assertNotEqualArray(a, b)
361 return
362 raise e # pragma: no cover
364 def assertEqualFloat(self, a, b, precision=1e-5):
365 """
366 Checks that ``abs(a-b) < precision``.
367 """
368 mi = min(abs(a), abs(b))
369 if mi == 0:
370 d = abs(a - b)
371 try:
372 self.assertLesser(d, precision)
373 except AssertionError:
374 raise AssertionError("{} != {} (p={})".format(a, b, precision))
375 else:
376 r = float(abs(a - b)) / mi
377 try:
378 self.assertLesser(r, precision)
379 except AssertionError:
380 raise AssertionError("{} != {} (p={})".format(a, b, precision))
382 def assertCallable(self, fct):
383 """
384 Checks that *fct* is callable.
385 """
386 if not callable(fct):
387 raise AssertionError("fct is not callable: {0}".format(type(fct)))
389 def assertEqualDict(self, a, b):
390 """
391 Checks that ``a == b``.
392 """
393 if not isinstance(a, dict):
394 raise TypeError('a is not dict but {0}'.format(type(a)))
395 if not isinstance(b, dict):
396 raise TypeError('b is not dict but {0}'.format(type(b)))
397 rows = []
398 for key in sorted(b):
399 if key not in a:
400 rows.append("** Added key '{0}' in b".format(key))
401 else:
402 if a[key] != b[key]:
403 rows.append(
404 "** Value != for key '{0}': != id({1}) != id({2})\n==1 {3}\n==2 {4}".format(
405 key, id(a[key]), id(b[key]), a[key], b[key]))
406 for key in sorted(a):
407 if key not in b:
408 rows.append("** Removed key '{0}' in a".format(key))
409 if len(rows) > 0:
410 raise AssertionError(
411 "Dictionaries are different\n{0}".format('\n'.join(rows)))
413 def fLOG(self, *args, **kwargs):
414 """
415 Prints out some information.
416 @see fn fLOG.
417 """
418 # delayed import
419 from ..loghelper import fLOG as _flog # pragma: no cover
420 _flog(*args, **kwargs) # pragma: no cover
422 @staticmethod
423 def profile(fct, sort='cumulative', rootrem=None,
424 return_results=False):
425 """
426 Profiles the execution of a function with function
427 :func:`profile <pyquickhelper.pycode.profiling.profile>`.
429 :param fct: function to profile
430 :param sort: see :meth:`pstats.Stats.sort_stats`
431 :param rootrem: root to remove in filenames
432 :param return_results: return the results as well
433 :return: statistics text dump
435 .. versionchanged:: 1.11
436 Parameter *return_results* was added.
437 """
438 return profile(fct, sort=sort, rootrem=rootrem,
439 return_results=return_results)
441 def read_file(self, filename, mode='r', encoding="utf-8"):
442 """
443 Returns the content of a file.
445 @param filename filename
446 @param encoding encoding
447 @param mode reading mode
448 @return content
449 """
450 self.assertExists(filename)
451 with open(filename, mode, encoding=encoding) as f:
452 return f.read()
454 def write_file(self, filename, content, mode='w', encoding='utf-8'):
455 """
456 Writes the content of a file.
458 @param filename filename
459 @param content content to write
460 @param encoding encoding
461 @param mode reading mode
462 @return content
463 """
464 with open(filename, mode, encoding=encoding) as f:
465 return f.write(content)
467 def assertIn(self, sub, ensemble, msg=None): # pylint: disable=W0221,W0237
468 """
469 Checks that substring *sub* is in *text*.
471 @param sub sub set
472 @param ensemble full set
473 @param msg error message
474 @raises AssertionError
475 """
476 if sub is None:
477 return # pragma: no cover
478 if ensemble is None:
479 raise AssertionError(msg or "'text' is None") # pragma: no cover
480 if sub not in ensemble:
481 raise AssertionError( # pragma: no cover
482 msg or "Unable to find '{}' in\n{}".format(
483 sub, pprint.pformat(ensemble)))
485 def assertWarning(self, fct):
486 """
487 Returns the list of warnings raised while
488 executing function *fct*.
490 @param fct function to run
491 @return result, list of warnings
492 """
493 with warnings.catch_warnings(record=True) as w:
494 warnings.simplefilter("always")
495 r = fct()
496 return r, list(w)
498 def assertLogging(self, fct, logger_name, level=INFO, log_sphinx=False):
499 """
500 Returns the logged information in a logger defined
501 by its name.
503 @param fct function to run
504 @param logger_name logger name
505 @param level level to intercept
506 @param log_sphinx logging from :epkg:`sphinx`
507 @return result, logged information
508 """
509 from sphinx.util import logging as logging_sphinx
511 class MyStream:
512 def __init__(self):
513 self.rows = []
515 def write(self, text):
516 self.rows.append(text)
518 def getvalue(self):
519 return "\n".join(self.rows)
521 def __len__(self):
522 return len(self.rows)
524 logger = (logging_sphinx.getLogger(logger_name).logger
525 if log_sphinx else getLogger(logger_name))
527 hs = list(logger.handlers)
528 for h in logger.handlers:
529 logger.removeHandler(h) # pragma: no cover
531 log_capture_string = MyStream()
532 ch = StreamHandler(log_capture_string)
533 ch.setLevel(level)
534 logger.addHandler(ch)
536 res = fct()
538 logs = log_capture_string.getvalue()
539 logger.removeHandler(ch)
541 for h in hs:
542 logger.addHandler(h) # pragma: no cover
543 return res, logs
545 @staticmethod
546 def abs_path_join(filename, *args):
547 """
548 Returns an absolute and normalized path from this location.
550 :param filename: filename, the folder which contains it
551 is used as the base
552 :param args: list of subpaths to the previous path
553 :return: absolute and normalized path
554 """
555 dirname = os.path.join(os.path.dirname(filename), *args)
556 return os.path.normpath(os.path.abspath(dirname))
559def skipif_appveyor(msg):
560 """
561 Skips a unit test if it runs on :epkg:`appveyor`.
562 """
563 if is_travis_or_appveyor() != 'appveyor':
564 return lambda x: x
565 msg = 'Test does not work on appveyor due to: ' + msg # pragma: no cover
566 return unittest.skip(msg) # pragma: no cover
569def skipif_travis(msg):
570 """
571 Skips a unit test if it runs on :epkg:`travis`.
572 """
573 if is_travis_or_appveyor() != 'travis':
574 return lambda x: x
575 msg = 'Test does not work on travis due to: ' + msg # pragma: no cover
576 return unittest.skip(msg) # pragma: no cover
579def skipif_circleci(msg):
580 """
581 Skips a unit test if it runs on :epkg:`circleci`.
582 """
583 if is_travis_or_appveyor() != 'circleci':
584 return lambda x: x
585 msg = 'Test does not work on circleci due to: ' + msg # pragma: no cover
586 return unittest.skip(msg) # pragma: no cover
589def skipif_azure(msg):
590 """
591 Skips a unit test if it runs on :epkg:`azure pipeline`.
592 """
593 if is_travis_or_appveyor() != 'azurepipe':
594 return lambda x: x # pragma: no cover
595 msg = 'Test does not work on azure pipeline due to: ' + msg # pragma: no cover
596 return unittest.skip(msg) # pragma: no cover
599def skipif_azure_linux(msg):
600 """
601 Skips a unit test if it runs on :epkg:`azure pipeline` on :epkg:`linux`.
602 """
603 if not sys.platform.startswith('lin') and is_travis_or_appveyor() != 'azurepipe':
604 return lambda x: x # pragma: no cover
605 msg = 'Test does not work on azure pipeline (linux) due to: ' + msg
606 return unittest.skip(msg)
609def skipif_azure_macosx(msg):
610 """
611 Skips a unit test if it runs on :epkg:`azure pipeline` on :epkg:`linux`.
612 """
613 if not sys.platform.startswith('darwin') and is_travis_or_appveyor() != 'azurepipe':
614 return lambda x: x
615 msg = 'Test does not work on azure pipeline (macosx) due to: ' + msg
616 return unittest.skip(msg)
619def skipif_linux(msg):
620 """
621 Skips a unit test if it runs on :epkg:`linux`.
622 """
623 if not sys.platform.startswith('lin'):
624 return lambda x: x
625 msg = 'Test does not work on travis due to: ' + msg # pragma: no cover
626 return unittest.skip(msg) # pragma: no cover
629def skipif_vless(version, msg):
630 """
631 Skips a unit test if the version is stricly below
632 *version* (tuple).
633 """
634 if sys.version_info[:3] >= version:
635 return lambda x: x
636 msg = 'Python {} < {}: {}'.format(
637 sys.version_info[:3], version, msg) # pragma: no cover
638 return unittest.skip(msg) # pragma: no cover
641def unittest_require_at_least(mod, version, msg=""):
642 """
643 Skips a unit test if the version of one module
644 is not at least the provided version.
646 @param mod module (the module must have an attribute ``__version__``)
647 @param version expected version or more recent
648 @param msg message
650 .. versionadded:: 1.9
651 """
652 v = getattr(mod, '__version__', None)
653 if v is None:
654 raise RuntimeError( # pragma: no cover
655 "Module '{}' has no version.".format(mod))
656 if compare_module_version(v, version) >= 0:
657 return lambda x: x
658 msg = "Module '{}' is older than '{}' (= '{}'). {}".format(
659 mod, version, v, msg)
660 return unittest.skip(msg)
663def ignore_warnings(warns):
664 """
665 Catches warnings.
667 @param warns warnings to ignore
668 """
669 def wrapper(fct):
670 if warns is None:
671 raise AssertionError( # pragma: no cover
672 "warns cannot be None for '{}'.".format(fct))
674 def call_f(self):
675 with warnings.catch_warnings():
676 warnings.simplefilter("ignore", warns)
677 return fct(self)
678 return call_f
679 return wrapper
682def testlog(logtype="print"):
683 """
684 Logs before and after a function is called.
686 :param logtype: kind of logging, only `'print'` is implemented
687 and None to disable it
688 """
689 if logtype is None:
690 def nothing(arg):
691 pass
693 logfct = nothing
694 elif logtype == 'print':
695 logfct = print
696 else:
697 raise ValueError("Unexpected logtype %r." % logtype)
699 def wrapper(fct):
700 def call_f(self):
701 logfct('START %r' % fct.__name__)
702 fct(self)
703 logfct('DONE- %r' % fct.__name__)
704 return call_f
705 return wrapper
708def assert_almost_equal_detailed(expected, value, **kwargs):
709 """
710 Calls :epkg:`numpy:testing:assert_almost_equal`.
711 Add more informations in the exception message.
713 :param expected: expected value
714 :param value: value
715 :raises: AssertionError
716 """
717 from numpy.testing import assert_almost_equal
718 try:
719 assert_almost_equal(expected, value, **kwargs)
720 except AssertionError as e:
721 if expected.shape[0] != value.shape[0]:
722 raise e # pragma: no cover
723 rows = ['INNER EXCEPTION:', str(e), '------', 'ROWS BY ROWS']
724 for i, (r1, r2) in enumerate(zip(expected, value)):
725 try:
726 assert_almost_equal(r1, r2, **kwargs)
727 except AssertionError as e:
728 rows.append('----------------------')
729 rows.append("ISSUE WITH ROW {}/{}:0 {}".format(
730 i, expected.shape[0], str(e)))
731 if len(rows) > 10:
732 break # pragma: no cover
733 raise AssertionError("\n".join(rows))