Truncated multivariate normal likelihood (original) (raw)

Hello,

I am trying to fit the mean and covariance of a 2D normal distribution to some data with the complication that my data is truncated. I only have observations within a window, although the underlying distribution really is normal.

Some astronomers had this same problem (among others) while fitting GMMs and made an expectation maximization algorithm that I’ve successfully used to solve this problem (Filling the gaps: Gaussian mixture models from noisy, truncated or incomplete samples), but I’d like to try Bayesian inference as well.

As far as I understand, PyMC only has a univariate truncated normal (pymc.TruncatedNormal), so I’m trying to define my own multivariate truncated normal with constant bounds of truncation. I’m really struggling to implement the normalization constant part. Here’s an example, where the third and fourth-to-last lines don’t actually work since they use scipy to show what I’m imagining.

import os
import pymc as pm
from scipy import stats
import numpy as np
rng = np.random.default_rng()

lower_bounds = np.zeros(2)
upper_bounds = np.array([128, 32])

# I have some guesses about the mean and covariance
prior_mean = np.mean([lower_bounds, upper_bounds], axis=0)
prior_covar = np.diag([2000., 400.])

# Some mock data
data = rng.multivariate_normal(
    prior_mean + np.array([10, 15]),
    prior_covar + np.array([[10, 11], [11, 12]]),
    size=10000
)
# Simulating truncation
mask = np.all((data >= lower_bounds) & (data <= upper_bounds), axis=1)
data = data[mask]

tune = 500
draws = 500

# Modeling
with pm.Model() as model:
    # Priors
    mu = pm.Normal(
        'mu', 
        mu=prior_mean, 
        sigma=50, 
        shape=2
    )
    chol, corr, std_devs = pm.LKJCholeskyCov(
        'chol_cov',
        eta=5., # I expect low correlation
        n=2,
        sd_dist = pm.InverseGamma.dist(mu=np.diag(prior_covar), sigma=np.array([200, 50]), shape=2)
    )
    cov = pm.Deterministic('cov', chol.dot(chol.T))

    # This is what isn't working. Trying to define the log prob of a truncated normal at `value`
    def truncated_mvnormal_logp(value, mu, cov):
        # Un-truncated normal
        untruncated_logp = pm.logp(
            pm.MvNormal.dist(mu, cov),
            value
        )

        # Find normalization constant. I know that mu.eval() doesn't actually exist
        mvn = stats.multivariate_normal(mean=mu.eval(), cov=cov.eval())
        norm_const = mvn.cdf(upper_bounds) - mvn.cdf(lower_bounds)

        return untruncated_logp - np.log(norm_const)

    pm.CustomDist('x', mu, cov, logp=truncated_mvnormal_logp, observed=data)

I don’t want to just kill samples outside the window with pymc.Potential since those out-of-the-window samples really are probable, it’s just that my data doesn’t fully encompass the true distribution since I can only see samples in a window. But this is doing the same sort of thing, saying that the samples aren’t probable… Anyway, I’m all mixed up.

Any guidance on how to do this kind of thing?

You can’t use scipy code like that in the logp. Someone tried to implement it as PyTensor code a while ago, but it got abandoned: Added MVN_cdf.py by JessSpearing · Pull Request #60 · pymc-devs/pymc-extras · GitHub

Importantly you’ll also want gradients, so that you can use NUTS.

Okay, thanks for clarifying!

There’s a subtle distinction between truncation and censoring. In a truncated normal distribution, the process you’re describing is truncated. For example, I might be using a normal model of age, which is constrained to be positive and hence truncated below at zero. It’s physically impossible to have ages less than zero, so it’s truncation.

In a censored normal distribution, the underlying process is normal, but when you observe values beyond some point, they get reported as being beyond the threshold, but the exact values aren’t given. A typical example of censored data is actuarial. You have data that some people died, but other people you only know they have lived to their current age and will die in the future. The people who are still alive provide censored values—the only thing you know about their death age is that it’s greater than their current age.

This makes it sound like you have censored data from an untruncated distribution, rather than data from a truncated distribution.

For the observations that are censored, you either need to sample with constraints or evaluate the cdfs. I’m not sure how easy it is to sample with constraints in PyMC, but it’s always challenging to evaluate bivariate normal cdfs as you need to use some form of numerical integration. For the constrained form, you introduce a variable that is constrained to be beyond the censoring point, then give it the underlying distribution. This is super confusing terminologically, because you use a truncated normal to as the pdf. Then MCMC does the integration that the cdf would otherwise have to do.

I’m not sure how easy it is to sample with constraints

I think this is what the first model in this example is doing with a univariate normal? Censored Data Models — PyMC example gallery

I guess you can do the same for multivariate normal, but then you have to sample one multivariate per observed point. No idea how that fares vs numerical integration.

I quite like this plot @drbenvincent made for this notebook on censored/truncated GLMs, to visualize the difference:

image

The nice thing about HMC is that it scales really well with dimension. I’ve found that introducing a bunch more parameters isn’t such a big deal as long as the probability mass doesn’t bunch up near constraint boundaries, which can be hard on the unconstrained geometry.

Right, but here it’s one extra variable per observed censored data-point

Right—that’s always the case with censoring. I haven’t tried this particular model at scale with NUTS, but with something like a negative binomial model, I’ve often found it more efficient to reparameterize with a Poisson with a latent gamma random effect, even when that adds 100K or more extra parameters. \mathcal{O}(D^{1/4}) scaling is pretty great in practice as well as in theory.

The geometry’s usually much more of a problem. Like if you have a few hundred spatial random effects in a Poisson model, but they’re spatially smoothed so that you have a hierarchical model. The geometry there can be much more challenging to sample than adding more parameters in my experience.

One of the serious drawbacks to letting MCMC do the integration for you is overall output size for the sample of MCMC draws. It’s not only slow in disk and memory just for I/O, if you save all the random effects and try to do something like pass to ArviZ, it’ll take a long time to calculate all the summary statistics. And if you try to be super picky about R-hat thresholds, you might never be happy with the results of 200K multiple comparisons.

One advantage of the random effect approach is that it’s a bit more flexible, especially if you have multilevel structure that’s not strictly hierarchical in the random effects case or more complicated than normal densities in the truncated case. So it’s usually just simpler to use the random effect approach. Another advantage is that it doesn’t rely on numerical (log) cdfs. In Stan, these are some of our least stable functions when differentiating against parameters because we just have the traditional cdf implementations and take their log—we should really have custom log cdf iterative algorithms for both values and gradients, but tha’s way beyond our person-power and expertise. So we just autodiff through taking the log of the cdf. We couldn’t find better implementations, so I’m wondering if you did for PyMC? Especially for log cdfs. I should really be asking this in a separate topic.

Re: the memory aspect you can control what variables pymc actually stores in the trace with var_names so that’s not a problem if they’re just nuisance parameters. You do loose the rhat and convergence statistics ofc.

Re: iterative cdfs, we only have a few for like hypergeometric, and incomplete beta and gamma, which are directly on log-scale (IIRC). I think we copied them mostly from stan-math, except for one where we found an alternative algorithm that was slightly more stable. So we’re behind you guys on that front.

As mentioned above we don’t even have the mvnormal cdf.

I think the R implementations of the cdfs on the log scale are more stable, but all their code’s GPL-ed and I’ve been afraid to use it. I don’t think you could be behind Stan. Stan doesn’t have an mvnormal or even binomral cdf, nor does it have any implementations of which I’m aware on the log scale. And absolutely none that derive a good algorithm for the derivatives beyond autodiffing the algorithm for the values. I think a lot of the expertise in writing special functions died out decades ago in the applied math community.