Minibatch not working (original) (raw)

@mcao @ricardoV94 @ferrine

I was facing the same problem.
I debugged it using your Example and I thought there were two bugs.
The first is that you are generating separate mini batches for X and Y. Fixing this problem would eliminate the following Warning.

UserWarning: RNG Variable RandomGeneratorSharedVariable(<Generator(PCG64) at 0x151ADB760>) has multiple clients. This is likely an inconsistent random graph.

The second point is that total_size is not working. The original assumption seems to be that total_size is an argument that scales logp as it is specified, but that does not appear to be working. Therefore, I thought I could solve this by manually calculating and scaling the logp.
I will share the code.

However, I too do not know if this is the perfect fix. If there are any experts out there, I would like to know what the solution is.

import numpy as np
import pymc as pm
import pytensor.tensor as pt
import matplotlib.pyplot as plt
import arviz as az

# generate data
N = 10000
P = 3
rng = np.random.default_rng(88)
X = rng.uniform(2, 10, size=(N, 3))
beta = np.array([1.5, 0.2, -0.9])
y = np.matmul(X, beta) + rng.normal(0, 1, size=(N,))

# minibatch
batch_size = 100
X_mb = pm.Minibatch(X, batch_size=batch_size)
y_mb = pm.Minibatch(y, batch_size=batch_size)

# model with minibatch
with pm.Model() as model_mb:
    b = pm.Normal("b", mu=0, sigma=3, shape=(P,))
    sigma = pm.HalfCauchy("sigma", 1)
    mu = pt.matmul(X_mb, b)
    likelihood = pm.Normal(
        "likelihood", mu=mu, sigma=sigma, observed=y_mb, total_size=N
    )

    fit_mb1 = pm.fit(
        n=100000,
        method="advi",
        progressbar=True,
        callbacks=[pm.callbacks.CheckParametersConvergence()],
        random_seed=88,
    )
    idata_mb1 = fit_mb1.sample(500)


# minibatch
batch_size = 100
X_mb, y_mb = pm.Minibatch(X, y, batch_size=batch_size)
#y_mb = pm.Minibatch(y, batch_size=100)

# model with minibatch
with pm.Model() as model_mb:
    b = pm.Normal("b", mu=0, sigma=3, shape=(P,))
    sigma = pm.HalfCauchy("sigma", 1)
    mu = pt.matmul(X_mb, b)
    likelihood = pm.Normal(
        "likelihood", mu=mu, sigma=sigma, observed=y_mb, total_size=N
    )

    fit_mb2 = pm.fit(
        n=100000,
        method="advi",
        progressbar=True,
        callbacks=[pm.callbacks.CheckParametersConvergence()],
        random_seed=88,
    )
    idata_mb2 = fit_mb2.sample(500)


# minibatch
batch_size = 100
X_mb, y_mb = pm.Minibatch(X, y, batch_size=batch_size)
#y_mb = pm.Minibatch(y, batch_size=100)

# model with minibatch
with pm.Model() as model_mb:
    b = pm.Normal("b", mu=0, sigma=3, shape=(P,))
    sigma = pm.HalfCauchy("sigma", 1)
    mu = pt.matmul(X_mb, b)
    pm.Potential("likelihood", (N/batch_size)*pm.logp(rv=pm.Normal.dist(mu=mu, sigma=sigma), value=y_mb))

    fit_mb3 = pm.fit(
        n=100000,
        method="advi",
        progressbar=True,
        callbacks=[pm.callbacks.CheckParametersConvergence()],
        random_seed=88,
    )
    idata_mb3 = fit_mb3.sample(500)


# model no minibatch
with pm.Model() as model:
    b = pm.Normal("b", mu=0, sigma=3, shape=(P,))
    sigma = pm.HalfCauchy("sigma", 1)
    mu = pt.matmul(X, b)
    likelihood = pm.Normal("likelihood", mu=mu, sigma=sigma, observed=y, total_size=N)

    fit = pm.fit(
        n=100000,
        method="advi",
        progressbar=True,
        callbacks=[pm.callbacks.CheckParametersConvergence()],
        random_seed=88,
    )
    idata = fit.sample(500)


# compare models
fig, ax = plt.subplots(nrows=4, ncols=3, figsize=(10, 8), layout="constrained")
az.plot_posterior(
    idata_mb1,
    var_names="b",
    ref_val=beta.tolist(),
    ax=ax[0, :],
    textsize=8,
)
az.plot_posterior(
    idata_mb2,
    var_names="b",
    ref_val=beta.tolist(),
    ax=ax[1, :],
    textsize=8,
)
az.plot_posterior(
    idata_mb3,
    var_names="b",
    ref_val=beta.tolist(),
    ax=ax[2, :],
    textsize=8,
)
az.plot_posterior(idata, var_names="b", ref_val=beta.tolist(), ax=ax[3, :], textsize=8)

for i in range(3):
    ax[1, i].set_xlim(ax[0, i].get_xlim())

ax[0, 0].annotate(
    text="Minibatch1",
    xy=(-0.5, 0.5),
    xycoords="axes fraction",
    rotation=90,
    size=15,
    fontweight="bold",
    va="center",
)
ax[1, 0].annotate(
    text="Minibatch2",
    xy=(-0.5, 0.5),
    xycoords="axes fraction",
    rotation=90,
    size=15,
    fontweight="bold",
    va="center",
)
ax[2, 0].annotate(
    text="Minibatch3",
    xy=(-0.5, 0.5),
    xycoords="axes fraction",
    rotation=90,
    size=15,
    fontweight="bold",
    va="center",
)
ax[3, 0].annotate(
    text="Full Data",
    xy=(-0.5, 0.5),
    xycoords="axes fraction",
    rotation=90,
    size=15,
    fontweight="bold",
    va="center",
)

image