Stochastic (original) (raw)

Stochastic#

class flax.nnx.Dropout(self, /, *args, **kwargs)[source]#

Create a dropout layer.

To use dropout, call the train() method (or pass indeterministic=False in the constructor or during call time).

To disable dropout, call the eval() method (or pass indeterministic=True in the constructor or during call time).

Example usage:

from flax import nnx import jax.numpy as jnp

class MLP(nnx.Module): ... def init(self, rngs): ... self.linear = nnx.Linear(in_features=3, out_features=4, rngs=rngs) ... self.dropout = nnx.Dropout(0.5, rngs=rngs) ... def call(self, x): ... x = self.linear(x) ... x = self.dropout(x) ... return x

model = MLP(rngs=nnx.Rngs(0)) x = jnp.ones((1, 3))

model.train() # use dropout model(x) Array([[ 0. , 0. , -1.592019 , -2.5238838]], dtype=float32)

model.eval() # don't use dropout model(x) Array([[ 1.0533503, -1.2679932, -0.7960095, -1.2619419]], dtype=float32)

Parameters