Marginalized Mixture won't begin sampling/throws assertion error (original) (raw)

Hi all,

I am encountering an error that I can’t identify in trying to use a Gaussian Mixture model. I tried to switch to a marginalized mixture to make it easier to compile, but it essentially won’t get off the ground at all. I’m fairly certain the model is well-specified- if I put in the data to scikit-learn’s Gaussian Mixture it will immediately spit out the two peaks on the hist/kde plot for ~160,000 observations. Here is the snippet (I can’t publicly share this specific data but can describe shapes- the graphviz indicates that everything looks kosher on that front)

with pm.Model(coords={'cluster': np.arange(2),'batters': mean_in_batters}) as Basic_Mixture_Model:
    mu = pm.Normal(
        'mu',
        mu = 80,
        sigma = 10,
        initval = [70,90],
        transform = pm.distributions.transforms.ordered,
        dims = 'cluster'
    )
    weights = pm.Dirichlet("w", np.ones(2), dims="cluster")
    tau = pm.Gamma("tau", 1.0, 1.0, dims="cluster")
    y = pm.NormalMixture("x", w=weights, mu=mu, tau = tau, observed=obs_ev_means,dims='batters')

    GM_trace = pm.sample_prior_predictive(1000)
    GM_trace.extend(pm.sample(chains=2, cores=4, random_seed=0))
    pm.sample_posterior_predictive(GM_trace, extend_inferencedata=True)

If I just call sample_prior_predictive, it will spit out essentially no separation- giving me two identical means. When I call extend, or try creating the trace directly by using pm.sample, I get the following Assertion Error

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[72], line 14
     11 tau = pm.Gamma("tau", 1.0, 1.0, dims="cluster")
     12 y = pm.NormalMixture("x", w=weights, mu=mu, tau = tau, observed=obs_ev_means, dims='batters')
---> 14 GM_trace = pm.sample(draws=1000,tune=1000, chains=2)

File ~\anaconda3\Lib\site-packages\pymc\sampling\mcmc.py:714, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, **kwargs)
    711         auto_nuts_init = False
    713 initial_points = None
--> 714 step = assign_step_methods(model, step, methods=pm.STEP_METHODS, step_kwargs=kwargs)
    716 if nuts_sampler != "pymc":
    717     if not isinstance(step, NUTS):

File ~\anaconda3\Lib\site-packages\pymc\sampling\mcmc.py:223, in assign_step_methods(model, step, methods, step_kwargs)
    221 if has_gradient:
    222     try:
--> 223         tg.grad(model_logp, var)  # type: ignore
    224     except (NotImplementedError, tg.NullTypeGradError):
    225         has_gradient = False

File ~\anaconda3\Lib\site-packages\pytensor\gradient.py:607, in grad(cost, wrt, consider_constant, disconnected_inputs, add_names, known_grads, return_disconnected, null_gradients)
    604     if hasattr(g.type, "dtype"):
    605         assert g.type.dtype in pytensor.tensor.type.float_dtypes
--> 607 _rval: Sequence[Variable] = _populate_grad_dict(
    608     var_to_app_to_idx, grad_dict, _wrt, cost_name
    609 )
    611 rval: MutableSequence[Variable | None] = list(_rval)
    613 for i in range(len(_rval)):

File ~\anaconda3\Lib\site-packages\pytensor\gradient.py:1402, in _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name)
   1399     # end if cache miss
   1400     return grad_dict[var]
-> 1402 rval = [access_grad_cache(elem) for elem in wrt]
   1404 return rval

File ~\anaconda3\Lib\site-packages\pytensor\gradient.py:1357, in _populate_grad_dict.<locals>.access_grad_cache(var)
   1355 for node in node_to_idx:
   1356     for idx in node_to_idx[node]:
-> 1357         term = access_term_cache(node)[idx]
   1359         if not isinstance(term, Variable):
   1360             raise TypeError(
   1361                 f"{node.op}.grad returned {type(term)}, expected"
   1362                 " Variable instance."
   1363             )

