GitHub - danielward27/paramax: A small library of paramaterizations and parameter constraints for PyTrees. (original) (raw)

Paramax

Parameterizations and constraints for JAX PyTrees

Paramax allows applying custom constraints or behaviors to PyTree components, using unwrappable placeholders. This can be used for

Some benefits of the unwrappable pattern:

If you found the package useful, please consider giving it a star on github, and if you create AbstractUnwrappables that may be of interest to others, a pull request would be much appreciated!

Documentation

Documentation available here.

Installation

Example

import paramax import jax.numpy as jnp scale = paramax.Parameterize(jnp.exp, jnp.log(jnp.ones(3))) # Enforce positivity paramax.unwrap(("abc", 1, scale)) ('abc', 1, Array([1., 1., 1.], dtype=float32))

Alternative parameterization patterns

Using properties to access parameterized model components is common but has drawbacks: