Base — Sonnet documentation (original) (raw)

Module

class sonnet.Module(name=None)[source]

Base class for Sonnet modules.

A Sonnet module is a lightweight container for variables and other modules. Modules typically define one or more “forward” methods (e.g. __call__) which apply operations combining user input and module parameters. For example:

class MultiplyModule(snt.Module): ... def call(self, x): ... if not hasattr(self, 'w'): ... self.w = tf.Variable(2., name='w') ... return x * self.w

mod = MultiplyModule() mod(1.) <tf.Tensor: ... numpy=2.0>

Sonnet modules are a layer on top of tf.Module, implementing automatic name scoping as described in the original RFC [1].

__init__(name=None)[source]

Initializes the current module with the given name.

Subclasses should call this constructor before creating other modules or variables such that those modules are named correctly.

Parameters

name (Optional[str]) – An optional string name for the class. Must be a valid Python identifier. If name is not provided then the class name for the current instance is converted to lower_snake_case and used instead.

property variables

Sequence of tf.Variables owned by this module and it’s submodules.

See tf.Module.variables for implementation details.

NOTE: Most Sonnet modules create variables lazily (e.g. the first time they are called). As such just after construction there are typically no variables. To mitigate a common error (calling .variables or.trainable_variables before any variables are created) these properties will raise an exception if their result is empty. Seeallow_empty_variables() if you want to suppress this error.

Returns

A sequence of variables for the current module (sorted by attribute name) followed by variables from all submodules recursively (breadth first).

property trainable_variables

Sequence of tf.Variables owned by this module and it’s submodules.

See tf.Module.trainable_variables for implementation details.

NOTE: Most Sonnet modules create variables lazily (e.g. the first time they are called). As such just after construction there are typically no variables. To mitigate a common error (calling .variables or.trainable_variables before any variables are created) these properties will raise an exception if their result is empty. Seeallow_empty_variables() if you want to suppress this error.

Returns

A sequence of variables for the current module (sorted by attribute name) followed by variables from all submodules recursively (breadth first).

once

sonnet.once(f)[source]

Decorator which ensures a wrapped method is only ever run once.

@snt.once ... def f(): ... print('Hello, world!') f() Hello, world! f() f()

If f is a method then it will be evaluated once per instance:

class MyObject: ... @snt.once ... def f(self): ... print('Hello, world!')

o = MyObject() o.f() Hello, world! o.f()

o2 = MyObject() o2.f() Hello, world! o.f() o2.f()

If an error is raised during execution of f it will be raised to the user. Next time the method is run, it will be treated as not having run before.

Parameters

f – A function to wrap which should only be called once.

Returns

Wrapped version of f which will only evaluate f the first time it is called.

no_name_scope

sonnet.no_name_scope(method)[source]

Decorator to wrap a method, preventing automatic name scope wrapping.

By default, any method on a module is considered as a forwards function, and so any variables / modules created by the method will be scoped as belonging to the module. In some cases this is undesirable, for example when implementing .clone() / .transpose(), as in those cases we want the new module to have the scope of wherever the .transpose() call is made. To allow this, decorate any methods with no_name_scope.

Parameters

method (TypeVar(T)) – the method to wrap.

Return type

TypeVar(T)

Returns

The method, with a flag indicating no name scope wrapping should occur.

Deferred

class sonnet.Deferred(*args, **kwargs)[source]

Defers the construction of another module until the first call.

Deferred can be used to declare modules that depend on computed properties of other modules before those modules are defined. This allows users to separate the declaration and use of modules. For example at the start of your program you can declare two modules which are coupled:

encoder = snt.Linear(64) decoder = snt.Deferred(lambda: snt.Linear(encoder.input_size))

Later you can use these naturally (note: that using decoder first would cause an error since encoder.input_size is only defined after encoder has been called):

x = tf.ones([8, 32]) y = encoder(x) z = decoder(y) # Constructs the Linear encoder by calling the lambda.

The result will satisfy the following conditions:

assert x.shape == z.shape assert y.shape == [8, 64] assert decoder.input_size == encoder.output_size assert decoder.output_size == encoder.input_size

__init__(constructor, call_methods=('__call__',), name=None)[source]

Initializes the Deferred module.

Parameters

property target

Returns the target module.

If the constructor has not already run this will trigger construction. Subsequent calls to target will return the same instance.

Returns

A Module instance as created by self.constructor() .

__call__(*args, **kwargs)[source]

Call self as a function.

__str__()[source]

Return str(self).

__repr__()[source]

Return repr(self).

__delattr__(name)[source]

Implement delattr(self, name).

Linear modules

Linear

class sonnet.Linear(output_size, with_bias=True, w_init=None, b_init=None, name=None)[source]

Linear module, optionally including bias.

__init__(output_size, with_bias=True, w_init=None, b_init=None, name=None)[source]

Constructs a Linear module.

Parameters

__call__(*args, **kwargs)[source]

Call self as a function.

Bias

class sonnet.Bias(output_size=None, bias_dims=None, b_init=None, name=None)[source]

Bias module.

Example Usage:

N, H, W, C = 1, 2, 3, 4 x = tf.random.normal([N, H, W, C])

scalar_bias = snt.Bias(bias_dims=[]) scalar_bias_output = scalar_bias(x) assert scalar_bias.b.shape == []

Create a bias over all non-minibatch dimensions:

all_bias = snt.Bias() all_bias_output = all_bias(x) assert all_bias.b.shape == [H, W, C]

Create a bias over the last non-minibatch dimension:

last_bias = snt.Bias(bias_dims=[-1]) last_bias_output = last_bias(x) assert last_bias.b.shape == [C]

Create a bias over the first non-minibatch dimension:

first_bias = snt.Bias(bias_dims=[1]) first_bias_output = first_bias(x) assert first_bias.b.shape == [H, 1, 1]

Subtract and later add the same learned bias:

bias = snt.Bias() h1 = bias(x, multiplier=-1) h2 = bias(x) h3 = bias(x, multiplier=-1) reconstructed_x = bias(h3) assert tf.reduce_all(tf.equal(x, reconstructed_x))

__init__(output_size=None, bias_dims=None, b_init=None, name=None)[source]

Constructs a Bias module that supports broadcasting.

Parameters

__call__(*args, **kwargs)[source]

Call self as a function.

Convolutional modules

Conv1D

class sonnet.Conv1D(output_channels, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NWC', name=None)[source]

Conv1D module.

__init__(output_channels, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NWC', name=None)[source]

Constructs a Conv1D module.

Parameters

Conv2D

class sonnet.Conv2D(output_channels, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NHWC', name=None)[source]

Conv2D module.

__init__(output_channels, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NHWC', name=None)[source]

Constructs a Conv2D module.

Parameters

Conv3D

class sonnet.Conv3D(output_channels, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NDHWC', name=None)[source]

Conv3D module.

__init__(output_channels, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NDHWC', name=None)[source]

Constructs a Conv3D module.

Parameters

Conv1DTranspose

class sonnet.Conv1DTranspose(output_channels, kernel_shape, output_shape=None, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NWC', name=None)[source]

A 1D transpose convolutional module.

__init__(output_channels, kernel_shape, output_shape=None, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NWC', name=None)[source]

Constructs a Conv1DTranspose module.

Parameters

Conv2DTranspose

class sonnet.Conv2DTranspose(output_channels, kernel_shape, output_shape=None, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NHWC', name=None)[source]

A 2D transpose convolutional module.

__init__(output_channels, kernel_shape, output_shape=None, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NHWC', name=None)[source]

Constructs a Conv2DTranspose module.

Parameters

Conv3DTranspose

class sonnet.Conv3DTranspose(output_channels, kernel_shape, output_shape=None, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NDHWC', name=None)[source]

A 3D transpose convolutional module.

__init__(output_channels, kernel_shape, output_shape=None, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NDHWC', name=None)[source]

Constructs a Conv3DTranspose module.

Parameters

DepthwiseConv2D

class sonnet.DepthwiseConv2D(kernel_shape, channel_multiplier=1, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NHWC', name=None)[source]

Spatial depth-wise 2D convolution module, including bias.

This acts as a light wrapper around the TensorFlow opstf.nn.depthwise_conv2d, abstracting away variable creation and sharing.

__init__(kernel_shape, channel_multiplier=1, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NHWC', name=None)[source]

Constructs a DepthwiseConv2D module.

Parameters

__call__(*args, **kwargs)[source]

Call self as a function.

Normalization modules

LayerNorm

class sonnet.LayerNorm(axis, create_scale, create_offset, eps=1e-05, scale_init=None, offset_init=None, data_format='channels_last', name=None)[source]

Normalizes inputs along the given axes.

This is a generic implementation of normalization along specific axes of the input. InstanceNorm is a subclass of this module, it normalizes over the spatial dimensions.

It transforms the input x into:

\[\d{outputs} = \d{scale} \dfrac{x - \mu}{\sigma + \epsilon} + \d{offset}\]

Where \(\mu\) and \(\sigma\) are respectively the mean and standard deviation of x.

There are many different variations for how users want to manage scale and offset if they require them at all. These are:

scale

If create_scale=True, a trainable tf.Variable holding the current scale.

offset

If create_offset=True, a trainable tf.Variable holding the current offset.

__init__(axis, create_scale, create_offset, eps=1e-05, scale_init=None, offset_init=None, data_format='channels_last', name=None)[source]

Constructs an LayerNorm module.

Parameters

__call__(*args, **kwargs)[source]

Call self as a function.

InstanceNorm

class sonnet.InstanceNorm(create_scale, create_offset, eps=1e-05, scale_init=None, offset_init=None, data_format='channels_last', name=None)[source]

Normalizes inputs along the spatial dimensions.

See LayerNorm for more details.

scale

If create_scale=True, a trainable tf.Variable holding the current scale.

offset

If create_offset=True, a trainable tf.Variable holding the current offset.

__init__(create_scale, create_offset, eps=1e-05, scale_init=None, offset_init=None, data_format='channels_last', name=None)[source]

Constructs an InstanceNorm module.

This method creates a module which normalizes over the spatial dimensions.

Parameters

BaseBatchNorm

class sonnet.BaseBatchNorm(create_scale, create_offset, moving_mean, moving_variance, eps=1e-05, scale_init=None, offset_init=None, data_format='channels_last', name=None)[source]

Batch normalization module.

This implements normalization across the batch and spatial dimensions. It maintains moving averages of the mean and variance which can be used to normalize at test time. The constructor is generic and requires the user to pass in objects to compute these.

At training time we use the batch statistics for that batch and these are then used to update the moving averages.

At test time we can either use the moving averages of the batch statistics (test_local_stats=False) or we can use the local statistics (test_local_stats=True).

It transforms the input x into:

\[\d{outputs} = \d{scale} \dfrac{x - \mu}{\sigma + \epsilon} + \d{offset}\]

Where \(\mu\) and \(\sigma\) are respectively the mean and standard deviation of x. Note that this module automatically uses the fused batch norm op if the data format is NHWC.

There are many different variations for how users want to manage scale and offset if they require them at all. These are:

scale

If create_scale, a trainable tf.Variable holding the current scale after the module is connected for the first time.

offset

If create_offset, a trainable tf.Variable holding the current offset after the module is connected for the first time.

__init__(create_scale, create_offset, moving_mean, moving_variance, eps=1e-05, scale_init=None, offset_init=None, data_format='channels_last', name=None)[source]

Constructs a BaseBatchNorm module.

Parameters

__call__(inputs, is_training, test_local_stats=False, scale=None, offset=None)[source]

Returns normalized inputs.

Parameters

Returns

An n-d tensor of the same shape as inputs that has been normalized.

BatchNorm

class sonnet.BatchNorm(create_scale, create_offset, decay_rate=0.999, eps=1e-05, scale_init=None, offset_init=None, data_format='channels_last', name=None)[source]

Batch normalization with exponential moving average for test statistics.

See BaseBatchNorm for details.

scale

If create_scale=True, a trainable tf.Variable holding the current scale after the module is connected for the first time.

offset

If create_offset, a trainable tf.Variable holding the current offset after the module is connected for the first time.

__init__(create_scale, create_offset, decay_rate=0.999, eps=1e-05, scale_init=None, offset_init=None, data_format='channels_last', name=None)[source]

Constructs a BatchNorm module.

Parameters

CrossReplicaBatchNorm

class sonnet.distribute.CrossReplicaBatchNorm(create_scale, create_offset, moving_mean, moving_variance, eps=1e-05, scale_init=None, offset_init=None, data_format='channels_last', name=None)[source]

Cross-replica Batch Normalization.

At every step the full batch is used to calculate the batch statistics even within a distributed setting (note only with snt.(Tpu)Replicator).

See BaseBatchNorm for details.

scale

If create_scale=True, a trainable tf.Variable holding the current scale after the module is connected for the first time.

offset

If create_offset, a trainable tf.Variable holding the current offset after the module is connected for the first time.

__init__(create_scale, create_offset, moving_mean, moving_variance, eps=1e-05, scale_init=None, offset_init=None, data_format='channels_last', name=None)[source]

Constructs a CrossReplicaBatchNorm module.

Parameters

GroupNorm

class sonnet.GroupNorm(groups, axis=slice(1, None, None), create_scale=True, create_offset=True, eps=1e-05, scale_init=None, offset_init=None, data_format='channels_last', name=None)[source]

Group normalization module.

This applies group normalization to the inputs. This involves splitting the channels into groups before calculating the mean and variance. The default behaviour is to compute the mean and variance over the spatial dimensions and the grouped channels. The mean and variance will never be computed over the created groups axis.

It transforms the input x into:

\[\d{outputs} = \d{scale} \dfrac{x - \mu}{\sigma + \epsilon} + \d{offset}\]

Where \(\mu\) and \(\sigma\) are respectively the mean and standard deviation of x.

There are many different variations for how users want to manage scale and offset if they require them at all. These are:

scale

If create_scale=True, a trainable tf.Variable holding the current scale.

offset

If create_offset=True, a trainable tf.Variable holding the current offset.

__init__(groups, axis=slice(1, None, None), create_scale=True, create_offset=True, eps=1e-05, scale_init=None, offset_init=None, data_format='channels_last', name=None)[source]

Constructs a GroupNorm module.

Parameters

__call__(*args, **kwargs)[source]

Call self as a function.

Recurrent modules

RNNCore

class sonnet.RNNCore(name=None)[source]

Base class for Recurrent Neural Network cores.

This class defines the basic functionality that every core should implement: initial_state(), used to construct an example of the core state; and __call__() which applies the core parameterized by a previous state to an input.

Cores are typically used with dynamic_unroll() andstatic_unroll() to iteratively construct an output sequence from the given input sequence.

__call__(*args, **kwargs)[source]

Call self as a function.

UnrolledRNN

class sonnet.UnrolledRNN(name=None)[source]

Base class for unrolled Recurrent Neural Networks.

This class is a generalization of RNNCore which operates on an input sequence as opposed to a single time step.

__call__(*args, **kwargs)[source]

Call self as a function.

TrainableState

class sonnet.TrainableState(initial_values, mask=None, name=None)[source]

Trainable state for an RNNCore.

The state can be constructed manually from a nest of initial values:

state = snt.TrainableState((tf.zeros([16]), tf.zeros([16])))

or automatically for a given RNNCore:

core = snt.LSTM(hidden_size=16) state = snt.TrainableState.for_core(core)

classmethod for_core(core, mask=None, name=None)[source]

Constructs a trainable state for a given RNNCore.

Parameters

Returns

A TrainableState.

__init__(initial_values, mask=None, name=None)[source]

Constructs a trainable state from initial values.

Parameters

__call__(*args, **kwargs)[source]

Call self as a function.

dynamic_unroll

sonnet.dynamic_unroll(core, input_sequence, initial_state, sequence_length=None, parallel_iterations=1, swap_memory=False)[source]

Performs a dynamic unroll of an RNN.

core = snt.LSTM(hidden_size=16) batch_size = 3 input_sequence = tf.random.uniform([1, batch_size, 2]) output_sequence, final_state = snt.dynamic_unroll( ... core, ... input_sequence, ... core.initial_state(batch_size))

An unroll corresponds to calling the core on each element of the input sequence in a loop, carrying the state through:

state = initial_state for t in range(len(input_sequence)): outputs, state = core(input_sequence[t], state)

A dynamic unroll preserves the loop structure when executed withintf.function. See static_unroll() for an unroll function which replaces a loop with its body repeated multiple times.

Parameters

Returns

Return type

A tuple with two elements

Raises

ValueError – If input_sequence is empty.

static_unroll

sonnet.static_unroll(core, input_sequence, initial_state, sequence_length=None)[source]

Performs a static unroll of an RNN.

core = snt.LSTM(hidden_size=16) batch_size = 3 input_sequence = tf.random.uniform([1, batch_size, 2]) output_sequence, final_state = snt.static_unroll( ... core, ... input_sequence, ... core.initial_state(batch_size))

An unroll corresponds to calling the core on each element of the input sequence in a loop, carrying the state through:

state = initial_state for t in range(len(input_sequence)): outputs, state = core(input_sequence[t], state)

A static unroll replaces a loop with its body repeated multiple times when executed inside tf.function:

state = initial_state outputs0, state = core(input_sequence[0], state) outputs1, state = core(input_sequence[1], state) outputs2, state = core(input_sequence[2], state) ...

See dynamic_unroll() for a loop-preserving unroll function.

Parameters

Returns

Return type

A tuple with two elements

Raises

ValueError – If input_sequence is empty or its leading dimension is not known statically.

VanillaRNN

class sonnet.VanillaRNN(hidden_size, activation=, w_i_init=None, w_h_init=None, b_init=None, dtype=tf.float32, name=None)[source]

Basic fully-connected RNN core.

Given \(x_t\) and the previous hidden state \(h_{t-1}\) the core computes

\[h_t = w_i x_t + w_h h_{t-1} + b\]

input_to_hidden

Input-to-hidden weights \(w_i\), a tensor of shape[hidden_size, hidden_size].

hidden_to_hidden

Hidden-to-hidden weights \(w_i\), a tensor of shape[input_size, hidden_size].

b

bias, a tensor or shape [hidden_size].

__init__(hidden_size, activation=, w_i_init=None, w_h_init=None, b_init=None, dtype=tf.float32, name=None)[source]

Constructs a vanilla RNN core.

Parameters

__call__(*args, **kwargs)[source]

Call self as a function.

DeepRNN

class sonnet.DeepRNN(layers, name=None)[source]

Linear chain of RNNCores or callables.

The core takes (input, prev_state) as input and passes the input through each internal module in the order they were presented, using elements from prev_state as necessary for internal RNN cores.

deep_rnn = snt.DeepRNN([ ... snt.LSTM(hidden_size=16), ... snt.LSTM(hidden_size=16), ... ])

Note that the state of a DeepRNN is always a tuple, which will contain the same number of elements as there are internal RNN cores. If no internal modules are RNN cores, the state of the DeepRNN as a whole is an empty tuple.

Wrapping non-recurrent modules into a DeepRNN can be useful to produce something API compatible with a “real” recurrent module, simplifying code that handles the cores.

__init__(layers, name=None)[source]

Constructs a DeepRNN.

Parameters

sonnet.deep_rnn_with_skip_connections(layers, concat_final_output=True, name='deep_rnn_with_skip_connections')[source]

Constructs a DeepRNN with skip connections.

Skip connections alter the dependency structure within a DeepRNN. Specifically, input to the i-th layer (i > 0) is given by a concatenation of the core’s inputs and the outputs of the (i-1)-th layer.

outputs0, ... = layers[0](inputs, ...) outputs1, ... = layers[1](tf.concat([inputs, outputs0], axis=1], ...) outputs2, ... = layers[2](tf.concat([inputs, outputs1], axis=1], ...) ...

This allows the layers to learn decoupled features.

Parameters

Return type

RNNCore

Returns

A DeepRNN with skip connections.

Raises

ValueError – If any of the layers is not an RNNCore.

sonnet.deep_rnn_with_residual_connections(layers, name='deep_rnn_with_residual_connections')[source]

Constructs a DeepRNN with residual connections.

Residual connections alter the dependency structure in a DeepRNN. Specifically, the input to the i-th intermediate layer is a sum of the original core’s inputs and the outputs of all the preceding layers (<i).

outputs0, ... = layers[0](inputs, ...) outputs0 += inputs outputs1, ... = layers[1](outputs0, ...) outputs1 += outputs0 outputs2, ... = layers[2](outputs1, ...) outputs2 += outputs1 ...

This allows the layers to learn specialized features that compose incrementally.

Parameters

Return type

RNNCore

Returns

A DeepRNN with residual connections.

Raises

ValueError – If any of the layers is not an RNNCore.

LSTM

class sonnet.LSTM(hidden_size, projection_size=None, projection_init=None, w_i_init=None, w_h_init=None, b_init=None, forget_bias=1.0, dtype=tf.float32, name=None)[source]

Long short-term memory (LSTM) RNN core.

The implementation is based on [2]. Given\(x_t\) and the previous state \((h_{t-1}, c_{t-1})\) the core computes

\[\begin{array}{ll} i_t = \sigma(W_{ii} x_t + W_{hi} h_{t-1} + b_i) \\ f_t = \sigma(W_{if} x_t + W_{hf} h_{t-1} + b_f) \\ g_t = \tanh(W_{ig} x_t + W_{hg} h_{t-1} + b_g) \\ o_t = \sigma(W_{io} x_t + W_{ho} h_{t-1} + b_o) \\ c_t = f_t c_{t-1} + i_t g_t \\ h_t = o_t \tanh(c_t) \end{array}\]

Where \(i_t\), \(f_t\), \(o_t\) are input, forget and output gate activations, and \(g_t\) is a vector of cell updates.

Notes

Forget gate initialization:

Following [3] we add a constantforget_bias (defaults to 1.0) to \(b_f\) after initialization in order to reduce the scale of forgetting in the beginning of the training.

Recurrent projections:

Hidden state could be projected (via the project_size parameter) to reduce the number of parameters and speed up computation. For more details see [4].

input_to_hidden

Input-to-hidden weights \(W_{ii}\), \(W_{if}\),\(W_{ig}\) and \(W_{io}\) concatenated into a tensor of shape[input_size, 4 * hidden_size].

hidden_to_hidden

Hidden-to-hidden weights \(W_{hi}\), \(W_{hf}\),\(W_{hg}\) and \(W_{ho}\) concatenated into a tensor of shape[hidden_size, 4 * hidden_size].

b

Biases \(b_i\), \(b_f\), \(b_g\) and \(b_o\) concatenated into a tensor of shape [4 * hidden_size].

__init__(hidden_size, projection_size=None, projection_init=None, w_i_init=None, w_h_init=None, b_init=None, forget_bias=1.0, dtype=tf.float32, name=None)[source]

Constructs an LSTM.

Parameters

__call__(*args, **kwargs)[source]

Call self as a function.

class sonnet.LSTMState(hidden, cell)

lstm_with_recurrent_dropout

sonnet.lstm_with_recurrent_dropout(hidden_size, dropout=0.5, seed=None, **kwargs)[source]

Constructs an LSTM with recurrent dropout.

The implementation is based on [5]. Dropout is applied on the previous hidden state \(h_{t-1}\) during the computation of gate activations:

\[\begin{array}{ll} i_t = \sigma(W_{ii} x_t + W_{hi} d(h_{t-1}) + b_i) \\ f_t = \sigma(W_{if} x_t + W_{hf} d(h_{t-1}) + b_f) \\ g_t = \tanh(W_{ig} x_t + W_{hg} d(h_{t-1}) + b_g) \\ o_t = \sigma(W_{io} x_t + W_{ho} d(h_{t-1}) + b_o) \end{array}\]

Parameters

Returns

Return type

A tuple of two elements

Raises

ValueError – If dropout is not in [0, 1).

UnrolledLSTM

class sonnet.UnrolledLSTM(hidden_size, w_i_init=None, w_h_init=None, b_init=None, forget_bias=1.0, dtype=tf.float32, name=None)[source]

Unrolled long short-term memory (LSTM).

The implementation uses efficient device-specialized ops, e.g. CuDNN-RNN on a CUDA-enabled GPU, and can be an order of magnitude faster thansnt.*_unroll with an LSTM core.

__init__(hidden_size, w_i_init=None, w_h_init=None, b_init=None, forget_bias=1.0, dtype=tf.float32, name=None)[source]

Construct an unrolled LSTM.

Parameters

__call__(*args, **kwargs)[source]

Call self as a function.

Conv1DLSTM

class sonnet.Conv1DLSTM(input_shape, output_channels, kernel_shape, data_format='NWC', w_i_init=None, w_h_init=None, b_init=None, forget_bias=1.0, dtype=tf.float32, name=None)[source]

1-D convolutional LSTM.

The implementation is based on [6]. Given \(x_t\) and the previous state \((h_{t-1}, c_{t-1})\)the core computes

\[\begin{array}{ll} i_t = \sigma(W_{ii} * x_t + W_{hi} * h_{t-1} + b_i) \\ f_t = \sigma(W_{if} * x_t + W_{hf} * h_{t-1} + b_f) \\ g_t = \tanh(W_{ig} * x_t + W_{hg} * h_{t-1} + b_g) \\ o_t = \sigma(W_{io} * x_t + W_{ho} * h_{t-1} + b_o) \\ c_t = f_t c_{t-1} + i_t g_t \\ h_t = o_t \tanh(c_t) \end{array}\]

where \(*\) denotes the convolution operator; \(i_t\),\(f_t\), \(o_t\) are input, forget and output gate activations, and \(g_t\) is a vector of cell updates.

Notes

Forget gate initialization:

Following [3] we add a constantforget_bias (defaults to 1.0) to \(b_f\) after initialization in order to reduce the scale of forgetting in the beginning of the training.

input_to_hidden

Input-to-hidden convolution weights \(W_{ii}\),\(W_{if}\), \(W_{ig}\) and \(W_{io}\) concatenated into a single tensor of shape [kernel_shape*, input_channels, 4 * output_channels] where kernel_shape is repeated 1 times.

hidden_to_hidden

Hidden-to-hidden convolution weights \(W_{hi}\),\(W_{hf}\), \(W_{hg}\) and \(W_{ho}\) concatenated into a single tensor of shape [kernel_shape*, input_channels, 4 * output_channels] where kernel_shape is repeated 1 times.

b

Biases \(b_i\), \(b_f\), \(b_g\) and \(b_o\) concatenated into a tensor of shape [4 * output_channels].

__init__(input_shape, output_channels, kernel_shape, data_format='NWC', w_i_init=None, w_h_init=None, b_init=None, forget_bias=1.0, dtype=tf.float32, name=None)[source]

Constructs a 1-D convolutional LSTM.

Parameters

Conv2DLSTM

class sonnet.Conv2DLSTM(input_shape, output_channels, kernel_shape, data_format='NHWC', w_i_init=None, w_h_init=None, b_init=None, forget_bias=1.0, dtype=tf.float32, name=None)[source]

2-D convolutional LSTM.

The implementation is based on [6]. Given \(x_t\) and the previous state \((h_{t-1}, c_{t-1})\)the core computes

\[\begin{array}{ll} i_t = \sigma(W_{ii} * x_t + W_{hi} * h_{t-1} + b_i) \\ f_t = \sigma(W_{if} * x_t + W_{hf} * h_{t-1} + b_f) \\ g_t = \tanh(W_{ig} * x_t + W_{hg} * h_{t-1} + b_g) \\ o_t = \sigma(W_{io} * x_t + W_{ho} * h_{t-1} + b_o) \\ c_t = f_t c_{t-1} + i_t g_t \\ h_t = o_t \tanh(c_t) \end{array}\]

where \(*\) denotes the convolution operator; \(i_t\),\(f_t\), \(o_t\) are input, forget and output gate activations, and \(g_t\) is a vector of cell updates.

Notes

Forget gate initialization:

Following [3] we add a constantforget_bias (defaults to 1.0) to \(b_f\) after initialization in order to reduce the scale of forgetting in the beginning of the training.

input_to_hidden

Input-to-hidden convolution weights \(W_{ii}\),\(W_{if}\), \(W_{ig}\) and \(W_{io}\) concatenated into a single tensor of shape [kernel_shape*, input_channels, 4 * output_channels] where kernel_shape is repeated 2 times.

hidden_to_hidden

Hidden-to-hidden convolution weights \(W_{hi}\),\(W_{hf}\), \(W_{hg}\) and \(W_{ho}\) concatenated into a single tensor of shape [kernel_shape*, input_channels, 4 * output_channels] where kernel_shape is repeated 2 times.

b

Biases \(b_i\), \(b_f\), \(b_g\) and \(b_o\) concatenated into a tensor of shape [4 * output_channels].

__init__(input_shape, output_channels, kernel_shape, data_format='NHWC', w_i_init=None, w_h_init=None, b_init=None, forget_bias=1.0, dtype=tf.float32, name=None)[source]

Constructs a 2-D convolutional LSTM.

Parameters

Conv3DLSTM

class sonnet.Conv3DLSTM(input_shape, output_channels, kernel_shape, data_format='NDHWC', w_i_init=None, w_h_init=None, b_init=None, forget_bias=1.0, dtype=tf.float32, name=None)[source]

3-D convolutional LSTM.

The implementation is based on [6]. Given \(x_t\) and the previous state \((h_{t-1}, c_{t-1})\)the core computes

\[\begin{array}{ll} i_t = \sigma(W_{ii} * x_t + W_{hi} * h_{t-1} + b_i) \\ f_t = \sigma(W_{if} * x_t + W_{hf} * h_{t-1} + b_f) \\ g_t = \tanh(W_{ig} * x_t + W_{hg} * h_{t-1} + b_g) \\ o_t = \sigma(W_{io} * x_t + W_{ho} * h_{t-1} + b_o) \\ c_t = f_t c_{t-1} + i_t g_t \\ h_t = o_t \tanh(c_t) \end{array}\]

where \(*\) denotes the convolution operator; \(i_t\),\(f_t\), \(o_t\) are input, forget and output gate activations, and \(g_t\) is a vector of cell updates.

Notes

Forget gate initialization:

Following [3] we add a constantforget_bias (defaults to 1.0) to \(b_f\) after initialization in order to reduce the scale of forgetting in the beginning of the training.

input_to_hidden

Input-to-hidden convolution weights \(W_{ii}\),\(W_{if}\), \(W_{ig}\) and \(W_{io}\) concatenated into a single tensor of shape [kernel_shape*, input_channels, 4 * output_channels] where kernel_shape is repeated 3 times.

hidden_to_hidden

Hidden-to-hidden convolution weights \(W_{hi}\),\(W_{hf}\), \(W_{hg}\) and \(W_{ho}\) concatenated into a single tensor of shape [kernel_shape*, input_channels, 4 * output_channels] where kernel_shape is repeated 3 times.

b

Biases \(b_i\), \(b_f\), \(b_g\) and \(b_o\) concatenated into a tensor of shape [4 * output_channels].

__init__(input_shape, output_channels, kernel_shape, data_format='NDHWC', w_i_init=None, w_h_init=None, b_init=None, forget_bias=1.0, dtype=tf.float32, name=None)[source]

Constructs a 3-D convolutional LSTM.

Parameters

GRU

class sonnet.GRU(hidden_size, w_i_init=None, w_h_init=None, b_init=None, dtype=tf.float32, name=None)[source]

Gated recurrent unit (GRU) RNN core.

The implementation is based on [7]. Given\(x_t\) and the previous state \(h_{t-1}\) the core computes

\[\begin{array}{ll} z_t &= \sigma(W_{iz} x_t + W_{hz} h_{t-1} + b_z) \\ r_t &= \sigma(W_{ir} x_t + W_{hr} h_{t-1} + b_r) \\ a_t &= \tanh(W_{ia} x_t + W_{ha} (r_t h_{t-1}) + b_a) \\ h_t &= (1 - z_t) h_{t-1} + z_t a_t \end{array}\]

where \(z_t\) and \(r_t\) are reset and update gates.

input_to_hidden

Input-to-hidden weights \(W_{iz}\), \(W_{ir}\)and \(W_{ia}\) concatenated into a tensor of shape[input_size, 3 * hidden_size].

hidden_to_hidden

Hidden-to-hidden weights \(W_{hz}\), \(W_{hr}\)and \(W_{ha}\) concatenated into a tensor of shape[hidden_size, 3 * hidden_size].

b

Biases \(b_z\), \(b_r\) and \(b_a\) concatenated into a tensor of shape [3 * hidden_size].

__init__(hidden_size, w_i_init=None, w_h_init=None, b_init=None, dtype=tf.float32, name=None)[source]

Constructs a GRU.

Parameters

__call__(*args, **kwargs)[source]

Call self as a function.

Batch

reshape

sonnet.reshape(inputs, output_shape, preserve_dims=1, name=None)[source]

A shortcut for applying Reshape to the inputs.

Return type

Tensor

Reshape

class sonnet.Reshape(output_shape, preserve_dims=1, name=None)[source]

Reshapes input Tensor, preserving the batch dimension.

For example, given an input tensor with shape [B, H, W, C, D]:

B, H, W, C, D = range(1, 6) x = tf.ones([B, H, W, C, D])

The default behavior when output_shape is (-1, D) is to flatten all dimensions between B and D:

mod = snt.Reshape(output_shape=(-1, D)) assert mod(x).shape == [B, HWC, D]

You can change the number of preserved leading dimensions viapreserve_dims:

mod = snt.Reshape(output_shape=(-1, D), preserve_dims=2) assert mod(x).shape == [B, H, W*C, D]

mod = snt.Reshape(output_shape=(-1, D), preserve_dims=3) assert mod(x).shape == [B, H, W, C, D]

mod = snt.Reshape(output_shape=(-1, D), preserve_dims=4) assert mod(x).shape == [B, H, W, C, 1, D]

__init__(output_shape, preserve_dims=1, name=None)[source]

Constructs a Reshape module.

Parameters

Raises

ValueError – If preserve_dims is not positive.

__call__(*args, **kwargs)[source]

Call self as a function.

reversed(name=None)[source]

Returns inverse batch reshape.

Return type

Reshape

flatten

sonnet.flatten(inputs, name='flatten')[source]

A shortcut for applying Flatten to the inputs.

Return type

Tensor

Flatten

class sonnet.Flatten(preserve_dims=1, name=None)[source]

Flattens the input Tensor, preserving the batch dimension(s).

Flatten reshapes input tensors to combine all trailing dimensions apart from the first. Additional leading dimensions can be preserved by setting the preserve_dims parameter.

See Reshape for more details.

__init__(preserve_dims=1, name=None)[source]

Constructs a Flatten module.

Parameters

BatchApply

class sonnet.BatchApply(module, num_dims=2, name=None)[source]

Merges a number of leading dimensions of an input tensor to manipulate it.

Merges a number of leading dimensions of a tensor into a single dimension, connects the provided module, then splits the leading dimension of the result to match the input.

Input tensors whose rank is smaller than the number of dimensions to collapse (e.g. all scalar values, which are tensors of rank 0), are passed unaltered to the provided module.

This is useful for applying some module to each timestep of a Time x Batch x N tensor. If a module is hard coded to only support 2D (Batch x N) then the full 3D Tensor cannot be provided. BatchApply will ‘merge’ the first two dimensions of the sequence tensor by reshaping to a (Time * Batch) x N Tensor, and then the internal module can be applied. The result of that operation is reshaped such that its first dimensions are split to match the leading dimensions of the input.

__init__(module, num_dims=2, name=None)[source]

Initializes the current module with the given name.

Subclasses should call this constructor before creating other modules or variables such that those modules are named correctly.

Parameters

name (Optional[str]) – An optional string name for the class. Must be a valid Python identifier. If name is not provided then the class name for the current instance is converted to lower_snake_case and used instead.

__call__(*args, **kwargs)[source]

Call self as a function.

Embedding modules

Embed

class sonnet.Embed(vocab_size=None, embed_dim=None, existing_vocab=None, densify_gradients=False, initializer=None, trainable=True, dtype=tf.float32, name=None)[source]

Module for embedding tokens in a low-dimensional space.

__init__(vocab_size=None, embed_dim=None, existing_vocab=None, densify_gradients=False, initializer=None, trainable=True, dtype=tf.float32, name=None)[source]

Constructs an Embed module.

Parameters

Raises

ValueError – if neither one of vocab_size or existing_vocab is provided, or if existing_vocab is provided along withvocab_size, embedding_dim, initializer (as these should be inferred).

__call__(*args, **kwargs)[source]

Call self as a function.

Optimizers

Sonnet optimizers built for TensorFlow 2.

All optimizers implement the snt.Optimizer interface.

Adam

class sonnet.optimizers.Adam(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-08, name=None)[source]

Adaptive Moment Estimation (Adam) optimizer.

Adam is an algorithm for first-order gradient-based optimization of stochastic objective functions, based on adaptive estimates of lower-order moments. See[8] for more details.

Note: default parameter values have been taken from the paper.

learning_rate

Step size (alpha in the paper).

beta1

Exponential decay rate for first moment estimate.

beta2

Exponential decay rate for second moment estimate.

epsilon

Small value to avoid zero denominator.

step

Step count.

m

Biased first moment estimate (a list with one value per parameter).

v

Biased second raw moment estimate (a list with one value per parameter).

__init__(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-08, name=None)[source]

Constructs an Adam module.

Parameters

Momentum

class sonnet.optimizers.Momentum(learning_rate, momentum, use_nesterov=False, name=None)[source]

SGD with Momentum module.

learning_rate

Learning rate.

momentum

Momentum scalar.

use_nesterov

True if using Nesterov momentum.

accumulated_momentum

Accumulated momentum for each parameter.

__init__(learning_rate, momentum, use_nesterov=False, name=None)[source]

Constructs a Momentum module.

Parameters

RMSProp

class sonnet.optimizers.RMSProp(learning_rate, decay=0.9, momentum=0.0, epsilon=1e-10, centered=False, name=None)[source]

RMSProp module.

See: http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf

Maintain a moving (discounted) average of the square of updates. Divides each update by the root of this average.

ms <- decay * ms + (1-decay) * update^2 mom <- momentum * mom + learning_rate * update / sqrt(ms + epsilon) parameter <- parameter - mom

This implementation of RMSprop uses plain momentum, not Nesterov momentum.

The centered version additionally maintains a moving average of the gradients, and uses that average to estimate the variance:

mg <- decay * mg + (1-decay) * update ms <- decay * ms + (1-decay) * update^2 mom <- momentum * mom + learning_rate * update / sqrt(ms - mg^2 + epsilon) parameter <- parameter - mom

learning_rate

Learning rate.

decay

Learning rate decay over each update.

momentum

Momentum scalar.

epsilon

Small value to avoid zero denominator.

centered

True if centered.

mom

Accumulated mom for each parameter.

ms

Accumulated ms for each parameter.

mg

Accumulated mg for each parameter.

__init__(learning_rate, decay=0.9, momentum=0.0, epsilon=1e-10, centered=False, name=None)[source]

Constructs an RMSProp module.

Parameters

SGD

class sonnet.optimizers.SGD(learning_rate, name=None)[source]

Stochastic Gradient Descent (SGD) module.

learning_rate

Learning rate.

__init__(learning_rate, name=None)[source]

Constructs an SGD module.

Parameters

Initializers

Initializers.

Initializer

class sonnet.initializers.Initializer[source]

Initializer base class, all initializers must implement a call method.

abstract __call__(shape, dtype)[source]

Returns a tensor of the given shape and dtype.

Return type

Tensor

Constant

class sonnet.initializers.Constant(value)[source]

Initializer that generates tensors initialized to the given value.

__init__(value)[source]

__call__(shape, dtype)[source]

Returns a tensor of the given shape and dtype.

Return type

Tensor

Identity

class sonnet.initializers.Identity(gain=1.0)[source]

Initializer that generates the identity matrix.

Constructs a 2D identity matrix or batches of these.

__init__(gain=1.0)[source]

Constructs an identity initializer.

Parameters

gain (float) – Multiplicative factor to apply to the identity matrix.

__call__(shape, dtype)[source]

Returns a tensor of the given shape and dtype.

Return type

Tensor

Ones

class sonnet.initializers.Ones[source]

Initializer that generates tensors initialized to 1.

__call__(shape, dtype)[source]

Returns a tensor of the given shape and dtype.

Return type

Tensor

Orthogonal

class sonnet.initializers.Orthogonal(gain=1.0, seed=None)[source]

Initializer that generates an orthogonal matrix.

NOTE: Does not support 1D tensors.

The implementation is based on [9].

If the shape of the tensor to initialize is two-dimensional, it is initialized with an orthogonal matrix obtained from the QR decomposition of a matrix of random numbers drawn from a normal distribution. If the matrix has fewer rows than columns then the output will have orthogonal rows. Otherwise, the output will have orthogonal columns.

If the shape of the tensor to initialize is more than two-dimensional, a matrix of shape (shape[0] * ... * shape[n - 2], shape[n - 1])is initialized, where n is the length of the shape vector. The matrix is subsequently reshaped to give a tensor of the desired shape.

__init__(gain=1.0, seed=None)[source]

Constructs an orthogonal initializer.

Parameters

__call__(shape, dtype)[source]

Returns a tensor of the given shape and dtype.

Return type

Tensor

RandomNormal

class sonnet.initializers.RandomNormal(mean=0.0, stddev=1.0, seed=None)[source]

Initializer that generates tensors with a normal distribution.

__init__(mean=0.0, stddev=1.0, seed=None)[source]

Constructs a random normal initializer.

Parameters

__call__(shape, dtype)[source]

Returns a tensor of the given shape and dtype.

Return type

Tensor

RandomUniform

class sonnet.initializers.RandomUniform(minval=0, maxval=1, seed=None)[source]

Initializer that generates tensors with a uniform distribution.

The generated values follow a uniform distribution in the range[minval, maxval).

__init__(minval=0, maxval=1, seed=None)[source]

Constructs a random uniform initializer.

Parameters

__call__(shape, dtype)[source]

Returns a tensor of the given shape and dtype.

TruncatedNormal

class sonnet.initializers.TruncatedNormal(mean=0.0, stddev=1.0, seed=None)[source]

Initializer that generates a truncated normal distribution.

These values follow a normal distribution except that values more than two standard deviations from the mean are discarded and re-drawn. This is the recommended initializer for neural network weights and filters.

__init__(mean=0.0, stddev=1.0, seed=None)[source]

Constructs a truncated normal initializer.

Parameters

__call__(shape, dtype)[source]

Returns a tensor of the given shape and dtype.

VarianceScaling

class sonnet.initializers.VarianceScaling(scale=1.0, mode='fan_in', distribution='truncated_normal', seed=None)[source]

Initializer capable of adapting its scale to the shape of weights tensors.

With distribution="truncated_normal" or "normal", samples are drawn from a distribution with a mean of zero and a standard deviation (after truncation, if used) stddev = sqrt(scale / n)where n is:

Note that for transposed convolution the mode selected should be reversed. For number of input units use fan_out and for number of output unitsfan_in.

With distribution=uniform, samples are drawn from a uniform distribution within [-limit, limit], with limit = sqrt(3 * scale / n).

The variance scaling initializer can be configured to generate other standard initializers using the scale, mode and distribution arguments. Here are some example configurations:

Name Parameters
glorot_uniform scale=1.0, mode=``fan_avg``, distribution=``uniform``
glorot_normal scale=1.0, mode=``fan_avg``, distribution=``truncated_normal``
lecun_uniform scale=1.0, mode=``fan_in``, distribution=``uniform``
lecun_normal scale=1.0, mode=``fan_in``, distribution=``truncated_normal``
he_uniform scale=2.0, mode=``fan_in``, distribution=``uniform``
he_normal scale=2.0, mode=``fan_in``, distribution=``truncated_normal``

__init__(scale=1.0, mode='fan_in', distribution='truncated_normal', seed=None)[source]

Constructs a variance scaling initalizer.

Parameters

Raises

ValueError – In case of an invalid value for the scale, mode ordistribution arguments.

__call__(shape, dtype)[source]

Returns a tensor of the given shape and dtype.

Return type

Tensor

Zeros

class sonnet.initializers.Zeros[source]

Initializer that generates tensors initialized to 0.

__call__(shape, dtype)[source]

Returns a tensor of the given shape and dtype.

Return type

Tensor

Regularizers

Regularizers.

Regularizer

class sonnet.regularizers.Regularizer[source]

Base regularizer class.

abstract __call__(tensors)[source]

Apply a regularizer.

Parameters

tensors (Sequence[Tensor]) – A sequence of tensors to regularize.

Return type

Tensor

Returns

Combined regularization loss for the given tensors.

L1

class sonnet.regularizers.L1(scale)[source]

L1 regularizer.

reg = snt.regularizers.L1(0.01) reg([tf.constant([1.0, 2.0, 3.0])]) <tf.Tensor: ...>

__init__(scale)[source]

Create an L1 regularizer.

Parameters

scale (Union[float, floating, ndarray, Tensor, Variable]) – A non-negative regularization factor.

Raises

ValueError – if scale is <0.

__call__(tensors)[source]

See base class.

Return type

Tensor

L2

class sonnet.regularizers.L2(scale)[source]

L2 regularizer.

reg = snt.regularizers.L2(0.01) reg([tf.constant([1.0, 2.0, 3.0])]) <tf.Tensor: ...>

__init__(scale)[source]

Create an L2 regularizer.

Parameters

scale (Union[float, floating, ndarray, Tensor, Variable]) – float or scalar tensor; regularization factor.

Raises

ValueError – if scale is <0.

__call__(tensors)[source]

See base class.

Return type

Tensor

OffDiagonalOrthogonal

class sonnet.regularizers.OffDiagonalOrthogonal(scale)[source]

Off-diagonal orthogonal regularizer.

The implementation is based on https://arxiv.org/abs/1809.11096. Given a rank N >= 2 tensor, the regularizer computes the sum of off-diagonal entries of (W^T W)^2 where

NB: that is equivalent to computing the off-diagonal sum of (W^T W - I)^2, as off-diagonal entries of I are 0.

For example,

t = tf.reshape(tf.range(8, dtype=tf.float32), [2, 2, 2]) reg = snt.regularizers.OffDiagonalOrthogonal(0.01) reg([t]) <tf.Tensor: ...>

corresponds to copmuting

w = tf.reshape(t, [-1, 2]) w_gram_sq = tf.square(tf.matmul(tf.transpose(w), w)) 0.01 * (tf.reduce_sum(w_gram_sq) - tf.linalg.trace(w_gram_sq)) <tf.Tensor: ...>

__init__(scale)[source]

Create an off-diagonal orthogonal regularizer.

Parameters

scale (Union[float, floating, ndarray, Tensor, Variable]) – A non-negative regularization factor.

Raises

ValueError – if scale is <0.

__call__(tensors)[source]

See base class.

Return type

Tensor

Paddings

Paddings.

causal

sonnet.pad.causal(effective_kernel_size)[source]

Pre-padding such that output has no dependence on the future.

create

sonnet.pad.create(padding, kernel, rate, n, channel_index)[source]

Generates the padding required for a given padding algorithm.

Parameters

Returns

A list of length n+2 containing the padding for each element. These are of the form [pad_before, pad_after].

full

sonnet.pad.full(effective_kernel_size)[source]

Maximal padding whilst not convolving over just padded elements.

reverse_causal

sonnet.pad.reverse_causal(effective_kernel_size)[source]

Post-padding such that output has no dependence on the past.

same

sonnet.pad.same(effective_kernel_size)[source]

Pads such that the output size matches input size for stride=1.

valid

sonnet.pad.valid(effective_kernel_size)[source]

No padding.

Distribution

Utilities for using Sonnet with TensorFlow Distribution Strategy.

Replicator

class sonnet.distribute.Replicator(devices=None, cross_device_ops=None)[source]

Replicates input, parameters and compute over multiple accelerators.

Replicator is a TensorFlow “Distribution Strategy” implementing the programming model described in the TF-Replicator paper[10] and TensorFlow RFC[11]. Replicator enables data-parallel training across multiple accelerators on a single machine, it supports eager execution and tf.function.

To get started create a Replicator instance:

replicator = snt.distribute.Replicator()

Replicator provides a scope inside which any new tf.Variables will be replicated across all local devices:

with replicator.scope(): ... mod = snt.Linear(32)

Additionally replicator provides utility functions to apply a module in parallel on multiple devices. First we need to define some computation that runs on each GPU. The “replica context” object provides us a way to communicate between replicas (e.g. to perform an all_reduce):

def forward(): ... # Compute a random output on each GPU. ... x = tf.random.normal([8, 28 * 28]) ... y = mod(x) ... # Synchronize the value of y between all GPUs. ... ctx = tf.distribute.get_replica_context() ... y = ctx.all_reduce("mean", y) ... return y

Finally we use the run API to apply forward in parallel on all accelerator devices:

per_replica_y = replicator.run(forward)

scope()[source]

Context manager to make the strategy current and distribute variables.

This method returns a context manager, and is used as follows:

strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])

Variable created inside scope:

with strategy.scope(): ... mirrored_variable = tf.Variable(1.) mirrored_variable MirroredVariable:{ 0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>, 1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=1.0> }

Variable created outside scope:

regular_variable = tf.Variable(1.) regular_variable <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>

_What happens when Strategy.scope is entered?_

Note: Entering a scope does not automatically distribute a computation, except

in the case of high level training framework like keras model.fit. If you’re not using model.fit, you need to use strategy.run API to explicitly distribute that computation. See an example in the [custom training loop tutorial](https://www.tensorflow.org/tutorials/distribute/custom_training).

_What should be in scope and what should be outside?_

There are a number of requirements on what needs to happen inside the scope. However, in places where we have information about which strategy is in use, we often enter the scope for the user, so they don’t have to do it explicitly (i.e. calling those either inside or outside the scope is OK).

Returns

A context manager.

TpuReplicator

class sonnet.distribute.TpuReplicator(tpu_cluster_resolver=None, experimental_device_assignment=None, experimental_spmd_xla_partitioning=False)[source]

Replicates input, parameters and compute over multiple TPUs.

TpuReplicator is a TensorFlow “Distribution Strategy” implementing the programming model described in the TF-Replicator paper[10] and TensorFlow RFC[11]. TpuReplicator enables data-parallel training across multiple TPUs on one or more machines, it supportstf.function.

To get started create a TpuReplicator instance:

replicator = snt.distribute.TpuReplicator()

This provides a scope inside which any new tf.Variables will be replicated across all TPU cores:

with replicator.scope(): ... mod = snt.Linear(32)

Additionally replicator provides utility functions to apply a module in parallel on multiple devices. First we need to define some computation that runs on each TPU. The “replica context” object provides us a way to communicate between replicas:

def forward(): ... # Compute a random output on each GPU. ... x = tf.random.normal([8, 28 * 28]) ... y = mod(x) ... # Synchronize the value of y between all GPUs. ... ctx = tf.distribute.get_replica_context() ... y = ctx.all_reduce("mean", y) ... return y

Finally we use the run API to apply forward in parallel on all TPU devices. This must be run as part of a tf.function since TpuReplicatoruses XLA to compile and replicate our function to run in parallel over all TPU cores:

@tf.function(autograph=False) ... def all_forward(): ... return replicator.run(forward) per_replica_y = all_forward()

scope()[source]

Context manager to make the strategy current and distribute variables.

This method returns a context manager, and is used as follows:

strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])

Variable created inside scope:

with strategy.scope(): ... mirrored_variable = tf.Variable(1.) mirrored_variable MirroredVariable:{ 0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>, 1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=1.0> }

Variable created outside scope:

regular_variable = tf.Variable(1.) regular_variable <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>

_What happens when Strategy.scope is entered?_

Note: Entering a scope does not automatically distribute a computation, except

in the case of high level training framework like keras model.fit. If you’re not using model.fit, you need to use strategy.run API to explicitly distribute that computation. See an example in the [custom training loop tutorial](https://www.tensorflow.org/tutorials/distribute/custom_training).

_What should be in scope and what should be outside?_

There are a number of requirements on what needs to happen inside the scope. However, in places where we have information about which strategy is in use, we often enter the scope for the user, so they don’t have to do it explicitly (i.e. calling those either inside or outside the scope is OK).

Returns

A context manager.

Metrics

Metric

class sonnet.Metric(name=None)[source]

Metric base class.

property value

Returns the current value of the metric.

__call__(*args, **kwargs)[source]

Call self as a function.

Mean

class sonnet.Mean(name=None)[source]

Calculates the element-wise mean of the given values.

__init__(name=None)[source]

Initializes the current module with the given name.

Subclasses should call this constructor before creating other modules or variables such that those modules are named correctly.

Parameters

name (Optional[str]) – An optional string name for the class. Must be a valid Python identifier. If name is not provided then the class name for the current instance is converted to lower_snake_case and used instead.

property value

See base class.

Sum

class sonnet.Sum(name=None)[source]

Calculates the element-wise sum of the given values.

__init__(name=None)[source]

Initializes the current module with the given name.

Subclasses should call this constructor before creating other modules or variables such that those modules are named correctly.

Parameters

name (Optional[str]) – An optional string name for the class. Must be a valid Python identifier. If name is not provided then the class name for the current instance is converted to lower_snake_case and used instead.

property value

See base class.

Nets

Common network architectures implemented as Sonnet modules.

MLP

class sonnet.nets.MLP(output_sizes, w_init=None, b_init=None, with_bias=True, activation=, dropout_rate=None, activate_final=False, name=None)[source]

A multi-layer perceptron module.

__init__(output_sizes, w_init=None, b_init=None, with_bias=True, activation=, dropout_rate=None, activate_final=False, name=None)[source]

Constructs an MLP.

Parameters

Raises

ValueError – If with_bias is False and b_init is not None.

__call__(*args, **kwargs)[source]

Call self as a function.

Cifar10ConvNet

class sonnet.nets.Cifar10ConvNet(num_classes=10, w_init=None, b_init=None, data_format='NHWC', output_channels=(64, 64, 128, 128, 128, 256, 256, 256, 512, 512, 512), strides=(1, 1, 2, 1, 1, 2, 1, 1, 2, 1, 1), name=None)[source]

Convolutional network designed for Cifar10.

Approximately equivalent to “VGG, minus max pooling, plus BatchNorm”. For best results the input data should be scaled to be between -1 and 1 when using the standard initializers.

__init__(num_classes=10, w_init=None, b_init=None, data_format='NHWC', output_channels=(64, 64, 128, 128, 128, 256, 256, 256, 512, 512, 512), strides=(1, 1, 2, 1, 1, 2, 1, 1, 2, 1, 1), name=None)[source]

Initializes the current module with the given name.

Subclasses should call this constructor before creating other modules or variables such that those modules are named correctly.

Parameters

name (Optional[str]) – An optional string name for the class. Must be a valid Python identifier. If name is not provided then the class name for the current instance is converted to lower_snake_case and used instead.

__call__(*args, **kwargs)[source]

Call self as a function.

ResNet

class sonnet.nets.ResNet(blocks_per_group_list, num_classes, bn_config=None, resnet_v2=False, channels_per_group_list=(256, 512, 1024, 2048), name=None)[source]

ResNet model.

__init__(blocks_per_group_list, num_classes, bn_config=None, resnet_v2=False, channels_per_group_list=(256, 512, 1024, 2048), name=None)[source]

Constructs a ResNet model.

Parameters

__call__(*args, **kwargs)[source]

Call self as a function.

ResNet50

class sonnet.nets.ResNet50(num_classes, bn_config=None, resnet_v2=False, name=None)[source]

ResNet50 module.

__init__(num_classes, bn_config=None, resnet_v2=False, name=None)[source]

Constructs a ResNet model.

Parameters

VectorQuantizer

class sonnet.nets.VectorQuantizer(embedding_dim, num_embeddings, commitment_cost, dtype=tf.float32, name='vector_quantizer')[source]

Sonnet module representing the VQ-VAE layer.

Implements the algorithm presented in ‘Neural Discrete Representation Learning’ by van den Oord et al.https://arxiv.org/abs/1711.00937

Input any tensor to be quantized. Last dimension will be used as space in which to quantize. All other dimensions will be flattened and will be seen as different examples to quantize.

The output tensor will have the same shape as the input.

For example a tensor with shape [16, 32, 32, 64] will be reshaped into [16384, 64] and all 16384 vectors (each of 64 dimensions) will be quantized independently.

embedding_dim

integer representing the dimensionality of the tensors in the quantized space. Inputs to the modules must be in this format as well.

num_embeddings

integer, the number of vectors in the quantized space.

commitment_cost

scalar which controls the weighting of the loss terms (see equation 4 in the paper - this variable is Beta).

__init__(embedding_dim, num_embeddings, commitment_cost, dtype=tf.float32, name='vector_quantizer')[source]

Initializes a VQ-VAE module.

Parameters

__call__(*args, **kwargs)[source]

Call self as a function.

VectorQuantizerEMA

class sonnet.nets.VectorQuantizerEMA(*args, **kwargs)[source]

Sonnet module representing the VQ-VAE layer.

Implements a slightly modified version of the algorithm presented in ‘Neural Discrete Representation Learning’ by van den Oord et al.https://arxiv.org/abs/1711.00937

The difference between VectorQuantizerEMA and VectorQuantizer is that this module uses exponential moving averages to update the embedding vectors instead of an auxiliary loss. This has the advantage that the embedding updates are independent of the choice of optimizer (SGD, RMSProp, Adam, K-Fac, …) used for the encoder, decoder and other parts of the architecture. For most experiments the EMA version trains faster than the non-EMA version.

Input any tensor to be quantized. Last dimension will be used as space in which to quantize. All other dimensions will be flattened and will be seen as different examples to quantize.

The output tensor will have the same shape as the input.

For example a tensor with shape [16, 32, 32, 64] will be reshaped into [16384, 64] and all 16384 vectors (each of 64 dimensions) will be quantized independently.

embedding_dim

integer representing the dimensionality of the tensors in the quantized space. Inputs to the modules must be in this format as well.

num_embeddings

integer, the number of vectors in the quantized space.

commitment_cost

scalar which controls the weighting of the loss terms (see equation 4 in the paper).

decay

float, decay for the moving averages.

epsilon

small float constant to avoid numerical instability.

__init__(embedding_dim, num_embeddings, commitment_cost, decay, epsilon=1e-05, dtype=tf.float32, name='vector_quantizer_ema')[source]

Initializes a VQ-VAE EMA module.

Parameters

__call__(*args, **kwargs)[source]

Call self as a function.

Mixed Precision

Sonnet mixed precision built for TensorFlow 2.

modes

sonnet.mixed_precision.modes(valid_types)[source]

Decorate a function to cast inputs/outputs to different precision.

support_modes = snt.mixed_precision.modes([tf.float32, tf.float16]) snt.Linear.call = support_modes(snt.Linear.call) mod = snt.Linear(10) snt.mixed_precision.enable(tf.float16) y = mod(tf.ones([1, 1])) # First call will be done in F32. y = mod(tf.ones([1, 1])) # MatMul/Add will be done in F16. snt.mixed_precision.disable()

Parameters

Returns

A decorator that will cast the inputs and outputs of the decorated function according to the global mixed precision policy and the functions eligibility for mixed precision.

enable

sonnet.mixed_precision.enable(dtype)[source]

Set the mixed precision mode.

Parameters

dtype – type to cast to.

disable

sonnet.mixed_precision.disable()[source]

Disable mixed precision training.

scope

sonnet.mixed_precision.scope(dtype)[source]

Temporarily set the global mixed precision type to dtype.

The global type is reset to its original value when the context is exited.:

snt.mixed_precision.enable(tf.float32) support_modes = snt.mixed_precision.modes([tf.float32, tf.float16]) snt.Linear.call = support_modes(snt.Linear.call) mod = snt.Linear(10)

with snt.mixed_precision.scope(tf.float16): y = mod(tf.ones([1, 1])) # First call will be done in F32. y = mod(tf.ones([1, 1])) # MatMul/Add will be done in F16. y = mod(tf.ones([1, 1])) # Outside the scope will be done in F32.

Parameters

dtype (DType) – type to set the mixed precision mode to.

Yields

Nothing. This is required for contextlib.contextmanager.

References

1

Ashish Agarwal, David Berthelot, Tom Hennigan, Alex Passos, and Malcolm Reynolds. Stateful containers with tf.Module. TensorFlow Community RFCs, Google / DeepMind, 2019. URL: https://github.com/tensorflow/community/pull/56.

2

Wojciech Zaremba, Ilya Sutskever, and Oriol Vinyals. Recurrent neural network regularization. arXiv preprint arXiv:1409.2329, 2014. URL: https://arxiv.org/abs/1409.2329.

3(1,2,3,4)

Rafal Jozefowicz, Wojciech Zaremba, and Ilya Sutskever. An empirical exploration of recurrent network architectures. In International Conference on Machine Learning, 2342–2350. 2015.

4

Haşim Sak, Andrew Senior, and Françoise Beaufays. Long short-term memory based recurrent neural network architectures for large vocabulary speech recognition. arXiv preprint arXiv:1402.1128, 2014. URL: https://arxiv.org/abs/1402.1128.

5

Yarin Gal and Zoubin Ghahramani. A theoretically grounded application of dropout in recurrent neural networks. In Advances in neural information processing systems, 1019–1027. 2016.

6(1,2,3)

SHI Xingjian, Zhourong Chen, Hao Wang, Dit-Yan Yeung, Wai-Kin Wong, and Wang-chun Woo. Convolutional lstm network: a machine learning approach for precipitation nowcasting. In Advances in neural information processing systems, 802–810. 2015.

7

Junyoung Chung, Caglar Gulcehre, KyungHyun Cho, and Yoshua Bengio. Empirical evaluation of gated recurrent neural networks on sequence modeling. arXiv preprint arXiv:1412.3555, 2014. URL: https://arxiv.org/abs/1412.3555.

8

Diederik P. Kingma and Jimmy Ba. Adam: a method for stochastic optimization. 2014. arXiv:1412.6980.

9

Andrew M Saxe, James L McClelland, and Surya Ganguli. Exact solutions to the nonlinear dynamics of learning in deep linear neural networks. arXiv preprint arXiv:1312.6120, 2013. URL: https://arxiv.org/abs/1312.6120.

10(1,2)

Peter Buchlovsky, David Budden, Dominik Grewe, Chris Jones, John Aslanides, Frederic Besse, Andy Brock, Aidan Clark, Sergio Gómez Colmenarejo, Aedan Pope, and others. TF-Replicator: Distributed machine learning for researchers. arXiv preprint arXiv:1902.00465, 2019. URL: https://arxiv.org/abs/1902.00465.

11(1,2)

Peter Buchlovsky, Dominik Grewe, Priya Gupta, Tom Hennigan, Jonathan Hseu, Chris Jones, and Josh Levenberg. Distribution Strategy - Revised API. TensorFlow Community RFCs, Google / DeepMind, 2018. URL: https://github.com/tensorflow/community/pull/25.