Notebook on nbviewer (original) (raw)
def plot_grid(idata, data, sd_h, sd_w, mean_h, mean_w): """This function creates plots like figures 17.3 and 17.4 in the book."""
fig = plt.figure(figsize=(13, 13))
# Define gridspec
gs = gridspec.GridSpec(4, 6)
ax1 = plt.subplot(gs[:2, 1:5])
ax2 = plt.subplot(gs[2, :2])
ax3 = plt.subplot(gs[2, 2:4])
ax4 = plt.subplot(gs[2, 4:6])
ax5 = plt.subplot(gs[3, :2])
ax6 = plt.subplot(gs[3, 2:4])
ax7 = plt.subplot(gs[3, 4:6])
# Scatter plot of the observed data
ax1.scatter(
data.height,
data.weight,
s=40,
linewidths=1,
facecolor="none",
edgecolor="k",
zorder=10,
)
ax1.set_xlabel("height", fontdict=f_dict)
ax1.set_ylabel("weight", fontdict=f_dict)
ax1.set(xlim=(0, 80), ylim=(-350, 250))
# Convert parameters to original scale
beta0 = (
idata.posterior["beta0"] * sd_w
+ mean_w
- idata.posterior["beta1"] * mean_h * sd_w / sd_h
)
beta1 = idata.posterior["beta1"] * (sd_w / sd_h)
sigma = idata.posterior["sigma"] * sd_w
B = pd.DataFrame({"beta0": beta0.values.flatten(), "beta1": beta1.values.flatten()})
# credible regression lines from posterior
b0_hdi = np.round(az.hdi(B["beta0"].to_numpy(), hdi_prob=0.95))
b1_hdi = np.round(az.hdi(B["beta1"].to_numpy(), hdi_prob=0.95))
B_hdi = B[
B["beta0"].between(*b0_hdi) & B["beta1"].between(*b1_hdi)
]
xrange = np.arange(0, data.height.max() * 1.05)
for i in np.random.randint(0, len(B_hdi), 30):
ax1.plot(
xrange,
B_hdi.iloc[i, 0] + B_hdi.iloc[i, 1] * xrange,
c=color,
alpha=0.6,
zorder=0,
)
# intercept
pm.plot_posterior(beta0, point_estimate="mode", ax=ax2, color=color)
ax2.set_xlabel(r"$\beta_0$", fontdict=f_dict)
ax2.set_title("Intercept", fontdict={"weight": "bold"})
# slope
pm.plot_posterior(beta1, point_estimate="mode", ax=ax3, color=color, ref_val=0)
ax3.set_xlabel(r"$\beta_1$", fontdict=f_dict)
ax3.set_title("Slope", fontdict={"weight": "bold"})
# scatter plot beta1, beta0
ax4.scatter(beta1, beta0, edgecolor=color, facecolor="none", alpha=0.6)
ax4.set_xlabel(r"$\beta_1$", fontdict=f_dict)
ax4.set_ylabel(r"$\beta_0$", fontdict=f_dict)
# scale
pm.plot_posterior(sigma, point_estimate="mode", ax=ax5, color=color)
ax5.set_xlabel(r"$\sigma$", fontdict=f_dict)
ax5.set_title("Scale", fontdict={"weight": "bold"})
# normality
pm.plot_posterior(
np.log10(idata.posterior["nu"]), point_estimate="mode", ax=ax6, color=color
)
ax6.set_xlabel(r"log10($\nu$)", fontdict=f_dict)
ax6.set_title("Normality", fontdict={"weight": "bold"})
# scatter plot normality, sigma
ax7.scatter(
np.log10(idata.posterior["nu"]),
sigma,
edgecolor=color,
facecolor="none",
alpha=0.6,
)
ax7.set_xlabel(r"log10($\nu$)", fontdict=f_dict)
ax7.set_ylabel(r"$\sigma$", fontdict=f_dict)
plt.tight_layout()
return fig