jax.numpy.unpackbits — JAX documentation (original) (raw)

jax.numpy.unpackbits#

jax.numpy.unpackbits(a, axis=None, count=None, bitorder='big')[source]#

Unpack the bits in a uint8 array.

JAX implementation of numpy.unpackbits().

Parameters:

Returns:

a uint8 array of unpacked bits.

Return type:

Array

Examples

Unpacking bits from a scalar:

jnp.unpackbits(jnp.uint8(27)) # big-endian by default Array([0, 0, 0, 1, 1, 0, 1, 1], dtype=uint8) jnp.unpackbits(jnp.uint8(27), bitorder="little") Array([1, 1, 0, 1, 1, 0, 0, 0], dtype=uint8)

Compare this to the Python binary representation:

Unpacking bits along an axis:

vals = jnp.array([[154], ... [ 49]], dtype='uint8') bits = jnp.unpackbits(vals, axis=1) bits Array([[1, 0, 0, 1, 1, 0, 1, 0], [0, 0, 1, 1, 0, 0, 0, 1]], dtype=uint8)

Using packbits() to invert this:

jnp.packbits(bits, axis=1) Array([[154], [ 49]], dtype=uint8)

The count keyword lets unpackbits serve as an inverse of packbitsin cases where not all bits are present:

bits = jnp.array([1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1]) # 11 bits vals = jnp.packbits(bits) vals Array([219, 96], dtype=uint8) jnp.unpackbits(vals) # 16 zero-padded bits Array([1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0], dtype=uint8) jnp.unpackbits(vals, count=11) # specify 11 output bits Array([1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1], dtype=uint8) jnp.unpackbits(vals, count=-5) # specify 5 bits to be trimmed Array([1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1], dtype=uint8)