Other NUTS Samplers — Open Source Marketing Analytics Solution (original) (raw)

In this notebook we show how to fit a CLV model with other NUTS samplers. These alternative samplers can be significantly faster and also sample on the GPU.

Note

You need to install these packages in your Python environment.

Tip

You can pass the exact same nuts_sampler argument to the MMM models.

Tip

GPU support only works with select samplers in PyMC that use the JAX backend. These samplers include numpyro, blackjax, and nutpie.

Make sure the GPU is registered, follow the instructions here.

For the purpose of illustration, we will use the same data and model as in the other CLV notebooks.

import arviz as az import matplotlib.pyplot as plt from lifetimes.datasets import load_cdnow_summary

from pymc_marketing import clv

az.style.use("arviz-darkgrid") plt.rcParams["figure.figsize"] = [12, 7] plt.rcParams["figure.dpi"] = 100 plt.rcParams["figure.facecolor"] = "white"

%load_ext autoreload %autoreload 2 %config InlineBackend.figure_format = "retina"

df = ( load_cdnow_summary(index_col=[0]) .reset_index() .rename(columns={"ID": "customer_id"}) )

We can pass the keyword argument nuts_sampler to the fit method of the CLV model to specify the NUTS sampler to use. In addition, we can pass additional keyword arguments which will be passed to the pymc.sample method via the model builder layer. For example, we can use the numpyro sampler as:

sampler_kwargs = { "draws": 2_000, "target_accept": 0.9, "chains": 5, "random_seed": 42, }

model = clv.BetaGeoModel(data=df) idata_numpyro = model.fit(nuts_sampler="numpyro", **sampler_kwargs)

Similarly, we can use the blackjax sampler as:

idata_blackjax = model.fit(nuts_sampler="blackjax", **sampler_kwargs)

Finally, we can use the nutpie which is a Rust implementation of NUTS.

idata_nutpie = model.fit(nuts_sampler="nutpie", **sampler_kwargs)

The results from the samplers are almost identical:

Show code cell source Hide code cell source

fig, axes = plt.subplots( nrows=2, ncols=2, figsize=(12, 8), sharex=False, sharey=False, layout="constrained" )

axes = axes.ravel()

for i, var_name in enumerate(["a", "b", "alpha", "r"]): for j, (idata, label) in enumerate( zip( [idata_blackjax, idata_nutpie, idata_numpyro], ["blackjax", "nutpie", "numpyro"], strict=False, ) ): az.plot_posterior( data=idata, var_names=[var_name], color=f"C{j}", point_estimate=None, hdi_prob="hide", label=label, ax=axes[i], )

fig.suptitle( "Posterior istributions of model parameters", fontsize=18, fontweight="bold", y=1.05, );

../../_images/7af5a613b6f6ca1b2167b16dd6b14f2784d8baa78c9ac89f47916d4e2b30e4d1.png

%load_ext watermark %watermark -n -u -v -iv -w -p blackjax,numpyro,nutpie,pymc

Last updated: Sat Mar 09 2024

Python implementation: CPython Python version : 3.11.3 IPython version : 8.20.0

blackjax: 0.0.0 numpyro : 0.14.0 nutpie : 0.9.2 pymc : 5.10.4

arviz : 0.15.1 matplotlib : 3.7.1 pymc_marketing: 0.4.0

Watermark: 2.4.3