one_hot#
- class brainpy.math.one_hot(x, num_classes, *, dtype=None, axis=-1)[source]#
One-hot encodes the given indicies.
Each index in the input
x
is encoded as a vector of zeros of lengthnum_classes
with the element atindex
set to one:>>> import jax.numpy as jnp >>> one_hot(jnp.array([0, 1, 2]), 3) Array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], dtype=float32)
Indicies outside the range [0, num_classes) will be encoded as zeros:
>>> import jax.numpy as jnp >>> one_hot(jnp.array([-1, 3]), 3) Array([[0., 0., 0.], [0., 0., 0.]], dtype=float32)
- Parameters:
x – A tensor of indices.
num_classes – Number of classes in the one-hot dimension.
dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
axis – the axis or axes along which the function should be computed.