jax.numpy.left_shift — JAX documentation (original) (raw)
jax.numpy.left_shift#
jax.numpy.left_shift(x, y, /)[source]#
Shift bits of x to left by the amount specified in y, element-wise.
JAX implementation of numpy.left_shift.
Parameters:
- x (ArrayLike) – Input array, must be integer-typed.
- y (ArrayLike) – The amount of bits to shift each element in
xto the left, only accepts integer subtypes.xandymust either have same shape or be broadcast compatible.
Returns:
An array containing the left shifted elements of x by the amount specified in y, with the same shape as the broadcasted shape of x and y.
Return type:
Note
Left shifting x by y is equivalent to x * (2**y) within the bounds of the dtypes involved.
See also
- jax.numpy.right_shift(): and jax.numpy.bitwise_right_shift(): Shifts the bits of
x1to right by the amount specified inx2, element-wise. - jax.numpy.bitwise_left_shift(): Alias of
jax.left_shift().
Examples
def print_binary(x): ... return [bin(int(val)) for val in x]
x1 = jnp.arange(5) x1 Array([0, 1, 2, 3, 4], dtype=int32) print_binary(x1) ['0b0', '0b1', '0b10', '0b11', '0b100'] x2 = 1 result = jnp.left_shift(x1, x2) result Array([0, 2, 4, 6, 8], dtype=int32) print_binary(result) ['0b0', '0b10', '0b100', '0b110', '0b1000']
x3 = 4 print_binary([x3]) ['0b100'] x4 = jnp.array([1, 2, 3, 4]) result1 = jnp.left_shift(x3, x4) result1 Array([ 8, 16, 32, 64], dtype=int32) print_binary(result1) ['0b1000', '0b10000', '0b100000', '0b1000000']