Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# coding: utf-8
2"""
3Common methods about simulation for :epkg:`SIR` models.
4"""
5import numpy
6from sympy import Symbol, diff as sympy_diff
9class BaseSIRSimulation:
10 """
11 Common methods about simulation for :epkg:`SIR` models.
12 """
14 def eqsign(self, eqname, name):
15 """
16 Returns the sign of the second derivative for equation
17 *eqname* against *name*.
19 :param eqname: equation name
20 :param name: symbol name
21 :return: boolean
22 """
23 leqname = 'd' + eqname + '/d' + name
24 eql = self._lambdified_(leqname)
25 if eql is None:
26 eq = self._eq[eqname]
27 df = sympy_diff(eq, Symbol(name))
28 self._lambdify_(leqname, df)
29 eval1 = self.evalf_eq(df)
30 eval2 = self.evalf_leq(leqname)
31 if abs(eval1 - eval2) > 1e-5:
32 raise ValueError( # pragma: no cover
33 "Lambdification failed for derivative '{}' by '{}' "
34 "({} != {})".format(eqname, name, eval1, eval2))
35 ev = self.evalf_leq(leqname)
36 return 1 if ev >= 0 else -1
38 def iterate(self, n=10, t=0, derivatives=False):
39 """
40 Evalues the quantities for *n* iterations.
41 Returns a list of dictionaries.
42 If *derivatives* is True, it returns two dictionaries.
44 :param n: number of iterations
45 :param t: first *t*
46 :param derivatives: returns the derivative as well
47 :return: iterator on dictionaries
48 """
49 for i in range(t, t + n):
50 x = self.vect(t=i)
51 diff = {k: v(*x) for k, v in self._leq.items()}
52 vals = {k[0]: v for k, v in zip(self._q, x)}
54 if derivatives:
55 yield vals.copy(), diff
56 else:
57 yield vals.copy()
59 for k, v in diff.items():
60 vals[k] += v
61 self.update(**vals)
63 def iterate2array(self, n=10, t=0, derivatives=False):
64 """
65 Evalues the quantities for *n* iterations.
66 Returns matrices.
68 :param n: number of iterations
69 :param t: first *t*
70 :param derivatives: returns the derivative as well
71 :return: quantities or (quantities, derivatives)
72 if *derivatives* is True
73 """
74 clq = self.quantity_names
75 pos = {n: i for i, n in enumerate(clq)}
76 res = list(self.iterate(n=n, t=t, derivatives=derivatives))
77 qu = numpy.zeros((len(res), len(clq)), dtype=numpy.float32)
78 if derivatives:
79 de = numpy.zeros((len(res), len(clq)), dtype=numpy.float32)
80 for i, (r, d) in enumerate(res):
81 for j, n in enumerate(pos):
82 qu[i, j] = r.get(n, numpy.nan)
83 for j, n in enumerate(pos):
84 de[i, j] = d.get(n, numpy.nan)
85 return qu, de
86 else:
87 for i, r in enumerate(res):
88 for j, n in enumerate(pos):
89 qu[i, j] = r.get(n, numpy.nan)
90 return qu
92 def R0(self, t=0):
93 '''Returns R0 coefficient.'''
94 raise NotImplementedError()