GitHub - google-deepmind/dm-haiku: JAX-based neural network library (original) (raw)

Haiku: Sonnet for JAX

Overview| Why Haiku?| Quickstart| Installation| Examples| User manual| Documentation| Citing Haiku

pytest docs pypi

Important

📣 As of July 2023 Google DeepMind recommends that new projects adoptFlax instead of Haiku. Flax is a neural network library originally developed by Google Brain and now by Google DeepMind. 📣

At the time of writing Flax has superset of the features available in Haiku, a larger andmore active development team and more adoption with users outside of Alphabet. Flax hasmore extensive documentation,examplesand an active community creating end to end examples.

Haiku will remain best-effort supported, however the project will entermaintenance mode, meaning that development efforts will be focussed on bug fixes and compatibility with new releases of JAX.

New releases will be made to keep Haiku working with newer versions of Python and JAX, however we will not be adding (or accepting PRs for) new features.

We have significant usage of Haiku internally at Google DeepMind and currently plan to support Haiku in this mode indefinitely.

What is Haiku?

Haiku is a tool
For building neural networks
Think: "Sonnet for JAX"

Haiku is a simple neural network library for JAX developed by some of the authors of Sonnet, a neural network library for TensorFlow.

Documentation on Haiku can be found at https://dm-haiku.readthedocs.io/.

Disambiguation: if you are looking for Haiku the operating system then please see https://haiku-os.org/.

Overview

JAX is a numerical computing library that combines NumPy, automatic differentiation, and first-class GPU/TPU support.

Haiku is a simple neural network library for JAX that enables users to use familiar object-oriented programming models while allowing full access to JAX's pure function transformations.

Haiku provides two core tools: a module abstraction, hk.Module, and a simple function transformation, hk.transform.

hk.Modules are Python objects that hold references to their own parameters, other modules, and methods that apply functions on user inputs.

hk.transform turns functions that use these object-oriented, functionally "impure" modules into pure functions that can be used with jax.jit,jax.grad, jax.pmap, etc.

Why Haiku?

There are a number of neural network libraries for JAX. Why should you choose Haiku?

Haiku has been tested by researchers at DeepMind at scale.

Haiku is a library, not a framework.

Haiku does not reinvent the wheel.

Transitioning to Haiku is easy.

Haiku makes other aspects of JAX simpler.

Quickstart

Let's take a look at an example neural network, loss function, and training loop. (For more examples, see ourexamples directory. TheMNIST exampleis a good place to start.)

import haiku as hk import jax.numpy as jnp

def softmax_cross_entropy(logits, labels): one_hot = jax.nn.one_hot(labels, logits.shape[-1]) return -jnp.sum(jax.nn.log_softmax(logits) * one_hot, axis=-1)

def loss_fn(images, labels): mlp = hk.Sequential([ hk.Linear(300), jax.nn.relu, hk.Linear(100), jax.nn.relu, hk.Linear(10), ]) logits = mlp(images) return jnp.mean(softmax_cross_entropy(logits, labels))

loss_fn_t = hk.transform(loss_fn) loss_fn_t = hk.without_apply_rng(loss_fn_t)

rng = jax.random.PRNGKey(42) dummy_images, dummy_labels = next(input_dataset) params = loss_fn_t.init(rng, dummy_images, dummy_labels)

def update_rule(param, update): return param - 0.01 * update

for images, labels in input_dataset: grads = jax.grad(loss_fn_t.apply)(params, images, labels) params = jax.tree.map(update_rule, params, grads)

The core of Haiku is hk.transform. The transform function allows you to write neural network functions that rely on parameters (here the weights of theLinear layers) without requiring you to explicitly write the boilerplate for initialising those parameters. transform does this by transforming the function into a pair of functions that are pure (as required by JAX) initand apply.

init

The init function, with signature params = init(rng, ...) (where ... are the arguments to the untransformed function), allows you to collect the initial value of any parameters in the network. Haiku does this by running your function, keeping track of any parameters requested through hk.get_parameter(called by e.g. hk.Linear) and returning them to you.

The params object returned is a nested data structure of all the parameters in your network, designed for you to inspect and manipulate. Concretely, it is a mapping of module name to module parameters, where a module parameter is a mapping of parameter name to parameter value. For example:

{'linear': {'b': ndarray(..., shape=(300,), dtype=float32),
            'w': ndarray(..., shape=(28, 300), dtype=float32)},
 'linear_1': {'b': ndarray(..., shape=(100,), dtype=float32),
              'w': ndarray(..., shape=(1000, 100), dtype=float32)},
 'linear_2': {'b': ndarray(..., shape=(10,), dtype=float32),
              'w': ndarray(..., shape=(100, 10), dtype=float32)}}

apply

The apply function, with signature result = apply(params, rng, ...), allows you to inject parameter values into your function. Wheneverhk.get_parameter is called, the value returned will come from the params you provide as input to apply:

loss = loss_fn_t.apply(params, rng, images, labels)

Note that since the actual computation performed by our loss function doesn't rely on random numbers, passing in a random number generator is unnecessary, so we could also pass in None for the rng argument. (Note that if your computation does use random numbers, passing in None for rng will cause an error to be raised.) In our example above, we ask Haiku to do this for us automatically with:

loss_fn_t = hk.without_apply_rng(loss_fn_t)

Since apply is a pure function we can pass it to jax.grad (or any of JAX's other transforms):

grads = jax.grad(loss_fn_t.apply)(params, images, labels)

Training

The training loop in this example is very simple. One detail to note is the use of jax.tree.map to apply the sgd function across all matching entries inparams and grads. The result has the same structure as the previous paramsand can again be used with apply.

Installation

Haiku is written in pure Python, but depends on C++ code via JAX.

Because JAX installation is different depending on your CUDA version, Haiku does not list JAX as a dependency in requirements.txt.

First, follow these instructionsto install JAX with the relevant accelerator support.

Then, install Haiku using pip:

$ pip install git+https://github.com/deepmind/dm-haiku

Alternatively, you can install via PyPI:

$ pip install -U dm-haiku

Our examples rely on additional libraries (e.g. bsuite). You can install the full set of additional requirements using pip:

$ pip install -r examples/requirements.txt

User manual

Writing your own modules

In Haiku, all modules are a subclass of hk.Module. You can implement any method you like (nothing is special-cased), but typically modules implement__init__ and __call__.

Let's work through implementing a linear layer:

class MyLinear(hk.Module):

def init(self, output_size, name=None): super().init(name=name) self.output_size = output_size

def call(self, x): j, k = x.shape[-1], self.output_size w_init = hk.initializers.TruncatedNormal(1. / np.sqrt(j)) w = hk.get_parameter("w", shape=[j, k], dtype=x.dtype, init=w_init) b = hk.get_parameter("b", shape=[k], dtype=x.dtype, init=jnp.zeros) return jnp.dot(x, w) + b

All modules have a name. When no name argument is passed to the module, its name is inferred from the name of the Python class (for example MyLinearbecomes my_linear). Modules can have named parameters that are accessed using hk.get_parameter(param_name, ...). We use this API (rather than just using object properties) so that we can convert your code into a pure function using hk.transform.

When using modules you need to define functions and transform them into a pair of pure functions using hk.transform. See our quickstart for more details about the functions returned from transform:

def forward_fn(x): model = MyLinear(10) return model(x)

Turn forward_fn into an object with init and apply methods. By default,

the apply will require an rng (which can be None), to be used with

hk.next_rng_key.

forward = hk.transform(forward_fn)

x = jnp.ones([1, 1])

When we run forward.init, Haiku will run forward_fn(x) and collect initial

parameter values. Haiku requires you pass a RNG key to init, since parameters

are typically initialized randomly:

key = hk.PRNGSequence(42) params = forward.init(next(key), x)

When we run forward.apply, Haiku will run forward_fn(x) and inject parameter

values from the params that are passed as the first argument. Note that

models transformed using hk.transform(f) must be called with an additional

rng argument: forward.apply(params, rng, x). Use

hk.without_apply_rng(hk.transform(f)) if this is undesirable.

y = forward.apply(params, None, x)

Working with stochastic models

Some models may require random sampling as part of the computation. For example, in variational autoencoders with the reparametrization trick, a random sample from the standard normal distribution is needed. For dropout we need a random mask to drop units from the input. The main hurdle in making this work with JAX is in management of PRNG keys.

In Haiku we provide a simple API for maintaining a PRNG key sequence associated with modules: hk.next_rng_key() (or next_rng_keys() for multiple keys):

class MyDropout(hk.Module):

def init(self, rate=0.5, name=None): super().init(name=name) self.rate = rate

def call(self, x): key = hk.next_rng_key() p = jax.random.bernoulli(key, 1.0 - self.rate, shape=x.shape) return x * p / (1.0 - self.rate)

forward = hk.transform(lambda x: MyDropout()(x))

key1, key2 = jax.random.split(jax.random.PRNGKey(42), 2) params = forward.init(key1, x) prediction = forward.apply(params, key2, x)

For a more complete look at working with stochastic models, please see ourVAE example.

Note: hk.next_rng_key() is not functionally pure which means you should avoid using it alongside JAX transformations which are inside hk.transform. For more information and possible workarounds, please consult the docs onHaiku transformsand availablewrappers for JAX transforms inside Haiku networks.

Working with non-trainable state

Some models may want to maintain some internal, mutable state. For example, in batch normalization a moving average of values encountered during training is maintained.

In Haiku we provide a simple API for maintaining mutable state that is associated with modules: hk.set_state and hk.get_state. When using these functions you need to transform your function using hk.transform_with_statesince the signature of the returned pair of functions is different:

def forward(x, is_training): net = hk.nets.ResNet50(1000) return net(x, is_training)

forward = hk.transform_with_state(forward)

The init function now returns parameters and state. State contains

anything that was created using hk.set_state. The structure is the same as

params (e.g. it is a per-module mapping of named values).

params, state = forward.init(rng, x, is_training=True)

The apply function now takes both params and state. Additionally it will

return updated values for state. In the resnet example this will be the

updated values for moving averages used in the batch norm layers.

logits, state = forward.apply(params, state, rng, x, is_training=True)

If you forget to use hk.transform_with_state don't worry, we will print a clear error pointing you to hk.transform_with_state rather than silently dropping your state.

Distributed training with jax.pmap

The pure functions returned from hk.transform (or hk.transform_with_state) are fully compatible with jax.pmap. For more details on SPMD programming withjax.pmap,look here.

One common use of jax.pmap with Haiku is for data-parallel training on many accelerators, potentially across multiple hosts. With Haiku, that might look like this:

def loss_fn(inputs, labels): logits = hk.nets.MLP([8, 4, 2])(x) return jnp.mean(softmax_cross_entropy(logits, labels))

loss_fn_t = hk.transform(loss_fn) loss_fn_t = hk.without_apply_rng(loss_fn_t)

Initialize the model on a single device.

rng = jax.random.PRNGKey(428) sample_image, sample_label = next(input_dataset) params = loss_fn_t.init(rng, sample_image, sample_label)

Replicate params onto all devices.

num_devices = jax.local_device_count() params = jax.tree.map(lambda x: np.stack([x] * num_devices), params)

def make_superbatch(): """Constructs a superbatch, i.e. one batch of data per device."""

Get N batches, then split into list-of-images and list-of-labels.

superbatch = [next(input_dataset) for _ in range(num_devices)] superbatch_images, superbatch_labels = zip(*superbatch)

Stack the superbatches to be one array with a leading dimension, rather than

a python list. This is what jax.pmap expects as input.

superbatch_images = np.stack(superbatch_images) superbatch_labels = np.stack(superbatch_labels) return superbatch_images, superbatch_labels

def update(params, inputs, labels, axis_name='i'): """Updates params based on performance on inputs and labels.""" grads = jax.grad(loss_fn_t.apply)(params, inputs, labels)

Take the mean of the gradients across all data-parallel replicas.

grads = jax.lax.pmean(grads, axis_name)

Update parameters using SGD or Adam or ...

new_params = my_update_rule(params, grads) return new_params

Run several training updates.

for _ in range(10): superbatch_images, superbatch_labels = make_superbatch() params = jax.pmap(update, axis_name='i')(params, superbatch_images, superbatch_labels)

For a more complete look at distributed Haiku training, take a look at ourResNet-50 on ImageNet example.

Citing Haiku

To cite this repository:

@software{haiku2020github,
  author = {Tom Hennigan and Trevor Cai and Tamara Norman and Lena Martens and Igor Babuschkin},
  title = {{H}aiku: {S}onnet for {JAX}},
  url = {http://github.com/deepmind/dm-haiku},
  version = {0.0.14},
  year = {2020},
}

In this bibtex entry, the version number is intended to be fromhaiku/__init__.py, and the year corresponds to the project's open-source release.