Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# coding: utf-8 

2""" 

3Common methods about simulation for :epkg:`SIR` models. 

4""" 

5import numpy 

6from sympy import Symbol, diff as sympy_diff 

7 

8 

9class BaseSIRSimulation: 

10 """ 

11 Common methods about simulation for :epkg:`SIR` models. 

12 """ 

13 

14 def eqsign(self, eqname, name): 

15 """ 

16 Returns the sign of the second derivative for equation 

17 *eqname* against *name*. 

18 

19 :param eqname: equation name 

20 :param name: symbol name 

21 :return: boolean 

22 """ 

23 leqname = 'd' + eqname + '/d' + name 

24 eql = self._lambdified_(leqname) 

25 if eql is None: 

26 eq = self._eq[eqname] 

27 df = sympy_diff(eq, Symbol(name)) 

28 self._lambdify_(leqname, df) 

29 eval1 = self.evalf_eq(df) 

30 eval2 = self.evalf_leq(leqname) 

31 if abs(eval1 - eval2) > 1e-5: 

32 raise ValueError( # pragma: no cover 

33 "Lambdification failed for derivative '{}' by '{}' " 

34 "({} != {})".format(eqname, name, eval1, eval2)) 

35 ev = self.evalf_leq(leqname) 

36 return 1 if ev >= 0 else -1 

37 

38 def iterate(self, n=10, t=0, derivatives=False): 

39 """ 

40 Evalues the quantities for *n* iterations. 

41 Returns a list of dictionaries. 

42 If *derivatives* is True, it returns two dictionaries. 

43 

44 :param n: number of iterations 

45 :param t: first *t* 

46 :param derivatives: returns the derivative as well 

47 :return: iterator on dictionaries 

48 """ 

49 for i in range(t, t + n): 

50 x = self.vect(t=i) 

51 diff = {k: v(*x) for k, v in self._leq.items()} 

52 vals = {k[0]: v for k, v in zip(self._q, x)} 

53 

54 if derivatives: 

55 yield vals.copy(), diff 

56 else: 

57 yield vals.copy() 

58 

59 for k, v in diff.items(): 

60 vals[k] += v 

61 self.update(**vals) 

62 

63 def iterate2array(self, n=10, t=0, derivatives=False): 

64 """ 

65 Evalues the quantities for *n* iterations. 

66 Returns matrices. 

67 

68 :param n: number of iterations 

69 :param t: first *t* 

70 :param derivatives: returns the derivative as well 

71 :return: quantities or (quantities, derivatives) 

72 if *derivatives* is True 

73 """ 

74 clq = self.quantity_names 

75 pos = {n: i for i, n in enumerate(clq)} 

76 res = list(self.iterate(n=n, t=t, derivatives=derivatives)) 

77 qu = numpy.zeros((len(res), len(clq)), dtype=numpy.float32) 

78 if derivatives: 

79 de = numpy.zeros((len(res), len(clq)), dtype=numpy.float32) 

80 for i, (r, d) in enumerate(res): 

81 for j, n in enumerate(pos): 

82 qu[i, j] = r.get(n, numpy.nan) 

83 for j, n in enumerate(pos): 

84 de[i, j] = d.get(n, numpy.nan) 

85 return qu, de 

86 else: 

87 for i, r in enumerate(res): 

88 for j, n in enumerate(pos): 

89 qu[i, j] = r.get(n, numpy.nan) 

90 return qu 

91 

92 def R0(self, t=0): 

93 '''Returns R0 coefficient.''' 

94 raise NotImplementedError()