Error with sample posterior predictive function (original) (raw)

Hi all,

I am currently learning Bayesian Additive Regression Trees and would like to replicate the model which is in your website named Bayesian Additive Regression Trees: Introduction
(link: [Bayesian Additive Regression Trees: Introduction — PyMC example gallery]

When I am trying to do out-of-sample predictions, it gives error message related to different dimensions on data. I am using PyMC v5.10.2, so I am wondering if there is any update that causes it?

Following codes, I have a problem with pm.sample_posterior_predictive function.

X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2, random_state=RANDOM_SEED)

with pm.Model() as model_oos_regression:
    X = pm.MutableData("X", X_train)
    Y = Y_train
    α = pm.Exponential("α", 1)
    μ = pmb.BART("μ", X, np.log(Y))
    y = pm.NegativeBinomial("y", mu=pm.math.exp(μ), alpha=α, observed=Y, shape=μ.shape)
    idata_oos_regression = pm.sample(random_seed=RANDOM_SEED)
    posterior_predictive_oos_regression_train = pm.sample_posterior_predictive(
        trace=idata_oos_regression, random_seed=RANDOM_SEED
    )

Error messages:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[36], line 8
      6 y = pm.NegativeBinomial("y", mu=pm.math.exp(μ), alpha=α, observed=Y, shape=μ.shape)
      7 idata_oos_regression = pm.sample(random_seed=RANDOM_SEED)
----> 8 posterior_predictive_oos_regression_train = pm.sample_posterior_predictive(
      9     trace=idata_oos_regression, random_seed=RANDOM_SEED
     10 )

File /opt/conda/envs/python3/lib/python3.9/site-packages/pymc/sampling/forward.py:673, in sample_posterior_predictive(trace, model, var_names, sample_dims, random_seed, progressbar, return_inferencedata, extend_inferencedata, predictions, idata_kwargs, compile_kwargs)
    671         ikwargs.setdefault("inplace", True)
    672     return pm.predictions_to_inference_data(ppc_trace, **ikwargs)
--> 673 idata_pp = pm.to_inference_data(posterior_predictive=ppc_trace, **ikwargs)
    675 if extend_inferencedata and idata is not None:
    676     idata.extend(idata_pp)

File /opt/conda/envs/python3/lib/python3.9/site-packages/pymc/backends/arviz.py:512, in to_inference_data(trace, prior, posterior_predictive, log_likelihood, coords, dims, sample_dims, model, save_warmup, include_transformed)
    509 if isinstance(trace, InferenceData):
    510     return trace
--> 512 return InferenceDataConverter(
    513     trace=trace,
    514     prior=prior,
    515     posterior_predictive=posterior_predictive,
    516     log_likelihood=log_likelihood,
    517     coords=coords,
    518     dims=dims,
    519     sample_dims=sample_dims,
    520     model=model,
    521     save_warmup=save_warmup,
    522     include_transformed=include_transformed,
    523 ).to_inference_data()

File /opt/conda/envs/python3/lib/python3.9/site-packages/pymc/backends/arviz.py:433, in InferenceDataConverter.to_inference_data(self)
    423 def to_inference_data(self):
    424     """Convert all available data to an InferenceData object.
    425 
    426     Note that if groups can not be created (e.g., there is no `trace`, so
    427     the `posterior` and `sample_stats` can not be extracted), then the InferenceData
    428     will not have those groups.
    429     """
    430     id_dict = {
    431         "posterior": self.posterior_to_xarray(),
    432         "sample_stats": self.sample_stats_to_xarray(),
--> 433         "posterior_predictive": self.posterior_predictive_to_xarray(),
    434         "predictions": self.predictions_to_xarray(),
    435         **self.priors_to_xarray(),
    436         "observed_data": self.observed_data_to_xarray(),
    437     }
    438     if self.predictions:
    439         id_dict["predictions_constant_data"] = self.constant_data_to_xarray()

File /opt/conda/envs/python3/lib/python3.9/site-packages/arviz/data/base.py:65, in requires.__call__.<locals>.wrapped(cls)
     63     if all((getattr(cls, prop_i) is None for prop_i in prop)):
     64         return None
---> 65 return func(cls)

File /opt/conda/envs/python3/lib/python3.9/site-packages/pymc/backends/arviz.py:344, in InferenceDataConverter.posterior_predictive_to_xarray(self)
    342 data = self.posterior_predictive
    343 dims = {var_name: self.sample_dims + self.dims.get(var_name, []) for var_name in data}
--> 344 return dict_to_dataset(
    345     data, library=pymc, coords=self.coords, dims=dims, default_dims=self.sample_dims
    346 )

File /opt/conda/envs/python3/lib/python3.9/site-packages/arviz/data/base.py:307, in dict_to_dataset(data, attrs, library, coords, dims, default_dims, index_origin, skip_event_dims)
    305 data_vars = {}
    306 for key, values in data.items():
--> 307     data_vars[key] = numpy_to_data_array(
    308         values,
    309         var_name=key,
    310         coords=coords,
    311         dims=dims.get(key),
    312         default_dims=default_dims,
    313         index_origin=index_origin,
    314         skip_event_dims=skip_event_dims,
    315     )
    316 return xr.Dataset(data_vars=data_vars, attrs=make_attrs(attrs=attrs, library=library))

File /opt/conda/envs/python3/lib/python3.9/site-packages/arviz/data/base.py:254, in numpy_to_data_array(ary, var_name, coords, dims, default_dims, index_origin, skip_event_dims)
    252 # filter coords based on the dims
    253 coords = {key: xr.IndexVariable((key,), data=np.asarray(coords[key])) for key in dims}
--> 254 return xr.DataArray(ary, coords=coords, dims=dims)

File /opt/conda/envs/python3/lib/python3.9/site-packages/xarray/core/dataarray.py:450, in DataArray.__init__(self, data, coords, dims, name, attrs, indexes, fastpath)
    448 data = _check_data_shape(data, coords, dims)
    449 data = as_compatible_data(data)
--> 450 coords, dims = _infer_coords_and_dims(data.shape, coords, dims)
    451 variable = Variable(dims, data, attrs, fastpath=True)
    453 if not isinstance(coords, Coordinates):

File /opt/conda/envs/python3/lib/python3.9/site-packages/xarray/core/dataarray.py:173, in _infer_coords_and_dims(shape, coords, dims)
    171     dims = tuple(dims)
    172 elif len(dims) != len(shape):
--> 173     raise ValueError(
    174         "different number of dimensions on data "
    175         f"and dims: {len(shape)} vs {len(dims)}"
    176     )
    177 else:
    178     for d in dims:

ValueError: different number of dimensions on data and dims: 3 vs 2

Thank you in advance for your help.