jax.nn.one_hot — JAX documentation (original) (raw)
jax.nn.one_hot#
jax.nn.one_hot(x, num_classes, *, dtype=<class 'jax.numpy.float64'>, axis=-1)[source]#
One-hot encodes the given indices.
Each index in the input x
is encoded as a vector of zeros of lengthnum_classes
with the element at index
set to one:
jax.nn.one_hot(jnp.array([0, 1, 2]), 3) Array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], dtype=float32)
Indices outside the range [0, num_classes) will be encoded as zeros:
jax.nn.one_hot(jnp.array([-1, 3]), 3) Array([[0., 0., 0.], [0., 0., 0.]], dtype=float32)
Parameters:
- x (Any) – A tensor of indices.
- num_classes (int) – Number of classes in the one-hot dimension.
- dtype (Any) – optional, a float dtype for the returned values (default
jnp.float_
). - axis (int | AxisName) – the axis or axes along which the function should be computed.
Return type: