jax.example_libraries.stax module — JAX documentation (original) (raw)
Contents
- AvgPool()
- BatchNorm()
- Conv()
- Conv1DTranspose()
- ConvTranspose()
- Dense()
- Dropout()
- FanInConcat()
- FanOut()
- GeneralConv()
- GeneralConvTranspose()
- MaxPool()
- SumPool()
- elementwise()
- parallel()
- serial()
- shape_dependent()
jax.example_libraries.stax
module#
Stax is a small but flexible neural net specification library from scratch.
You likely do not mean to import this module! Stax is intended as an example library only. There are a number of other much more fully-featured neural network libraries for JAX, including Flax from Google, and Haiku from DeepMind.
jax.example_libraries.stax.AvgPool(window_shape, strides=None, padding='VALID', spec=None)[source]#
Layer construction function for a pooling layer.
jax.example_libraries.stax.BatchNorm(axis=(0, 1, 2), epsilon=1e-05, center=True, scale=True, beta_init=, gamma_init=)[source]#
Layer construction function for a batch normalization layer.
jax.example_libraries.stax.Conv(out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal..init>)#
Layer construction function for a general convolution layer.
jax.example_libraries.stax.Conv1DTranspose(out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal..init>)#
Layer construction function for a general transposed-convolution layer.
jax.example_libraries.stax.ConvTranspose(out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal..init>)#
Layer construction function for a general transposed-convolution layer.
jax.example_libraries.stax.Dense(out_dim, W_init=<function variance_scaling..init>, b_init=<function normal..init>)[source]#
Layer constructor function for a dense (fully-connected) layer.
jax.example_libraries.stax.Dropout(rate, mode='train')[source]#
Layer construction function for a dropout layer with given rate.
jax.example_libraries.stax.FanInConcat(axis=-1)[source]#
Layer construction function for a fan-in concatenation layer.
jax.example_libraries.stax.FanOut(num)[source]#
Layer construction function for a fan-out layer.
jax.example_libraries.stax.GeneralConv(dimension_numbers, out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal..init>)[source]#
Layer construction function for a general convolution layer.
jax.example_libraries.stax.GeneralConvTranspose(dimension_numbers, out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal..init>)[source]#
Layer construction function for a general transposed-convolution layer.
jax.example_libraries.stax.MaxPool(window_shape, strides=None, padding='VALID', spec=None)[source]#
Layer construction function for a pooling layer.
jax.example_libraries.stax.SumPool(window_shape, strides=None, padding='VALID', spec=None)[source]#
Layer construction function for a pooling layer.
jax.example_libraries.stax.elementwise(fun, **fun_kwargs)[source]#
Layer that applies a scalar function elementwise on its inputs.
jax.example_libraries.stax.parallel(*layers)[source]#
Combinator for composing layers in parallel.
The layer resulting from this combinator is often used with the FanOut and FanInSum layers.
Parameters:
*layers – a sequence of layers, each an (init_fun, apply_fun) pair.
Returns:
A new layer, meaning an (init_fun, apply_fun) pair, representing the parallel composition of the given sequence of layers. In particular, the returned layer takes a sequence of inputs and returns a sequence of outputs with the same length as the argument layers.
jax.example_libraries.stax.serial(*layers)[source]#
Combinator for composing layers in serial.
Parameters:
*layers – a sequence of layers, each an (init_fun, apply_fun) pair.
Returns:
A new layer, meaning an (init_fun, apply_fun) pair, representing the serial composition of the given sequence of layers.
jax.example_libraries.stax.shape_dependent(make_layer)[source]#
Combinator to delay layer constructor pair until input shapes are known.
Parameters:
make_layer – a one-argument function that takes an input shape as an argument (a tuple of positive integers) and returns an (init_fun, apply_fun) pair.
Returns:
A new layer, meaning an (init_fun, apply_fun) pair, representing the same layer as returned by make_layer but with its construction delayed until input shapes are known.