Notebook on nbviewer (original) (raw)

  1. PyMB
  2. Examples Notebook

PyMB Example

Import Module

In [1]:

import sys sys.path.append('../..') import PyMB from PyMB import magic # enable %%PyMB cell magic import numpy as np

Define the model

YsimmathcalN(hatY,sigma)Y \sim \mathcal{N}(\hat{Y},\sigma)YsimmathcalN(hatY,sigma)hatY=alpha+Bx\hat{Y} = \alpha + B xhatY=alpha+Bx

In [2]:

%%PyMB LinearRegression // DATA DATA_VECTOR(Y); DATA_VECTOR(x);

// PARAMETERS PARAMETER(alpha); PARAMETER(Beta); PARAMETER(logSigma);

// MODEL vector Y_hat = alpha + Beta*x; REPORT(Y_hat); Type nll = -sum(dnorm(Y, Y_hat, exp(logSigma), true)); return nll;

Created model LinearRegression. Using tmb_tmp/LinearRegression.cpp. Compiled in 24.3s.

Simulate data

In [3]:

LinearRegression.data = { 'x': np.arange(10), 'Y': np.random.normal(np.arange(10)) }

Set initial parameter values

In [4]:

LinearRegression.init = { 'alpha': 0., 'Beta': 0., 'logSigma': 0. }

Fit the model

The model likelihood will be integrated wrt the random parameters. See here for more information.

In [5]:

LinearRegression.optimize(random=['alpha','Beta'])

Matching hessian patterns... Done

Model optimization complete in 0.1s.


Simulated 100 draws in 0.2s.

alpha: mean [ 0.31609895] sd [ 0.73192141] draws [[-0.72558336 -0.71637844 ..., 1.25651001 -0.43803027]] shape (1, 100) Beta: mean [ 0.88899844] sd [ 0.13710144] draws [[ 1.09499943 1.00378394 ..., 0.55369137 0.89730876]] shape (1, 100)

In [6]:

print(LinearRegression.report('Y_hat'))

[ 0.31609895 1.20509738 2.09409582 2.98309425 3.87209269 4.76109113 5.65008956 6.539088 7.42808643 8.31708487]

Examine joint density

In [7]:

%matplotlib inline import matplotlib.pyplot as plt import seaborn as sns import pandas as pd df = pd.DataFrame({ k: v['draws'][0] for k,v in LinearRegression.parameters.iteritems() }) g = sns.PairGrid(df, diag_sharey=False) g.map_lower(sns.kdeplot, cmap='Blues_d') g.map_upper(plt.scatter) g.map_diag(sns.kdeplot, lw=3)

/usr/bin/anaconda/lib/python2.7/site-packages/matplotlib/axes/_axes.py:476: UserWarning: No labelled objects found. Use label='...' kwarg on individual plots. warnings.warn("No labelled objects found. "