Notebook on nbviewer (original) (raw)

def plot_grid(idata, data, sd_h, sd_w, mean_h, mean_w): """This function creates plots like figures 17.3 and 17.4 in the book."""

fig = plt.figure(figsize=(13, 13))

# Define gridspec
gs = gridspec.GridSpec(4, 6)
ax1 = plt.subplot(gs[:2, 1:5])
ax2 = plt.subplot(gs[2, :2])
ax3 = plt.subplot(gs[2, 2:4])
ax4 = plt.subplot(gs[2, 4:6])
ax5 = plt.subplot(gs[3, :2])
ax6 = plt.subplot(gs[3, 2:4])
ax7 = plt.subplot(gs[3, 4:6])

# Scatter plot of the observed data
ax1.scatter(
    data.height,
    data.weight,
    s=40,
    linewidths=1,
    facecolor="none",
    edgecolor="k",
    zorder=10,
)
ax1.set_xlabel("height", fontdict=f_dict)
ax1.set_ylabel("weight", fontdict=f_dict)
ax1.set(xlim=(0, 80), ylim=(-350, 250))

# Convert parameters to original scale
beta0 = (
    idata.posterior["beta0"] * sd_w
    + mean_w
    - idata.posterior["beta1"] * mean_h * sd_w / sd_h
)
beta1 = idata.posterior["beta1"] * (sd_w / sd_h)
sigma = idata.posterior["sigma"] * sd_w
B = pd.DataFrame({"beta0": beta0.values.flatten(), "beta1": beta1.values.flatten()})

# credible regression lines from posterior
b0_hdi = np.round(az.hdi(B["beta0"].to_numpy(), hdi_prob=0.95))
b1_hdi = np.round(az.hdi(B["beta1"].to_numpy(), hdi_prob=0.95))
B_hdi = B[
    B["beta0"].between(*b0_hdi) & B["beta1"].between(*b1_hdi)
]
xrange = np.arange(0, data.height.max() * 1.05)
for i in np.random.randint(0, len(B_hdi), 30):
    ax1.plot(
        xrange,
        B_hdi.iloc[i, 0] + B_hdi.iloc[i, 1] * xrange,
        c=color,
        alpha=0.6,
        zorder=0,
    )

# intercept
pm.plot_posterior(beta0, point_estimate="mode", ax=ax2, color=color)
ax2.set_xlabel(r"$\beta_0$", fontdict=f_dict)
ax2.set_title("Intercept", fontdict={"weight": "bold"})

# slope
pm.plot_posterior(beta1, point_estimate="mode", ax=ax3, color=color, ref_val=0)
ax3.set_xlabel(r"$\beta_1$", fontdict=f_dict)
ax3.set_title("Slope", fontdict={"weight": "bold"})

# scatter plot beta1, beta0
ax4.scatter(beta1, beta0, edgecolor=color, facecolor="none", alpha=0.6)
ax4.set_xlabel(r"$\beta_1$", fontdict=f_dict)
ax4.set_ylabel(r"$\beta_0$", fontdict=f_dict)

# scale
pm.plot_posterior(sigma, point_estimate="mode", ax=ax5, color=color)
ax5.set_xlabel(r"$\sigma$", fontdict=f_dict)
ax5.set_title("Scale", fontdict={"weight": "bold"})

# normality
pm.plot_posterior(
    np.log10(idata.posterior["nu"]), point_estimate="mode", ax=ax6, color=color
)
ax6.set_xlabel(r"log10($\nu$)", fontdict=f_dict)
ax6.set_title("Normality", fontdict={"weight": "bold"})

# scatter plot normality, sigma
ax7.scatter(
    np.log10(idata.posterior["nu"]),
    sigma,
    edgecolor=color,
    facecolor="none",
    alpha=0.6,
)
ax7.set_xlabel(r"log10($\nu$)", fontdict=f_dict)
ax7.set_ylabel(r"$\sigma$", fontdict=f_dict)

plt.tight_layout()

return fig