File ~\anaconda3\Lib\site-packages\pytensor\gradient.py:1032, in _populate_grad_dict.<locals>.access_term_cache(node)
   1029 if node not in term_dict:
   1030     inputs = node.inputs
-> 1032     output_grads = [access_grad_cache(var) for var in node.outputs]
   1034     # list of bools indicating if each output is connected to the cost
   1035     outputs_connected = [
   1036         not isinstance(g.type, DisconnectedType) for g in output_grads
   1037     ]

File ~\anaconda3\Lib\site-packages\pytensor\gradient.py:1357, in _populate_grad_dict.<locals>.access_grad_cache(var)
   1355 for node in node_to_idx:
   1356     for idx in node_to_idx[node]:
-> 1357         term = access_term_cache(node)[idx]
   1359         if not isinstance(term, Variable):
   1360             raise TypeError(
   1361                 f"{node.op}.grad returned {type(term)}, expected"
   1362                 " Variable instance."
   1363             )

File ~\anaconda3\Lib\site-packages\pytensor\gradient.py:1032, in _populate_grad_dict.<locals>.access_term_cache(node)
   1029 if node not in term_dict:
   1030     inputs = node.inputs
-> 1032     output_grads = [access_grad_cache(var) for var in node.outputs]
   1034     # list of bools indicating if each output is connected to the cost
   1035     outputs_connected = [
   1036         not isinstance(g.type, DisconnectedType) for g in output_grads
   1037     ]

File ~\anaconda3\Lib\site-packages\pytensor\gradient.py:1357, in _populate_grad_dict.<locals>.access_grad_cache(var)
   1355 for node in node_to_idx:
   1356     for idx in node_to_idx[node]:
-> 1357         term = access_term_cache(node)[idx]
   1359         if not isinstance(term, Variable):
   1360             raise TypeError(
   1361                 f"{node.op}.grad returned {type(term)}, expected"
   1362                 " Variable instance."
   1363             )

File ~\anaconda3\Lib\site-packages\pytensor\gradient.py:1187, in _populate_grad_dict.<locals>.access_term_cache(node)
   1179         if o_shape != g_shape:
   1180             raise ValueError(
   1181                 "Got a gradient of shape "
   1182                 + str(o_shape)
   1183                 + " on an output of shape "
   1184                 + str(g_shape)
   1185             )
-> 1187 input_grads = node.op.L_op(inputs, node.outputs, new_output_grads)
   1189 if input_grads is None:
   1190     raise TypeError(
   1191         f"{node.op}.grad returned NoneType, expected iterable."
   1192     )

File ~\anaconda3\Lib\site-packages\pytensor\graph\op.py:398, in Op.L_op(self, inputs, outputs, output_grads)
    371 def L_op(
    372     self,
    373     inputs: Sequence[Variable],
    374     outputs: Sequence[Variable],
    375     output_grads: Sequence[Variable],
    376 ) -> list[Variable]:
    377     r"""Construct a graph for the L-operator.
    378 
    379     The L-operator computes a row vector times the Jacobian.
   (...)
    396 
    397     """
--> 398     return self.grad(inputs, output_grads)

File ~\anaconda3\Lib\site-packages\pytensor\tensor\subtensor.py:1897, in IncSubtensor.grad(self, inputs, grads)
   1895         gx = g_output
   1896     gy = Subtensor(idx_list=self.idx_list)(g_output, *idx_list)
-> 1897     gy = _sum_grad_over_bcasted_dims(y, gy)
   1899 return [gx, gy] + [DisconnectedType()()] * len(idx_list)

File ~\anaconda3\Lib\site-packages\pytensor\tensor\subtensor.py:1933, in _sum_grad_over_bcasted_dims(x, gx)
   1931 x_dim_added = gx.ndim - x.ndim
   1932 x_broad = (True,) * x_dim_added + x.broadcastable
-> 1933 assert sum(gx.broadcastable) <= sum(x_broad)
   1934 axis_to_sum = []
   1935 for i in range(gx.ndim):

AssertionError: 

I very straightforwardly followed the examples so I’m not sure exactly what’s going wrong. Thanks for your help in advance.