Minibatch not working (original) (raw)
I was facing the same problem.
I debugged it using your Example and I thought there were two bugs.
The first is that you are generating separate mini batches for X and Y. Fixing this problem would eliminate the following Warning.
UserWarning: RNG Variable RandomGeneratorSharedVariable(<Generator(PCG64) at 0x151ADB760>) has multiple clients. This is likely an inconsistent random graph.
The second point is that total_size is not working. The original assumption seems to be that total_size is an argument that scales logp as it is specified, but that does not appear to be working. Therefore, I thought I could solve this by manually calculating and scaling the logp.
I will share the code.
However, I too do not know if this is the perfect fix. If there are any experts out there, I would like to know what the solution is.
import numpy as np
import pymc as pm
import pytensor.tensor as pt
import matplotlib.pyplot as plt
import arviz as az
# generate data
N = 10000
P = 3
rng = np.random.default_rng(88)
X = rng.uniform(2, 10, size=(N, 3))
beta = np.array([1.5, 0.2, -0.9])
y = np.matmul(X, beta) + rng.normal(0, 1, size=(N,))
# minibatch
batch_size = 100
X_mb = pm.Minibatch(X, batch_size=batch_size)
y_mb = pm.Minibatch(y, batch_size=batch_size)
# model with minibatch
with pm.Model() as model_mb:
b = pm.Normal("b", mu=0, sigma=3, shape=(P,))
sigma = pm.HalfCauchy("sigma", 1)
mu = pt.matmul(X_mb, b)
likelihood = pm.Normal(
"likelihood", mu=mu, sigma=sigma, observed=y_mb, total_size=N
)
fit_mb1 = pm.fit(
n=100000,
method="advi",
progressbar=True,
callbacks=[pm.callbacks.CheckParametersConvergence()],
random_seed=88,
)
idata_mb1 = fit_mb1.sample(500)
# minibatch
batch_size = 100
X_mb, y_mb = pm.Minibatch(X, y, batch_size=batch_size)
#y_mb = pm.Minibatch(y, batch_size=100)
# model with minibatch
with pm.Model() as model_mb:
b = pm.Normal("b", mu=0, sigma=3, shape=(P,))
sigma = pm.HalfCauchy("sigma", 1)
mu = pt.matmul(X_mb, b)
likelihood = pm.Normal(
"likelihood", mu=mu, sigma=sigma, observed=y_mb, total_size=N
)
fit_mb2 = pm.fit(
n=100000,
method="advi",
progressbar=True,
callbacks=[pm.callbacks.CheckParametersConvergence()],
random_seed=88,
)
idata_mb2 = fit_mb2.sample(500)
# minibatch
batch_size = 100
X_mb, y_mb = pm.Minibatch(X, y, batch_size=batch_size)
#y_mb = pm.Minibatch(y, batch_size=100)
# model with minibatch
with pm.Model() as model_mb:
b = pm.Normal("b", mu=0, sigma=3, shape=(P,))
sigma = pm.HalfCauchy("sigma", 1)
mu = pt.matmul(X_mb, b)
pm.Potential("likelihood", (N/batch_size)*pm.logp(rv=pm.Normal.dist(mu=mu, sigma=sigma), value=y_mb))
fit_mb3 = pm.fit(
n=100000,
method="advi",
progressbar=True,
callbacks=[pm.callbacks.CheckParametersConvergence()],
random_seed=88,
)
idata_mb3 = fit_mb3.sample(500)
# model no minibatch
with pm.Model() as model:
b = pm.Normal("b", mu=0, sigma=3, shape=(P,))
sigma = pm.HalfCauchy("sigma", 1)
mu = pt.matmul(X, b)
likelihood = pm.Normal("likelihood", mu=mu, sigma=sigma, observed=y, total_size=N)
fit = pm.fit(
n=100000,
method="advi",
progressbar=True,
callbacks=[pm.callbacks.CheckParametersConvergence()],
random_seed=88,
)
idata = fit.sample(500)
# compare models
fig, ax = plt.subplots(nrows=4, ncols=3, figsize=(10, 8), layout="constrained")
az.plot_posterior(
idata_mb1,
var_names="b",
ref_val=beta.tolist(),
ax=ax[0, :],
textsize=8,
)
az.plot_posterior(
idata_mb2,
var_names="b",
ref_val=beta.tolist(),
ax=ax[1, :],
textsize=8,
)
az.plot_posterior(
idata_mb3,
var_names="b",
ref_val=beta.tolist(),
ax=ax[2, :],
textsize=8,
)
az.plot_posterior(idata, var_names="b", ref_val=beta.tolist(), ax=ax[3, :], textsize=8)
for i in range(3):
ax[1, i].set_xlim(ax[0, i].get_xlim())
ax[0, 0].annotate(
text="Minibatch1",
xy=(-0.5, 0.5),
xycoords="axes fraction",
rotation=90,
size=15,
fontweight="bold",
va="center",
)
ax[1, 0].annotate(
text="Minibatch2",
xy=(-0.5, 0.5),
xycoords="axes fraction",
rotation=90,
size=15,
fontweight="bold",
va="center",
)
ax[2, 0].annotate(
text="Minibatch3",
xy=(-0.5, 0.5),
xycoords="axes fraction",
rotation=90,
size=15,
fontweight="bold",
va="center",
)
ax[3, 0].annotate(
text="Full Data",
xy=(-0.5, 0.5),
xycoords="axes fraction",
rotation=90,
size=15,
fontweight="bold",
va="center",
)