Coverage for aftercovid/models/_base_sir.py: 99%
284 statements
« prev ^ index » next coverage.py v7.1.0, created at 2024-05-09 03:09 +0200
« prev ^ index » next coverage.py v7.1.0, created at 2024-05-09 03:09 +0200
1# coding: utf-8
2"""
3Common functions for :epkg:`SIR` models.
4"""
5import numpy
6from sympy import symbols, Symbol, latex, lambdify
7import sympy.printing as printing
8from sympy.parsing.sympy_parser import (
9 parse_expr, standard_transformations, implicit_application)
10from ._sympy_helper import enumerate_traverse
11from ._base_sir_sim import BaseSIRSimulation
12from ._base_sir_estimation import BaseSIREstimation
15class BaseSIR(BaseSIRSimulation, BaseSIREstimation):
16 """
17 Base model for :epkg:`SIR` models.
19 :param p: list of `[(name, initial value or None, comment)]` (parameters)
20 :param q: list of `[(name, initial value or None, comment)]` (quantities)
21 :param c: list of `[(name, initial value or None, comment)]` (constants)
22 :param eq: equations
23 """
24 _pickled_atts = [
25 '_p', '_q', '_c', '_eq', '_val_p', '_val_q', '_val_c',
26 '_val_ind', '_val_len', '_syms']
28 def __init__(self, p, q, c=None, eq=None, **kwargs):
29 if not isinstance(p, list):
30 raise TypeError("p must be a list of tuple.")
31 if not isinstance(q, list):
32 raise TypeError("q must be a list of tuple.")
33 if not isinstance(c, list):
34 raise TypeError("c must be a list of tuple.")
35 if eq is not None and not isinstance(eq, dict):
36 raise TypeError("eq must be a dictionary.")
37 self._p = p
38 self._q = q
39 self._c = c
40 if eq is not None:
41 locs = {'t': symbols('t', cls=Symbol)}
42 for v in self._p:
43 locs[v[0]] = symbols(v[0], cls=Symbol)
44 for v in self._c:
45 locs[v[0]] = symbols(v[0], cls=Symbol)
46 for v in self._q:
47 locs[v[0]] = symbols(v[0], cls=Symbol)
48 self._syms = locs
49 tr = standard_transformations + (implicit_application, )
50 self._eq = {}
51 for k, v in eq.items():
52 try:
53 self._eq[k] = parse_expr(v, locs, transformations=tr)
54 except (TypeError, ValueError) as e: # pragma: no cover
55 raise RuntimeError(
56 f"Unable to parse '{v}'.") from e
57 else:
58 self._eq = None
59 if len(kwargs) != 0:
60 raise NotImplementedError( # pragma: no cover
61 "Not implemented.")
62 self._init()
64 def copy(self):
65 inst = self.__class__.__new__(self.__class__)
66 for k in BaseSIR._pickled_atts:
67 setattr(inst, k, getattr(self, k))
68 if hasattr(inst, '_eq') and inst._eq is not None:
69 inst._init_lambda_()
70 return inst
72 def __getstate__(self):
73 '''
74 Returns the pickled data.
75 '''
76 return {k: getattr(self, k) for k in BaseSIR._pickled_atts}
78 def __setstate__(self, state):
79 '''
80 Sets the pickled data.
81 '''
82 for k, v in state.items():
83 setattr(self, k, v)
84 if hasattr(self, '_eq') and self._eq is not None:
85 self._init_lambda_()
87 def _init(self):
88 """
89 Starts from the initial values.
90 """
91 def _def_(name, v):
92 if v is not None:
93 return v
94 if name == 'N': # pragma: no cover
95 return 10000.
96 return 0. # pragma: no cover
98 self._val_p = numpy.array(
99 [_def_(v[0], v[1]) for v in self._p], dtype=numpy.float64)
100 self._val_q = numpy.array(
101 [_def_(v[0], v[1]) for v in self._q], dtype=numpy.float64)
102 self._val_c = numpy.array(
103 [_def_(v[0], v[1]) for v in self._c], dtype=numpy.float64)
104 self._val_len = (len(self._val_p) + len(self._val_q) +
105 len(self._val_c))
106 self._val_ind = numpy.array([
107 0, len(self._val_q), len(self._val_q) + len(self._val_p),
108 len(self._val_q) + len(self._val_p) + len(self._val_c)])
110 if hasattr(self, '_eq') and self._eq is not None:
111 self._init_lambda_()
113 def _init_lambda_(self):
114 self._leq = {}
115 for k, v in self._eq.items():
116 fct = self._lambdify_(k, v)
117 eval1 = float(self.evalf_eq(v))
118 eval2 = self.evalf_leq(k)
119 err = (eval2 - eval1) / max(abs(eval1), abs(eval2))
120 if err > 1e-8:
121 raise ValueError( # pragma: no cover
122 "Lambdification failed for function '{}': {} "
123 "({} ({}) != {} ({}), error={})".format(
124 k, v, eval1, type(eval1), eval2, type(eval2), err))
125 self._leq[k] = fct
126 self._leqa = [self._leq[_[0]] for _ in self._q]
128 def get_index(self, name):
129 '''
130 Returns the index of a name (True or False, position).
131 '''
132 for i, v in enumerate(self._p):
133 if v[0] == name:
134 return 'p', i
135 for i, v in enumerate(self._q):
136 if v[0] == name:
137 return 'q', i
138 for i, v in enumerate(self._c):
139 if v[0] == name:
140 return 'c', i
141 raise ValueError(f"Unable to find name '{name}'.")
143 def __setitem__(self, name, value):
144 """
145 Updates a value whether it is a parameter or a quantity.
147 :param name: name
148 :param value: new value
149 """
150 p, pos = self.get_index(name)
151 if p == 'p':
152 self._val_p[pos] = value
153 elif p == 'q':
154 self._val_q[pos] = value
155 elif p == 'c':
156 self._val_c[pos] = value
158 def __getitem__(self, name):
159 """
160 Retrieves a value whether it is a parameter or a quantity.
162 :param name: name
163 :return: value
164 """
165 p, pos = self.get_index(name)
166 if p == 'p':
167 return self._val_p[pos]
168 if p == 'q':
169 return self._val_q[pos]
170 if p == 'c':
171 return self._val_c[pos]
173 @property
174 def names(self):
175 'Returns the list of names.'
176 return list(sorted(
177 [v[0] for v in self._p] + [v[0] for v in self._q] +
178 [v[0] for v in self._c]))
180 @property
181 def quantity_names(self):
182 'Returns the list of quantities names (unsorted).'
183 return [v[0] for v in self._q]
185 @property
186 def param_names(self):
187 'Returns the list of parameters names (unsorted).'
188 return [v[0] for v in self._p]
190 @property
191 def params_dict(self):
192 'Returns the list of parameters names in a dictionary.'
193 return {k: self[k] for k in self.param_names}
195 @property
196 def cst_names(self):
197 'Returns the list of constants names (unsorted).'
198 return [v[0] for v in self._c]
200 @property
201 def vect_names(self):
202 'Returns the list of names.'
203 return ([v[0] for v in self._q] + [v[0] for v in self._p] +
204 [v[0] for v in self._c] + ['t'])
206 def vect(self, t=0, out=None, derivative=False):
207 """
208 Returns all values as a vector.
210 :param t: time *t*
211 :param out: alternative output array in which to place the
212 result. It must have the same shape as the expected output.
213 :param derivative: returns the derivatives instead of the values
214 :return: values or derivatives
215 """
216 if derivative:
217 if out is None:
218 out = numpy.empty((self._val_len + 1 + self._val_ind[1], ),
219 dtype=numpy.float64)
220 self.vect(t=t, out=out)
221 for i, v in enumerate(self._leqa):
222 out[i - self._val_ind[1]] = v(*out[:self._val_len + 1])
223 else:
224 if out is None:
225 out = numpy.empty((self._val_len + 1, ), dtype=numpy.float64)
226 out[:self._val_ind[1]] = self._val_q
227 out[self._val_ind[1]:self._val_ind[2]] = self._val_p
228 out[self._val_ind[2]:self._val_ind[3]] = self._val_c
229 out[self._val_ind[3]] = t
230 return out
232 @property
233 def P(self):
234 '''
235 Returns the parameters
236 '''
237 return [(a[0], b, a[2]) for a, b in zip(self._p, self._val_p)]
239 @property
240 def Q(self):
241 '''
242 Returns the quantities
243 '''
244 return [(a[0], b, a[2]) for a, b in zip(self._q, self._val_q)]
246 @property
247 def C(self):
248 '''
249 Returns the quantities
250 '''
251 return [(a[0], b, a[2]) for a, b in zip(self._c, self._val_c)]
253 def update(self, **values):
254 """Updates values."""
255 for k, v in values.items():
256 self[k] = v
258 def get(self):
259 """Retrieves all values."""
260 return {n: self[n] for n in self.names}
262 def to_rst(self):
263 '''
264 Returns a string formatted in RST.
265 '''
266 rows = [
267 f'*{self.__class__.__name__}*',
268 '',
269 '*Quantities*',
270 ''
271 ]
272 for name, _, doc in self._q:
273 rows.append(f'* *{name}*: {doc}')
274 rows.extend(['', '*Constants*', ''])
275 for name, _, doc in self._c:
276 rows.append(f'* *{name}*: {doc}')
277 rows.extend(['', '*Parameters*', ''])
278 for name, _, doc in self._p:
279 rows.append(f'* *{name}*: {doc}')
280 if self._eq is not None:
281 rows.extend(['', '*Equations*', '', '.. math::',
282 '', ' \\begin{array}{l}'])
283 for i, (k, v) in enumerate(sorted(self._eq.items())):
284 line = "".join(
285 [" ", "\\frac{d%s}{dt} = " % k, printing.latex(v)])
286 if i < len(self._eq) - 1:
287 line += " \\\\"
288 rows.append(line)
289 rows.append(" \\end{array}")
291 return '\n'.join(rows)
293 def _repr_html_(self):
294 '''
295 Returns a string formatted in RST.
296 '''
297 rows = [
298 f'<p><b>{self.__class__.__name__}</b></p>',
299 '',
300 '<p><i>Quantities</i></p>',
301 '',
302 '<ul>'
303 ]
304 for name, _, doc in self._q:
305 rows.append(f'<li><i>{name}</i>: {doc}</li>')
306 rows.extend(['</ul>', '', '<p><i>Constants</i></p>', '', '<ul>'])
307 for name, _, doc in self._c:
308 rows.append(f'<li><i>{name}</i>: {doc}</li>')
309 rows.extend(['</ul>', '', '<p><i>Parameters</i></p>', '', '<ul>'])
310 for name, _, doc in self._p:
311 rows.append(f'<li><i>{name}</i>: {doc}</li>')
312 if self._eq is not None:
313 rows.extend(['</ul>', '', '<p><i>Equations</i></p>', '', '<ul>'])
314 for i, (k, v) in enumerate(sorted(self._eq.items())):
315 lats = "\\frac{d%s}{dt} = %s" % (k, printing.latex(v))
316 lat = latex(lats, mode='equation')
317 line = "".join(["<li>", str(lat), '</li>'])
318 rows.append(line)
319 rows.append("</ul>")
321 return '\n'.join(rows)
323 def enumerate_edges(self):
324 """
325 Enumerates the list of quantities contributing
326 to others. It ignores constants.
327 """
328 if self._eq is not None:
329 params = set(_[0] for _ in self.P)
330 quants = set(_[0] for _ in self.Q)
331 for k, v in sorted(self._eq.items()):
332 n2 = k
333 n = []
334 for dobj in enumerate_traverse(v):
335 term = dobj['e']
336 if not hasattr(term, 'name'):
337 continue
338 if term.name not in params:
339 continue
340 parent = dobj['p']
341 others = list(
342 _['e'] for _ in enumerate_traverse(parent))
343 for o in others:
344 if hasattr(o, 'name') and o.name in quants:
345 sign = self.eqsign(n2, o.name)
346 yield (sign, o.name, n2, term.name)
347 if o.name != n2:
348 n.append((sign, o.name, n2, term.name))
349 if len(n) == 0:
350 yield (0, '?', n2, '?')
352 def to_dot(self, verbose=False, full=False):
353 """
354 Produces a graph in :epkg:`DOT` format.
355 """
356 rows = ['digraph{']
358 pattern = (' {name} [label="{name}\\n{doc}" shape=record];'
359 if verbose else
360 ' {name} [label="{name}"];')
361 for name, _, doc in self._q:
362 rows.append(pattern.format(name=name, doc=doc))
363 for name, _, doc in self._c:
364 rows.append(pattern.format(name=name, doc=doc))
366 if self._eq is not None:
367 pattern = (
368 ' {n1} -> {n2} [label="{sg}{name}\\nvalue={v:1.2g}"];'
369 if verbose else ' {n1} -> {n2} [label="{sg}{name}"];')
370 for sg, a, b, name in set(self.enumerate_edges()):
371 if not full and (a == b or sg < 0):
372 continue
373 if name == '?':
374 rows.append( # pragma: no cover
375 pattern.format(n1=a, n2=b, name=name,
376 v=numpy.nan, sg='0'))
377 continue # pragma: no cover
378 value = self[name]
379 stsg = '' if sg > 0 else '-'
380 rows.append(
381 pattern.format(n1=a, n2=b, name=name, v=value, sg=stsg))
383 rows.append('}')
384 return '\n'.join(rows)
386 @property
387 def cst_param(self):
388 '''
389 Returns a dictionary with the constant and the parameters.
390 '''
391 res = {}
392 for k, v in zip(self._c, self._val_c):
393 res[k[0]] = v
394 for k, v in zip(self._p, self._val_p):
395 res[k[0]] = v
396 return res
398 def evalf_eq(self, eq, t=0):
399 """
400 Evaluates an :epkg:`sympy` expression.
401 """
402 svalues = self._eval_cache()
403 svalues[self._syms['t']] = t
404 for k, v in zip(self._q, self._val_q):
405 svalues[self._syms[k[0]]] = v
406 return eq.evalf(subs=svalues)
408 def evalf_leq(self, name, t=0):
409 """
410 Evaluates a lambdified expression.
412 :param name: name of the lambdified expresion
413 :param t: t values
414 :return: evaluation
415 """
416 leq = self._lambdified_(name)
417 if leq is None:
418 raise RuntimeError( # pragma: no cover
419 f"Equation '{name}' was not lambdified.")
420 return leq(*self.vect(t))
422 def _eval_cache(self):
423 values = self.cst_param
424 svalues = {self._syms[k]: v for k, v in values.items()}
425 return svalues
427 def _lambdify_(self, name, eq, derivative=False):
428 'Lambdifies an expression and caches in member `_lambda_`.'
429 if not hasattr(self, '_lambda_'):
430 self._lambda_ = {}
431 if name not in self._lambda_:
432 names = (self.quantity_names + self.param_names +
433 self.cst_names + ['t'])
434 sym = [Symbol(n) for n in names]
435 if derivative:
436 sym += [Symbol('d' + n) for n in self.quantity_names]
437 self._lambda_[name] = {
438 'names': names,
439 'symbols': sym,
440 'eq': eq,
441 'pos': {n: i for i, n in enumerate(names)},
442 }
443 ll = lambdify(sym, eq, 'numpy')
444 self._lambda_[name]['la'] = ll
445 return self._lambda_[name]['la']
447 def _lambdified_(self, name):
448 """
449 Returns the lambdified expression of name *name*.
450 """
451 if hasattr(self, '_lambda_'):
452 r = self._lambda_.get(name, None)
453 if r is not None:
454 return r['la']
455 return None
457 def _eval_diff_sympy(self, t=0):
458 """
459 Evaluates derivatives.
460 Returns a dictionary.
461 """
462 svalues = self._eval_cache()
463 svalues[self._syms['t']] = t
464 for k, v in zip(self._q, self._val_q):
465 svalues[self._syms[k[0]]] = v
467 x = self.vect(t=t)
468 res = {}
469 for k, v in self._eq.items():
470 res[k] = v.evalf(subs=svalues)
471 for k, v in self._leq.items():
472 res[k] = v(*x)
473 return res
475 def eval_diff(self, t=0):
476 """
477 Evaluates derivatives.
478 Returns a dictionary.
479 """
480 x = self.vect(t=t)
481 res = {}
482 for k, v in self._leq.items():
483 res[k] = v(*x)
484 return res