Step-by-step guide — ICoMo Toolbox

1.0.3 (original) (raw)

Step-by-step guide#

This guide shows step by step how a compartmental model is built and how its parameter can by inferred.

First some imports:

import importlib import multiprocessing import os import time

if importlib.metadata.version("jax") >= "0.4.33": os.environ["XLA_FLAGS"] = "--xla_cpu_use_thunk_runtime=false" # starting with jax 0.4.33, there is a speed regression for small ODEs, this flag # restores the old behavior. See https://github.com/jax-ml/jax/discussions/24501 and # https://github.com/patrick-kidger/diffrax/issues/518

import arviz as az import jax import jax.numpy as jnp import jaxopt import matplotlib.pyplot as plt import numpy as np import numpyro import optax import pymc as pm import pytensor.tensor as pt from tqdm.auto import tqdm

import icomo

Set the host device count of JAX, as otherwise parallel sampling with JAX doesn't work

(see https://github.com/jax-ml/jax/issues/1408)

numpyro.set_host_device_count(multiprocessing.cpu_count())

Let us first define a system of ordinary differential equations (ODEs). As an example, we will make an SEIR model with an Erlang distributed latent period.

ODEs should always be defined with t, y and args, t is the time variable, y the

variables proper and args the other arguments.

t is a float, y and args any python structure (dict, tuple, lists) containing arrays,

or for args potentially a function.

def Erlang_SEIR(t, y, args): # y and the args are passed as dictionary in this example to facilite # keeping track of the meaning of the variables N = args["N"]

# beta is here time-dependent, so we pass a function which will be evaluated at t
beta_t_func = args["beta_t_func"]

dy = {}  # Create the return dictionary of the derivatives, it will have the same
# structure as y

# The derivative of the S compartment is -beta(t) * I * S / N
dy["S"] = -beta_t_func(t) * y["I"] * y["S"] / N

# Latent period, use an helper function
dEs, outflow = icomo.erlang_kernel(
    inflow=beta_t_func(t) * y["I"] * y["S"] / N,
    comp=y["Es"],  # y["Es"] is assumed to be a list of compartments/variables to
    # be able to model the kernel
    rate=args["rate_latent"],
)
dy["Es"] = dEs

dy["I"] = outflow - args["rate_infectious"] * y["I"]

dy["R"] = args["rate_infectious"] * y["I"]

return dy  # return the derivatives

def SEIR(t, y, args): # y and the args are passed as dictionary in this example to facilite # keeping track of the meaning of the variables N = args["N"]

# beta is here time-dependent, so we pass a function which will be evaluated at t
beta_t_func = args["beta_t_func"]

dy = {}  # Create the return dictionary of the derivatives, it will have the same
# structure as y

# The derivative of the S compartment is -beta(t) * I * S / N
dy["S"] = -beta_t_func(t) * y["I"] * y["S"] / N

# Latent period, use an helper function
dE = beta_t_func(t) * y["I"] * y["S"] / N - args["rate_latent"] * y["Es"]

dy["Es"] = dE

dy["I"] = args["rate_latent"] * y["Es"] - args["rate_infectious"] * y["I"]

dy["R"] = args["rate_infectious"] * y["I"]

return dy  # return the derivatives

The above function defined the following set of equations:

\[\begin{split} \begin{align} \frac{\mathrm dS(t)}{\mathrm dt} &= -\tfrac{1}{N}{\beta(t)}SI,\\ \frac{\mathrm dE^{(1)}(t)}{\mathrm dt} &= \tfrac{1}{N}\beta(t) SI-\mathrm{rate_{latent}} \cdot n \cdot E^{(1)},\\ \frac{\mathrm dE^{(2)}(t)}{\mathrm dt} &= \mathrm{rate_{latent}} \cdot n \cdot E^{(1)} -\mathrm{rate_{latent}} \cdot n \cdot E^{(2)},\\ \frac{\mathrm dE^{(3)}(t)}{\mathrm dt} &= \mathrm{rate_{latent}} \cdot n \cdot E^{(2)} -\mathrm{rate_{latent}} \cdot n \cdot E^{(3)},\\ \frac{\mathrm dI(t)}{\mathrm dt} &= \mathrm{rate_{latent}} \cdot n \cdot E^{(3)} -\mathrm{rate_{infectious}} \cdot I,\\ \frac{\mathrm dR(t)}{\mathrm dt} &=\mathrm{rate_{infectious}} \cdot I,, \end{align}\end{split}\]

Here \(n = 3\) is the number of exposed compartments \(E\). In the function above the exposed compartments are saved as list in y["Es"]. The three equations \(\frac{\mathrm dE^{(1 \dots 3)}(t)}{\mathrm dt}\) are defined in the icomo.erlang_kernel function for convenience. This function also sets dynamically \(n\) to the length of the list y["Es"]. Take care to only use jax operations inside the differential equation, for instance jnp.cos instead of np.cos. It will be compiled with jax later and would otherwise lead to an error.

Integrating ODEs#

Given some starting conditions and parameters we can integrate our system of ODEs:

len_sim = 365 # days num_points = len_sim

First set the time variables

t_out = np.linspace(0, len_sim, num_points) # timepoints of the output t_beta = np.linspace( 0, len_sim, num_points // 14 ) # timepoints at which the time-dependent

variable is defined (every 2 weeks)

Define parameters

N = 1e5 # population R0 = 1.5 duration_latent = 3 # the average in days duration_infectious = 4 # the average in days beta0 = R0 / duration_infectious # infection rate

Set parameters for ODE

args = { "N": N, "rate_latent": 1 / duration_latent, "rate_infectious": 1 / duration_infectious, }

Define starting conditions

y0 = { "Es": np.array([100, 100, 100]), # multiple values to model the Erlang kernel "I": 300, "R": 0, } # It is important that the Es for the Erlang-kernel is a numpy or jax array, not a

list, as otherwise it is interpreted as separate compartments.

Susceptible compartment is N - other compartments

y0["S"] = N - y0["R"] - np.sum(y0["Es"])

We want to have a time-dependent beta. As the system of ODEs requires a value of beta

for any t, we define a interpolation function that can be evaluated at arbitrary t.

beta_t_func = icomo.interpolate_func( ts_in=t_beta, # assume constant beta for now values=beta0 * np.ones(len(t_beta)), ) args["beta_t_func"] = beta_t_func

Solve differential equation, passing the interpolation function and other arguments

in the dictionay args, which can be nested at will.

solution = icomo.diffeqsolve( ts_out=t_out, y0=y0, args=args, ODE=Erlang_SEIR, )

f = plt.figure(figsize=(4, 3))

The output is saved in the attribute of the solution object

plt.plot(solution.ys["I"]) plt.xlabel("Time") plt.ylabel("Infectious compartment III");

_images/example_5_0.png

We here use an ODE integrator built with the icomo.diffeqsolve which wraps the ODE solver from Diffrax. In general, the args and variables y0 passed to the integrator and subsequently to the ODE function can be a any nested list, tuple and/or dict, which is also called pytree. The output will have the same structure as y0 except that its variables will received a prependet time dimension.

Simplify the construction of ODEs with icomo.CompModel#

The system of ODEs can by vastly simplified. Notice how the population subtracted from one compartment is always added exactly to another compartment. Furthermore, the substracted amount is always proportional to the population currently in the compartment. Making use of these two properties, one can specify such a system by a number of flows starting and ending in different compartments and parametrized by rates which are multiplied by the starting compartment. Such a spefication is possible with the class icomo.CompModel:

comp_model = icomo.CompModel()

def Erlang_SEIR_v2(t, y, args): beta_t_func = args["beta_t_func"] comp_model.y = y comp_model.flow( start_comp="S", end_comp="Es", rate=y["I"] / args["N"] * beta_t_func(t), label="beta(t) * I/N", # label of the graph edge end_comp_is_erlang=True, ) # One has to specify that "Es" refers to # an array of compartments comp_model.erlang_flow("Es", "I", args["rate_latent"], label="rate_latent (erlang)") comp_model.flow("I", "R", args["rate_infectious"], label="rate_infectious") return comp_model.dy

Check whether the resulting dynamics are the same as the previous version

solution = icomo.diffeqsolve( ts_out=t_out, y0=y0, args=args, ODE=Erlang_SEIR_v2, )

f = plt.figure(figsize=(4, 3)) plt.plot(solution.ys["I"]) plt.xlabel("Time") plt.ylabel("Infectious compartment II");

_images/example_7_0.png

Take notice that CompModel assumes that the variables/compartments y are saved in a dictionary and that for flow that follow and Erlang kernels, the corresponding compartments are in a list saved under the key of start_comp.

Another advantage is the a CompModel object can display the graph with the view_graph() method, which helps the verify the parametrisation was correct. Specify in this case the label keyword when adding flows, these are displayed on the edges of the graph.

_images/example_10_0.svg

Another advantage is the a CompModel object can display the graph with the view_graph() method, which helps the verify the parametrisation was correct. Specify in this case the label keyword when adding flows, these are displayed on the edges of the graph.

Fitting/optimizing the model using Data#

We might to optimize some parameters of the model. This is readily achieved as it can be easily differenciated using jax.grad.

Concretely, we will optimize the beta_t variable. We defined it every 14 days. The icomo.interpolate_func class uses a cubic interpolation to obtain a continuous approximation inbetween. For completeness, we will also need to optimize the initial number of infected I0.

Data are the number of COVID-19 cases in England during 2022

data = [236576, 188567, 118275, 181068, 229901, 275647, 222174, 172980, 120287, 93615, 95790, 132959, 115245, 103219, 96412, 86325, 77950, 98097, 131514, 118857, 111760, 100324, 91730, 81733, 102365, 127364, 113201, 108668, 96200, 84026, 73512, 87702, 105051, 93949, 88418, 76644, 63264, 52305, 60552, 76483, 69661, 64339, 52293, 44559, 37529, 41755, 54745, 52502, 51572, 44716, 33241, 34647, 37720, 45131, 39188, 36476, 30631, 27837, 25089, 31360, 44649, 44795, 46179, 43761, 41408, 38759, 49026, 68520, 70111, 73531, 70510, 67030, 62078, 75738, 99832, 93708, 91752, 82853, 75285, 67417, 81054, 109286, 99095, 94185, 82905, 72430, 61456, 68336, 92857, 80126, 73643, 54791, 47462, 37592, 42179, 53727, 49101, 44750, 38709, 32988, 27885, 30952, 37371, 33771, 31685, 26638, 21615, 19924, 20144, 25223, 25772, 21457, 18464, 15703, 13210, 14715, 17108, 14560, 13055, 11608, 10012, 8354, 9034, 11907, 13318, 11672, 10221, 8954, 7517, 9062, 10702, 9748, 8725, 7642, 6942, 6120, 7973, 9417, 8286, 7638, 6820, 5821, 5019, 5997, 7313, 6679, 6293, 5753, 5409, 4988, 5680, 7115, 7093, 7011, 6209, 6635, 7492, 9815, 11747, 11411, 11784, 11319, 10784, 10077, 12693, 15272, 15141, 14837, 14225, 13919, 13623, 16997, 19942, 20896, 20873, 19590, 18349, 16800, 20778, 25104, 25041, 24601, 24240, 23650, 21961, 27780, 33704, 31415, 29698, 26454, 24611, 21246, 25096, 29826, 26946, 23856, 21100, 18359, 15017, 17236, 18718, 16670, 16169, 14158, 12378, 9934, 11275, 13299, 11725, 10544, 9329, 8508, 7325, 8426, 9835, 9170, 8353, 7296, 6548, 5420, 6547, 8068, 7309, 6748, 6193, 5339, 4495, 5363, 6424, 5559, 5032, 4564, 4012, 3502, 4085, 5224, 4778, 4554, 3796, 3451, 3112, 3439, 4662, 5226, 4615, 4350, 3808, 3417, 4288, 5125, 4578, 4235, 3747, 3553, 3268, 4347, 5307, 5147, 4973, 4634, 4134, 3521, 4339, 6166, 7828, 7553, 6815, 6485, 5719, 7240, 8617, 8251, 8414, 7895, 7615, 7065, 9204, 11692, 10965, 10231, 9389, 8950, 7053, 8558, 10497, 9461, 8768, 8134, 7970, 6442, 7780, 10002, 8233, 7598, 6576, 6297, 4968, 5523, 6807, 5665, 5449, 4855, 4401, 3583, 4271, 5130, 4549, 4014, 3678, 3107, 2665, 3267, 4321, 3829, 3576, 3205, 2893, 2433, 2912, 3923, 3534, 3260, 3037, 2773, 2430, 2890, 3937, 3805, 3616, 3329, 2902, 2551, 3448, 4639, 4386, 3980, 3505, 3596, 2748, 3786, 5325, 5241, 5195, 4537, 4216, 3570, 4631, 6986, 7111, 6533, 6094, 6084, 4743, 5772, 9073, 8631, 7802, 6877, 5741, 4132, 3694, 4950, 6605, 8419, 7718] # noqa E501 # fmt: skip data = np.array(data) N_England = 50e6

Setup a function that simulate the spread given a time-dependent beta_t and the

initial infected I0

def simulation(args_optimization): beta_t = args_optimization["beta_t"]

# Spread out the infected over the exposed and infectious compartments
I0 = args_optimization["I0"] / 2
Es_0 = args_optimization["I0"] / 6 * jnp.ones(3)

# Update const_args
args["N"] = N_England

# Update starting conditions
y0 = {
    "Es": Es_0,
    "I": I0,
    "R": 0,
}
y0["S"] = N_England - jnp.sum(y0["Es"]) - y0["I"] - y0["R"]

# beta is now time-dependent
beta_t_func = icomo.interpolate_func(ts_in=t_beta, values=beta_t)
args["beta_t_func"] = beta_t_func

output = icomo.diffeqsolve(
    ts_out=t_out,
    y0=y0,
    args=args,
    ODE=Erlang_SEIR_v2,
).ys

# Save also the beta at the output times
beta_t_interpolated = beta_t_func(t_out)
output["beta_t_interpolated"] = beta_t_interpolated

return output

Define our loss function

@jax.jit def loss(args_optimization): output = simulation(args_optimization) new_infected = -jnp.diff( output["S"] ) # The difference in the susceptible population # are the newly infected

# Use the mean squared difference as our loss, weighted by the number of new
# infected
loss = jnp.mean((new_infected - data[1:]) ** 2 / (new_infected + 1))
# Notice the use of jax.numpy instead of number for the calculation. This is
# necessary. as it allows the auto-differentiation of our loss function.

return loss

Define initial parameters

init_params = { "beta_t": beta0 * np.ones_like(t_beta), "I0": np.array(float(data[0] * duration_infectious)), } # np.arrays are used here, as ScipyMinimize passes converts it to np.arrays anyway,

and as we want to compile it beforehand for runtime measurements, it would otherwise

elicit a recompilation when called by jaxopt.ScipyMinimize

start_time = time.time()

Differenciate our loss

value_and_grad_loss = jax.jit(jax.value_and_grad(loss)) value_and_grad_loss(init_params) print(f"Compilation duration: {(time.time()-start_time):.1f}s")

Solve our minimization problem

solver = jaxopt.ScipyMinimize( fun=value_and_grad_loss, value_and_grad=True, method="L-BFGS-B", jit=False )

start_time = time.time() res = solver.run(init_params) end_time = time.time() print(f"Minimization duration: {(end_time-start_time):.3f}s")

print( f"Number of function evaluations: {res.state.iter_num}\n" f"Final cost: {res.state.fun_val:.3f}" )

Compilation duration: 13.6s Minimization duration: 0.946s Number of function evaluations: 180 Final cost: 1233.409

We decided here to use jaxopt.ScipyMinimize as minimization function, which wraps the scipy.minimize function. The advantage to scipy.minimize is that we can use pytrees as optimization variables instead flat arrays. Otherwise scipy.minimize works equally well.

In order to speed up the fitting procedure, we compile our loss function using jax.jit, a just-in-time (jit) compiler. This improves the runtime of the minimization by about a factor 20.

Notice the use of jax.numpy inside the loss function but not outside. It is a good habit to only use jax.numpy for calculations that needs to be automatically differentiated and otherwise the usual numpy. It might avoid the unnecessary tracing/graph-building of such variables and can also lead to errors if function still depend on traced variables outside the current scope.

Let us check the results:

f, axes = plt.subplots(2, 1, figsize=(4, 5), height_ratios=(1, 2.5)) plt.sca(axes[0]) plt.plot( t_out[:], simulation(res.params)["beta_t_interpolated"] * duration_infectious, color="tab:blue", label="model", lw=2, ) plt.ylabel("Reproduction\nnumber R_t") plt.xlim(t_out[0], t_out[-1]) plt.axhline([1], color="lightgray", ls="--") plt.sca(axes[1]) plt.plot(t_out, data, color="gray", ls="", marker="d", ms=3, label="data") plt.plot( t_out[1:], -np.diff(simulation(res.params)["S"]), color="tab:blue", label="model", lw=2, ) plt.xlabel("Time") plt.ylabel("Cases") plt.legend() plt.xlim(t_out[0], t_out[-1]);

_images/example_14_0.png

Fitting using Adam#

For high-dimensional optimization systems that are significantly underdetermined it might be advantageous to use a gradient descent algorithm instead of L-BFGS. This is not the case for this system, but we show it here as an example using optax:

start_learning_rate = 5e-2 schedule = optax.exponential_decay( init_value=start_learning_rate, transition_steps=1000, decay_rate=1 / 2, transition_begin=50, staircase=False, end_value=None, ) optimizer = optax.adam(learning_rate=schedule)

Initialize parameters of the model + optimizer.

opt_state = optimizer.init(init_params) losses = [] params_adam = init_params for i in (pbar := tqdm(range(2000))): func_val, grads = value_and_grad_loss(params_adam) if i % 10 == 0: pbar.set_description(f"Loss {func_val:.5f}") losses.append(func_val) updates, opt_state = optimizer.update(grads, opt_state) params_adam = optax.apply_updates(params_adam, updates)

f = plt.figure(figsize=(3, 2)) plt.plot(losses) plt.ylim(1e3, 1e4) plt.xlabel("iteration") plt.ylabel("loss");

_images/example_16_1.png

We obtain similar results:

f, axes = plt.subplots(2, 1, figsize=(4, 5), height_ratios=(1, 2.5)) plt.sca(axes[0]) plt.plot( t_out[:], simulation(params_adam)["beta_t_interpolated"] * duration_infectious, color="tab:blue", label="model", lw=2, ) plt.ylabel("Reproduction\nnumber R_t") plt.xlim(t_out[0], t_out[-1]) plt.axhline([1], color="lightgray", ls="--")

plt.sca(axes[1]) plt.plot(t_out, data, color="gray", ls="", marker="d", ms=3, label="data") plt.plot( t_out[1:], -np.diff(simulation(params_adam)["S"]), color="tab:blue", label="model", lw=2, ) plt.xlabel("Time") plt.ylabel("Cases") plt.legend() plt.xlim(t_out[0], t_out[-1]);

_images/example_18_0.png

Bayesian analysis#

With fitting procedure one doesn’t obtain good error estimates of the fitted parameters. As such, a Bayesian model helps to estimate the credible interval of the parameters of interest. Let us make such a model for our system of equations.

The central part is the modelling of the infection rate beta_t. In a bayesian spirit, we assume that differences between subsequent knots of the spline interpolation follow an hierarchical model: We assume that the deviation of the size of changes in infectiousness is similar across the changes. The equations for the beta_t are therefore:

\[\begin{split}\begin{align} \sigma_\beta &\sim HalfCauchy\left(0.2\right),\\ \Delta \beta_i &\sim \mathcal{N\left(0, \sigma_\beta\right)}, \\ \beta_k &= \beta_0 \cdot \exp \left(\sum_i^{k} \Delta \beta_i\right), \end{align}\end{split}\]

where \(\beta_k\) defines the k-th spline of the cubic interpolation. Let us define the model:

reduce the length of the simulation for runtime reasons

t_out_bayes = np.arange(100) data_bayes = data[t_out_bayes] t_solve_ODE_bayes = np.linspace(t_out_bayes[0], t_out_bayes[-1], len(t_out_bayes) // 2) t_beta_bayes = np.linspace(t_out_bayes[0], t_out_bayes[-1], len(t_out_bayes) // 14)

with pm.Model(coords={"time": t_out_bayes, "t_beta": t_beta_bayes}) as model: # We also allow the other rates of the compartments to vary duration_latent_var = pm.LogNormal( "duration_latent", mu=np.log(duration_latent), sigma=0.1 ) duration_infectious_var = pm.LogNormal( "duration_infectious", mu=np.log(duration_infectious), sigma=0.3 )

# Construct beta_t
R0 = pm.LogNormal("R0", np.log(1), 1)
beta_0_var = 1 * R0 / duration_infectious_var
beta_t_var = beta_0_var * pt.exp(
    pt.cumsum(
        icomo.experimental.hierarchical_priors("beta_t_log_diff", dims=("t_beta",))
    )
)  # The hierarchical priors implementation is in the experimental module: the API
# might change in future.

# Set the other parameters and initial conditions
args_var = {
    "N": N_England,
    "rate_latent": 1 / duration_latent_var,
    "rate_infectious": 1 / duration_infectious_var,
}
infections_0_var = pm.LogNormal(
    "infections_0", mu=np.log(data_bayes[0] * duration_infectious), sigma=2
)

y0_var = {
    "Es": infections_0_var / 3 * np.ones(3),
    "I": infections_0_var / 2,
    "R": 0,
}
y0_var["S"] = N_England - pt.sum(y0_var["Es"]) - y0_var["I"] - y0_var["R"]

# Define the interpolation function. For use with pymc/pytensor, the input and
# output has to be transformed as JAX is used inside the interpolate_func,
# that's why we transform the integration function using jax2pytensor
beta_t_func = icomo.jax2pytensor(icomo.interpolate_func)(
    ts_in=t_beta_bayes, values=beta_t_var
)
args_var["beta_t_func"] = beta_t_func

# Integrate the differential equation, transform ing the function with jax2pytensor
output = icomo.jax2pytensor(icomo.diffeqsolve)(
    ts_out=t_out_bayes,
    y0=y0_var,
    args=args_var,
    ODE=Erlang_SEIR,
).ys

pm.Deterministic("I", output["I"])
new_cases = -pt.diff(output["S"])
pm.Deterministic("new_cases", new_cases)

# And define our likelihood
sigma_error = pm.HalfCauchy("sigma_error", beta=1)
pm.StudentT(
    "cases_observed",
    nu=4,
    mu=new_cases,
    sigma=sigma_error * pt.sqrt(new_cases + 1),
    observed=data_bayes[1:],
)

# We also want to save the interpolated beta_t, this works as icomo.jax2pytensor
# also wrapped the function beta_t_func, such that it produces pytensor arrays
# if called
beta_t_interp = beta_t_func(t_out_bayes)
pm.Deterministic("beta_t_interp", beta_t_interp)

And then sample from it. We use the numpyro sampler, as it uses JAX which is more efficient as our differencial equation solver is written using jax. The normal pymc sampler also works. It would convert all the model in C, except our ODE solver, which would still run using JAX.

trace = pm.sample( model=model, tune=500, draws=500, chains=4, nuts_sampler="numpyro", target_accept=0.6, ) warnings = pm.stats.convergence.run_convergence_checks( trace, model=model, ) pm.stats.convergence.log_warnings(warnings) print(f"Maximal R-hat value: {max(az.rhat(trace).max().values()):.3f}")

The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details

Maximal R-hat value: 1.016

Notice how much longer the sampling takes compared to the simple fitting of the dynamics. It is also recommended to let it run for longer, to make sure the estimated posterior distribution converged. Let us plot the inferred parameters:

f, axes = plt.subplots(2, 1, figsize=(4, 5), height_ratios=(1, 2.5)) plt.sca(axes[0]) beta_t_post = ( trace.posterior["beta_t_interp"].to_numpy().reshape((-1, len(t_out_bayes))) ) R_t_post = ( beta_t_post * trace.posterior["duration_infectious"].to_numpy().flatten()[:, None] ) plt.plot(t_out_bayes, np.median(R_t_post, axis=0), color="tab:blue", alpha=0.3) plt.fill_between( t_out_bayes, *np.percentile(R_t_post, q=(2.5, 97.5), axis=0), color="tab:blue", alpha=0.3, ) plt.ylabel("Reproduction\nnumber R_t") plt.xlim(t_out_bayes[0], t_out_bayes[-1]) plt.axhline([1], color="lightgray", ls="--")

plt.sca(axes[1]) new_cases_post = ( trace.posterior["new_cases"].to_numpy().reshape((-1, len(t_out_bayes) - 1)) ) plt.plot( t_out_bayes[1:], np.median(new_cases_post, axis=0), color="tab:blue", alpha=0.3 ) plt.fill_between( t_out_bayes[1:], *np.percentile(new_cases_post, q=(2.5, 97.5), axis=0), color="tab:blue", alpha=0.3, label="Model (95% CI)", ) plt.plot(t_out_bayes, data_bayes, marker="d", color="black", ls="", ms=3, label="Data") plt.ylabel("Cases") plt.xlabel("Time") plt.xlim(t_out_bayes[0], t_out_bayes[-1]) plt.legend();

_images/example_24_0.png

f, axes = plt.subplots(1, 2, figsize=(5, 2.2)) x = np.linspace(1, 4, 100) plt.sca(axes[0]) plt.hist( trace.posterior["duration_latent"].data.flatten(), bins=30, density=True, label="Posterior", alpha=0.5, ) plt.plot( x, np.exp(pm.logp(pm.LogNormal.dist(np.log(duration_latent), 0.1), x).eval()), color="gray", label="Prior", ) plt.xlim(1.5, 4.5) plt.xlabel("Duration latent\nperiod") plt.ylabel("Density") plt.sca(axes[1]) x = np.linspace(0, 19, 100) plt.hist( trace.posterior["duration_infectious"].data.flatten(), bins=30, density=True, label="Posterior", alpha=0.5, ) plt.plot( x, np.exp(pm.logp(pm.LogNormal.dist(np.log(duration_infectious), 0.3), x).eval()), color="gray", label="Prior", ) plt.xlim(0, 19) plt.xlabel("Duration infectious\nperiod") plt.legend() plt.tight_layout();

_images/example_25_0.png