GitHub - danielward27/paramax: A small library of paramaterizations and parameter constraints for PyTrees. (original) (raw)
Paramax
Parameterizations and constraints for JAX PyTrees
Paramax allows applying custom constraints or behaviors to PyTree components, using unwrappable placeholders. This can be used for
- Enforcing positivity (e.g., scale parameters)
- Structured matrices (triangular, symmetric, etc.)
- Applying tricks like weight normalization
- Marking components as non-trainable
Some benefits of the unwrappable pattern:
- It allows parameterizations to be computed once for a model (e.g. at the top of the loss function).
- It is flexible, e.g. allowing custom parameterizations to be applied to PyTrees from external libraries
- It is concise
If you found the package useful, please consider giving it a star on github, and if you create AbstractUnwrappable
s that may be of interest to others, a pull request would be much appreciated!
Documentation
Documentation available here.
Installation
Example
import paramax import jax.numpy as jnp scale = paramax.Parameterize(jnp.exp, jnp.log(jnp.ones(3))) # Enforce positivity paramax.unwrap(("abc", 1, scale)) ('abc', 1, Array([1., 1., 1.], dtype=float32))
Alternative parameterization patterns
Using properties to access parameterized model components is common but has drawbacks:
- Parameterizations are tied to class definition, limiting flexibility e.g. this cannot be used on PyTrees from external libraries
- It can become verbose with many parameters
- It often leads to repeatedly computing the parameterization