Jupyter Notebook Viewer (original) (raw)

Predictions for Model 3.

rng_key, rng_key_ = random.split(rng_key) predictions_3 = Predictive(model, samples_3)( rng_key_, marriage=dset.MarriageScaled.values, age=dset.AgeScaled.values )["obs"] y = jnp.arange(50)

fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 16)) pred_mean = jnp.mean(predictions_3, axis=0) pred_hpdi = hpdi(predictions_3, 0.9) residuals_3 = dset.DivorceScaled.values - predictions_3 residuals_mean = jnp.mean(residuals_3, axis=0) residuals_hpdi = hpdi(residuals_3, 0.9) idx = jnp.argsort(residuals_mean)

Plot posterior predictive

ax[0].plot(jnp.zeros(50), y, "--") ax[0].errorbar( pred_mean[idx], y, xerr=pred_hpdi[1, idx] - pred_mean[idx], marker="o", ms=5, mew=4, ls="none", alpha=0.8, ) ax[0].plot(dset.DivorceScaled.values[idx], y, marker="o", ls="none", color="gray") ax[0].set( xlabel="Posterior Predictive (red) vs. Actuals (gray)", ylabel="State", title="Posterior Predictive with 90% CI", ) ax[0].set_yticks(y) ax[0].set_yticklabels(dset.Loc.values[idx], fontsize=10)

Plot residuals

residuals_3 = dset.DivorceScaled.values - predictions_3 residuals_mean = jnp.mean(residuals_3, axis=0) residuals_hpdi = hpdi(residuals_3, 0.9) err = residuals_hpdi[1] - residuals_mean

ax[1].plot(jnp.zeros(50), y, "--") ax[1].errorbar( residuals_mean[idx], y, xerr=err[idx], marker="o", ms=5, mew=4, ls="none", alpha=0.8 ) ax[1].set(xlabel="Residuals", ylabel="State", title="Residuals with 90% CI") ax[1].set_yticks(y) ax[1].set_yticklabels(dset.Loc.values[idx], fontsize=10);