nki.language — AWS Neuron Documentation (original) (raw)

import numpy as np import ml_dtypes

bfloat16 = np.dtype('bfloat16') r"""16-bit floating-point number (1S,8E,7M)"""

bool_ = np.bool_ r"""Boolean type (True or False), stored as a byte. Same as numpy.bool_."""

float16 = np.float16 r"""Half-precision floating-point number type. Same as numpy.float16."""

float32 = np.float32 r"""Single-precision floating-point number type, compatible with C float. Same as numpy.float32."""

float8_e4m3 = np.dtype('float8_e4m3') r"""8-bit floating-point number (1S,4E,3M)"""

float8_e5m2 = np.dtype('float8_e5m2') r"""8-bit floating-point number (1S,5E,2M)"""

[docs]class fp32: r""" FP32 Constants"""

@property def min(self): r"""FP32 Bit pattern (0xff7fffff) representing the minimum (or maximum negative) FP32 value""" ...

hbm = ... r"""HBM - Alias of private_hbm"""

int16 = np.int16 r"""Signed integer type, compatible with C short. Same as numpy.int16."""

int32 = np.int32 r"""Signed integer type, compatible with C int. Same as numpy.int32."""

int8 = np.int8 r"""Signed integer type, compatible with C char. Same as numpy.int8."""

mgrid = ... r""" Same as NumPy mgrid: "An instance which returns a dense (or fleshed out) mesh-grid when indexed, so that each returned argument has the same shape. The dimensions and number of the output arrays are equal to the number of indexing dimensions."

Complex numbers are not supported in the step length.

((Similar to numpy.mgrid <https://numpy.org/doc/stable/reference/generated/numpy.mgrid.html>_))

.. nki_example:: ../../test/test_nki_nl_mgrid.py :language: python 📑 NKI_EXAMPLE_8

.. nki_example:: ../../test/test_nki_nl_mgrid.py :language: python 📑 NKI_EXAMPLE_9

"""

nc = ... r""" Create a logical neuron core dimension in launch grid.

The instances of spmd kernel will be distributed to different physical neuron cores on the annotated dimension.

.. code-block:: python

# Let compiler decide how to distribute the instances of spmd kernel
c = kernel[2, 2](a, b)

import neuronxcc.nki.language as nl

# Distribute the kernel to physical neuron cores around the first dimension
# of the spmd grid.
c = kernel[nl.nc(2), 2](a, b)
# This means:
# Physical NC [0]: kernel[0, 0], kernel[0, 1]
# Physical NC [1]: kernel[1, 0], kernel[1, 1]

Sometimes the size of a spmd dimension is bigger than the number of available physical neuron cores. We can control the distribution with the following syntax:

.. nki_example:: ../../test/test_nki_spmd_grid.py :language: python 📑 NKI_EXAMPLE_0

"""

par_dim = ... r""" Mark a dimension explicitly as a partition dimension. """

private_hbm = ... r"""HBM - Only visible to each individual kernel instance in the SPMD grid"""

psum = ... r"""PSUM - Only visible to each individual kernel instance in the SPMD grid, alias of nki.compiler.psum.auto_alloc()"""

sbuf = ... r"""State Buffer - Only visible to each individual kernel instance in the SPMD grid, alias of nki.compiler.sbuf.auto_alloc()"""

shared_hbm = ... r"""Shared HBM - Visible to all kernel instances in the SPMD grid"""

spmd_dim = ... r""" Create a dimension in the SPMD launch grid of a NKI kernel with sub-dimension tiling.

A key use case for spmd_dim is to shard an existing NKI kernel over multiple NeuronCores without modifying the internal kernel implementation. Suppose we have a kernel, nki_spmd_kernel, which is launched with a 2D SPMD grid, (4, 2). We can shard the first dimension of the launch grid (size 4) over two physical NeuronCores by directly manipulating the launch grid as follows:

.. nki_example:: ../../test/test_nki_spmd_grid.py :language: python 📑 NKI_EXAMPLE_0

"""

tfloat32 = np.dtype('|V4') r"""32-bit floating-point number (1S,8E,10M)"""

[docs]class tile_size: r""" Tile size constants. """

@property def bn_stats_fmax(self): r"""Maximum free dimension of BN_STATS""" ...

@property def gemm_moving_fmax(self): r"""Maximum free dimension of the moving operand of General Matrix Multiplication on Tensor Engine.""" ...

@property def gemm_stationary_fmax(self): r"""Maximum free dimension of the stationary operand of General Matrix Multiplication on Tensor Engine.""" ...

@property def pmax(self): r"""Maximum partition dimension of a tile.""" ...

@property def psum_fmax(self): r"""Maximum free dimension of a tile on PSUM buffer.""" ...

@property def psum_min_align(self): r"""The minimum byte alignment requirement for PSUM free dimension address.""" ...

@property def sbuf_min_align(self): r"""The minimum byte alignment requirement for SBUF free dimension address.""" ...

uint16 = np.uint16 r"""Unsigned integer type, compatible with C unsigned short. Same as numpy.uint16."""

uint32 = np.uint32 r"""Unsigned integer type, compatible with C unsigned int. Same as numpy.uint32."""

uint8 = np.uint8 r"""Unsigned integer type, compatible with C unsigned char. Same as numpy.uint8."""

[docs]def abs(x, *, dtype=None, mask=None, **kwargs): r""" Absolute value of the input, element-wise.

:param x: a tile. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tile. :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile that has absolute values of x. """ ...

[docs]def add(x, y, *, dtype=None, mask=None, **kwargs): r""" Add the inputs, element-wise.

((Similar to numpy.add <https://numpy.org/doc/stable/reference/generated/numpy.add.html>_))

:param x: a tile or a scalar value. :param y: a tile or a scalar value. x.shape and y.shape must be broadcastable <https://numpy.org/doc/stable/user/basics.broadcasting.html>__ to a common shape, that will become the shape of the output. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tiles, or whichever input type has the highest precision (see :ref:nki-type-promotion for more information); :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile that has x + y, element-wise.

Examples:

.. nki_example:: ../../test/test_nki_nl_add.py :language: python 📑 NKI_EXAMPLE_20

.. note:: Broadcasting in the partition dimension is generally more expensive than broadcasting in free dimension. It is recommended to align your data to perform free dimension broadcast whenever possible.

""" ...

[docs]def affine_range(*args, **kwargs): r""" Create a sequence of numbers for use as parallel loop iterators in NKI. affine_range should be the default loop iterator choice, when there is no loop carried dependency. Note, associative reductions are not considered loop carried dependencies in this context. A concrete example of associative reduction is multiple :doc:nl.matmul <nki.language.matmul> or :doc:nisa.nc_matmul <nki.isa.nc_matmul> calls accumulating into the same output buffer defined outside of this loop level (see code example #2 below).

When the above conditions are not met, we recommend using :doc:sequential_range <nki.language.sequential_range> instead.

Notes:

.. code-block:: :linenos:

import neuronxcc.nki.language as nl

#######################################################################
# Example 1: No loop carried dependency
# Input/Output tensor shape: [128, 2048]
# Load one tile ([128, 512]) at a time, square the tensor element-wise,
# and store it into output tile
#######################################################################

# Every loop instance works on an independent input/output tile.
# No data dependency between loop instances.
for i_input in nl.affine_range(input.shape[1] // 512):
  offset = i_input * 512
  input_sb = nl.load(input[0:input.shape[0], offset:offset+512])
  result = nl.multiply(input_sb, input_sb)
  nl.store(output[0:input.shape[0], offset:offset+512], result)

#######################################################################
# Example 2: Matmul output buffer accumulation, a type of associative reduction
# Input tensor shapes for nl.matmul: xT[K=2048, M=128] and y[K=2048, N=128]
# Load one tile ([128, 128]) from both xT and y at a time, matmul and
# accumulate into the same output buffer
#######################################################################

result_psum = nl.zeros((128, 128), dtype=nl.float32, buffer=nl.psum)
for i_K in nl.affine_range(xT.shape[0] // 128):
  offset = i_K * 128
  xT_sbuf = nl.load(offset:offset+128, 0:xT.shape[1]])
  y_sbuf = nl.load(offset:offset+128, 0:y.shape[1]])

  result_psum += nl.matmul(xT_sbuf, y_sbuf, transpose_x=True)

""" ...

[docs]def all(x, axis, *, dtype=bool, mask=None, **kwargs): r""" Whether all elements along the specified axis (or axes) evaluate to True.

((Similar to numpy.all <https://numpy.org/doc/stable/reference/generated/numpy.all.html>_))

:param x: a tile. :param axis: int or tuple/list of ints. The axis (or axes) along which to operate; must be free dimensions, not partition dimension (0); can only be the last contiguous dim(s) of the tile: [1], [1,2], [1,2,3], [1,2,3,4] :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tile. :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a boolean tile with the result. This return tile will have a shape of the input tile's shape with the specified axes removed. """ ...

[docs]def all_reduce(x, op, program_axes, *, dtype=None, mask=None, parallel_reduce=True, asynchronous=False, **kwargs): r""" Apply reduce operation over multiple SPMD programs.

:param x: a tile. :param op: numpy ALU operator to use to reduce over the input tile. :param program_axes: a single axis or a tuple of axes along which the reduction operation is performed. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tile. :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :param parallel_reduce: optional boolean parameter whether to turn on parallel reduction. Enable parallel reduction consumes additional memory. :param asynchronous: Defaults to False. If True, caller should synchronize before reading final result, e.g. using nki.sync_thread. :return: the reduced resulting tile """ ...

[docs]def arange(*args): r""" Return contiguous values within a given interval, used for indexing a tensor to define a tile.

((Similar to numpy.arange <https://numpy.org/doc/stable/reference/generated/numpy.arange.html>_))

arange can be called as: - arange(stop): Values are generated within the half-open interval [0, stop) (the interval including zero, excluding stop). - arange(start, stop): Values are generated within the half-open interval [start, stop) (the interval including start, excluding stop). """ ...

[docs]def arctan(x, *, dtype=None, mask=None, **kwargs): r""" Inverse tangent of the input, element-wise.

((Similar to numpy.arctan <https://numpy.org/doc/stable/reference/generated/numpy.arctan.html>_))

:param x: a tile. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tile. :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile that has inverse tangent values of x. """ ...

[docs]def atomic_rmw(dst, value, op, *, mask=None, **kwargs): r""" Perform an atomic read-modify-write operation on HBM data dst = op(dst, value)

:param dst: HBM tensor with subscripts, only supports indirect dynamic indexing currently. :param value: tile or scalar value that is the operand to op. :param op: atomic operation to perform, only supports np.add currently. :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return:

.. nki_example:: ../../test/test_nki_nl_atomic_rmw.py
:language: python 📑 NKI_EXAMPLE_18

""" ...

[docs]def bitwise_and(x, y, *, dtype=None, mask=None, **kwargs): r""" Bitwise AND of the two inputs, element-wise.

((Similar to numpy.bitwise_and <https://numpy.org/doc/stable/reference/generated/numpy.bitwise_and.html>_))

Computes the bit-wise AND of the underlying binary representation of the integers in the input tiles. This function implements the C/Python operator &

:param x: a tile or a scalar value of integer type. :param y: a tile or a scalar value of integer type. x.shape and y.shape must be broadcastable <https://numpy.org/doc/stable/user/basics.broadcasting.html>__ to a common shape, that will become the shape of the output. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tiles, or whichever input type has the highest precision (see :ref:nki-type-promotion for more information); :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile that has values x & y. """ ...

[docs]def bitwise_or(x, y, *, dtype=None, mask=None, **kwargs): r""" Bitwise OR of the two inputs, element-wise.

((Similar to numpy.bitwise_or <https://numpy.org/doc/stable/reference/generated/numpy.bitwise_or.html>_))

Computes the bit-wise OR of the underlying binary representation of the integers in the input tiles. This function implements the C/Python operator |

:param x: a tile or a scalar value of integer type. :param y: a tile or a scalar value of integer type. x.shape and y.shape must be broadcastable <https://numpy.org/doc/stable/user/basics.broadcasting.html>__ to a common shape, that will become the shape of the output. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tiles, or whichever input type has the highest precision (see :ref:nki-type-promotion for more information); :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile that has values x | y. """ ...

[docs]def bitwise_xor(x, y, *, dtype=None, mask=None, **kwargs): r""" Bitwise XOR of the two inputs, element-wise.

((Similar to numpy.bitwise_xor <https://numpy.org/doc/stable/reference/generated/numpy.bitwise_xor.html>_))

Computes the bit-wise XOR of the underlying binary representation of the integers in the input tiles. This function implements the C/Python operator ^

:param x: a tile or a scalar value of integer type. :param y: a tile or a scalar value of integer type. x.shape and y.shape must be broadcastable <https://numpy.org/doc/stable/user/basics.broadcasting.html>__ to a common shape, that will become the shape of the output. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tiles, or whichever input type has the highest precision (see :ref:nki-type-promotion for more information); :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile that has values x ^ y. """ ...

[docs]def broadcast_to(src, *, shape, **kwargs): r""" Broadcast the src tile to a new shape based on numpy broadcast rules. The src may also be a tensor object which may be implicitly converted to a tile. A tensor can be implicitly converted to a tile if the partition dimension is the outermost dimension.

:param src: the source of broadcast, a tile in SBUF or PSUM. May also be a tensor object. :param shape: the target shape for broadcasting. :return: a new tile broadcast along the partition dimension of src, this new tile will be in SBUF, but can be also assigned to a PSUM tensor.

.. nki_example:: ../../test/test_nki_nl_broadcast.py :language: python 📑 NKI_EXAMPLE_5

""" ...

[docs]def ceil(x, *, dtype=None, mask=None, **kwargs): r""" Ceiling of the input, element-wise.

((Similar to numpy.ceil <https://numpy.org/doc/stable/reference/generated/numpy.ceil.html>_))

The ceil of the scalar x is the smallest integer i, such that i >= x.

:param x: a tile. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tile. :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile that has ceiling values of x. """ ...

[docs]def copy(src, *, mask=None, dtype=None, **kwargs): r""" Create a copy of the src tile.

:param src: the source of copy, must be a tile in SBUF or PSUM. :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tile. :return: a new tile with the same layout as src, this new tile will be in SBUF, but can be also assigned to a PSUM tensor. """ ...

[docs]def cos(x, *, dtype=None, mask=None, **kwargs): r""" Cosine of the input, element-wise.

((Similar to numpy.cos <https://numpy.org/doc/stable/reference/generated/numpy.cos.html>_))

:param x: a tile. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tile. :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile that has cosine values of x. """ ...

[docs]def device_print(prefix, x, *, mask=None, **kwargs): r""" Print a message with a String prefix followed by the value of a tile x. Printing is currently only supported in kernel simulation mode (see :doc:nki.simulate_kernel <nki.simulate_kernel> for a code example).

:param prefix: prefix of the print message :param x: data to print out :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: None """ ...

[docs]def divide(x, y, *, dtype=None, mask=None, **kwargs): r""" Divide the inputs, element-wise.

((Similar to numpy.divide <https://numpy.org/doc/stable/reference/generated/numpy.divide.html>_))

:param x: a tile or a scalar value. :param y: a tile or a scalar value. x.shape and y.shape must be broadcastable <https://numpy.org/doc/stable/user/basics.broadcasting.html>__ to a common shape, that will become the shape of the output. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tiles, or whichever input type has the highest precision (see :ref:nki-type-promotion for more information); :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile that has x / y, element-wise. """ ...

[docs]def dropout(x, rate, *, dtype=None, mask=None, **kwargs): r""" Randomly zeroes some of the elements of the input tile given a probability rate.

:param x: a tile. :param rate: a scalar value or a tile with 1 element, with the probability rate. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tile. :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile with randomly zeroed elements of x. """ ...

[docs]def ds(start, size): r""" Construct a dynamic slice for simple tensor indexing.

.. nki_example:: ../../test/test_nki_nl_dslice.py :language: python 📑 NKI_EXAMPLE_1

""" ...

[docs]def empty_like(a, dtype=None, *, buffer=None, name="", **kwargs): r""" Create a new tensor with the same shape and type as a given tensor.

((Similar to numpy.empty_like <https://numpy.org/doc/stable/reference/generated/numpy.empty_like.html>_))

:param a: the tensor. :param dtype: the data type of the tensor (see :ref:nki-dtype for more information). :param buffer: the specific buffer (ie, :doc:sbuf<nki.language.sbuf>, :doc:psum<nki.language.psum>, :doc:hbm<nki.language.hbm>), defaults to :doc:sbuf<nki.language.sbuf>. :param name: the name of the tensor. :return: a tensor with the same shape and type as a given tensor. """ ...

[docs]def equal(x, y, *, dtype=bool, mask=None, **kwargs): r""" Element-wise boolean result of x == y.

((Similar to numpy.equal <https://numpy.org/doc/stable/reference/generated/numpy.equal.html>_))

:param x: a tile or a scalar value. :param y: a tile or a scalar value. x.shape and y.shape must be broadcastable <https://numpy.org/doc/stable/user/basics.broadcasting.html>__ to a common shape, that will become the shape of the output. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tiles, or whichever input type has the highest precision (see :ref:nki-type-promotion for more information); :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile with boolean result of x == y element-wise. """ ...

[docs]def erf(x, *, dtype=None, mask=None, **kwargs): r""" Error function of the input, element-wise.

((Similar to torch.erf <https://pytorch.org/docs/master/generated/torch.erf.html>_))

erf(x) = 2/sqrt(pi)*integral(exp(-t**2), t=0..x) .

:param x: a tile. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tile. :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile that has erf of x. """ ...

[docs]def erf_dx(x, *, dtype=None, mask=None, **kwargs): r""" Derivative of the Error function (erf) on the input, element-wise.

:param x: a tile. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tile. :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile that has erf_dx of x. """ ...

[docs]def exp(x, *, dtype=None, mask=None, **kwargs): r""" Exponential of the input, element-wise.

((Similar to numpy.exp <https://numpy.org/doc/stable/reference/generated/numpy.exp.html>_))

The exp(x) is e^x where e is the Euler's number = 2.718281...

:param x: a tile. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tile. :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile that has exponential values of x. """ ...

[docs]def expand_dims(data, axis): r""" Expand the shape of a tile. Insert a new axis that will appear at the axis position in the expanded tile shape. Currently only supports expanding dimensions after the last index of the tile.

((Similar to numpy.expand_dims <https://numpy.org/doc/stable/reference/generated/numpy.expand_dims.html>_))

:param data: a tile input :param axis: int or tuple/list of ints. Position in the expanded axes where the new axis (or axes) is placed; must be free dimensions, not partition dimension (0); Currently only supports axis (or axes) after the last index. :return: a tile with view of input data with the number of dimensions increased. """ ...

def finfo(dtype): ...

[docs]def floor(x, *, dtype=None, mask=None, **kwargs): r""" Floor of the input, element-wise.

((Similar to numpy.floor <https://numpy.org/doc/stable/reference/generated/numpy.floor.html>_))

The floor of the scalar x is the largest integer i, such that i <= x.

:param x: a tile. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tile. :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile that has floor values of x. """ ...

[docs]def fmod(x, y, dtype=None, mask=None, **kwargs): r""" Floor-mod of x / y, element-wise.

The remainder has the same sign as the dividend x. It is equivalent to the Matlab(TM) rem function and should not be confused with the Python modulus operator x % y.

((Similar to numpy.fmod <https://numpy.org/doc/stable/reference/generated/numpy.fmod.html>_))

:param x: a tile. If x is a scalar value it will be broadcast to the shape of y. x.shape and y.shape must be broadcastable <https://numpy.org/doc/stable/user/basics.broadcasting.html>__ to a common shape, that will become the shape of the output. :param y: a tile or a scalar value. x.shape and y.shape must be broadcastable <https://numpy.org/doc/stable/user/basics.broadcasting.html>__ to a common shape, that will become the shape of the output. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tiles, or whichever input type has the highest precision (see :ref:nki-type-promotion for more information); :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile that has values x fmod y. """ ...

[docs]def full(shape, fill_value, dtype, *, buffer=None, name="", **kwargs): r""" Create a new tensor of given shape and dtype on the specified buffer, filled with initial value.

((Similar to numpy.full <https://numpy.org/doc/stable/reference/generated/numpy.full.html>_))

:param shape: the shape of the tensor. :param fill_value: the initial value of the tensor. :param dtype: the data type of the tensor (see :ref:nki-dtype for more information). :param buffer: the specific buffer (ie, :doc:sbuf<nki.language.sbuf>, :doc:psum<nki.language.psum>, :doc:hbm<nki.language.hbm>), defaults to :doc:sbuf<nki.language.sbuf>. :param name: the name of the tensor. :return: a new tensor allocated on the buffer. """ ...

[docs]def gelu(x, *, dtype=None, mask=None, **kwargs): r""" Gaussian Error Linear Unit activation function on the input, element-wise.

((Similar to torch.nn.functional.gelu <https://pytorch.org/docs/stable/generated/torch.nn.functional.gelu.html>_))

:param x: a tile. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tile. :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile that has gelu of x. """ ...

[docs]def gelu_apprx_tanh(x, *, dtype=None, mask=None, **kwargs): r""" Gaussian Error Linear Unit activation function on the input, element-wise, with tanh approximation.

:param x: a tile. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tile. :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile that has gelu of x. """ ...

[docs]def gelu_dx(x, *, dtype=None, mask=None, **kwargs): r""" Derivative of Gaussian Error Linear Unit (gelu) on the input, element-wise.

:param x: a tile. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tile. :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile that has gelu_dx of x. """ ...

def gemm_grid(): r""" Tile definition for result of Matrix Multiplication on Tensor Engine, it is identical to the Tile definition of moving operand of Matrix Multiplication on Tensor Engine as well.""" ...

def gemm_stationary_grid(): r""" Tile definition for stationary operand of Matrix Multiplication on Tensor Engine.""" ...

[docs]def greater(x, y, *, dtype=bool, mask=None, **kwargs): r""" Element-wise boolean result of x > y.

((Similar to numpy.greater <https://numpy.org/doc/stable/reference/generated/numpy.greater.html>_))

:param x: a tile or a scalar value. :param y: a tile or a scalar value. x.shape and y.shape must be broadcastable <https://numpy.org/doc/stable/user/basics.broadcasting.html>__ to a common shape, that will become the shape of the output. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tiles, or whichever input type has the highest precision (see :ref:nki-type-promotion for more information); :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile with boolean result of x > y element-wise. """ ...

[docs]def greater_equal(x, y, *, dtype=bool, mask=None, **kwargs): r""" Element-wise boolean result of x >= y.

((Similar to numpy.greater_equal <https://numpy.org/doc/stable/reference/generated/numpy.greater_equal.html>_))

:param x: a tile or a scalar value. :param y: a tile or a scalar value. x.shape and y.shape must be broadcastable <https://numpy.org/doc/stable/user/basics.broadcasting.html>__ to a common shape, that will become the shape of the output. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tiles, or whichever input type has the highest precision (see :ref:nki-type-promotion for more information); :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile with boolean result of x >= y element-wise. """ ...

[docs]def invert(x, *, dtype=None, mask=None, **kwargs): r""" Bitwise NOT of the input, element-wise.

((Similar to numpy.invert <https://numpy.org/doc/stable/reference/generated/numpy.invert.html>_))

Computes the bit-wise NOT of the underlying binary representation of the integers in the input tile. This ufunc implements the C/Python operator ~

:param x: a tile. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tile. :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile with bitwise NOT x element-wise. """ ...

[docs]def left_shift(x, y, *, dtype=None, mask=None, **kwargs): r""" Bitwise left-shift x by y, element-wise.

((Similar to numpy.left_shift <https://numpy.org/doc/stable/reference/generated/numpy.left_shift.html>_))

Computes the bit-wise left shift of the underlying binary representation of the integers in the input tiles. This function implements the C/Python operator <<

:param x: a tile or a scalar value of integer type. :param y: a tile or a scalar value of integer type. x.shape and y.shape must be broadcastable <https://numpy.org/doc/stable/user/basics.broadcasting.html>__ to a common shape, that will become the shape of the output. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tiles, or whichever input type has the highest precision (see :ref:nki-type-promotion for more information); :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile that has values x << y. """ ...

[docs]def less(x, y, *, dtype=bool, mask=None, **kwargs): r""" Element-wise boolean result of x < y.

((Similar to numpy.less <https://numpy.org/doc/stable/reference/generated/numpy.less.html>_))

:param x: a tile or a scalar value. :param y: a tile or a scalar value. x.shape and y.shape must be broadcastable <https://numpy.org/doc/stable/user/basics.broadcasting.html>__ to a common shape, that will become the shape of the output. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tiles, or whichever input type has the highest precision (see :ref:nki-type-promotion for more information); :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile with boolean result of x < y element-wise. """ ...

[docs]def less_equal(x, y, *, dtype=bool, mask=None, **kwargs): r""" Element-wise boolean result of x <= y.

((Similar to numpy.less_equal <https://numpy.org/doc/stable/reference/generated/numpy.less_equal.html>_))

:param x: a tile or a scalar value. :param y: a tile or a scalar value. x.shape and y.shape must be broadcastable <https://numpy.org/doc/stable/user/basics.broadcasting.html>__ to a common shape, that will become the shape of the output. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tiles, or whichever input type has the highest precision (see :ref:nki-type-promotion for more information); :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile with boolean result of x <= y element-wise. """ ...

[docs]def load(src, *, mask=None, dtype=None, **kwargs): r""" Load a tensor from device memory (HBM) into on-chip memory (SBUF).

See :ref:nki-pm-memory for detailed information.

:param src: HBM tensor to load the data from. :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tile. :return: a new tile on SBUF with values from src.

.. nki_example:: ../../test/test_nki_nl_load_store.py :language: python 📑 NKI_EXAMPLE_10

.. note:: Partition dimension size can't exceed the hardware limitation of nki.language.tile_size.pmax, see :ref:nki-tile-size.

Partition dimension has to be the first dimension in the index tuple of a tile. Therefore, data may need to be split into multiple batches to load/store, for example:

.. nki_example:: ../../test/test_nki_nl_load_store.py :language: python 📑 NKI_EXAMPLE_11

Also supports indirect DMA access with dynamic index values:

.. nki_example:: ../../test/test_nki_nl_load_store_indirect.py :language: python 📑 NKI_EXAMPLE_12

.. nki_example:: ../../test/test_nki_nl_load_store_indirect.py :language: python 📑 NKI_EXAMPLE_13 """ ...

[docs]def load_transpose2d(src, *, mask=None, dtype=None, **kwargs): r""" Load a tensor from device memory (HBM) and 2D-transpose the data before storing into on-chip memory (SBUF).

:param src: HBM tensor to load the data from. :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tile. :return: a new tile on SBUF with values from src 2D-transposed.

.. nki_example:: ../../test/test_nki_nl_load_transpose2d.py :language: python 📑 NKI_EXAMPLE_19

.. note:: Partition dimension size can't exceed the hardware limitation of nki.language.tile_size.pmax, see :ref:nki-tile-size.

""" ...

[docs]def log(x, *, dtype=None, mask=None, **kwargs): r""" Natural logarithm of the input, element-wise.

((Similar to numpy.log <https://numpy.org/doc/stable/reference/generated/numpy.log.html>_))

It is the inverse of the exponential function, such that: log(exp(x)) = x . The natural logarithm base is e.

:param x: a tile. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tile. :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile that has natural logarithm values of x. """ ...

[docs]def logical_and(x, y, *, dtype=bool, mask=None, **kwargs): r""" Element-wise boolean result of x AND y.

((Similar to numpy.logical_and <https://numpy.org/doc/stable/reference/generated/numpy.logical_and.html>_))

:param x: a tile or a scalar value. :param y: a tile or a scalar value. x.shape and y.shape must be broadcastable <https://numpy.org/doc/stable/user/basics.broadcasting.html>__ to a common shape, that will become the shape of the output. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tiles, or whichever input type has the highest precision (see :ref:nki-type-promotion for more information); :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile with boolean result of x AND y element-wise. """ ...

[docs]def logical_not(x, *, dtype=bool, mask=None, **kwargs): r""" Element-wise boolean result of NOT x.

((Similar to numpy.logical_not <https://numpy.org/doc/stable/reference/generated/numpy.logical_not.html>_))

:param x: a tile. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tile. :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile with boolean result of NOT x element-wise. """ ...

[docs]def logical_or(x, y, *, dtype=bool, mask=None, **kwargs): r""" Element-wise boolean result of x OR y.

((Similar to numpy.logical_or <https://numpy.org/doc/stable/reference/generated/numpy.logical_or.html>_))

:param x: a tile or a scalar value. :param y: a tile or a scalar value. x.shape and y.shape must be broadcastable <https://numpy.org/doc/stable/user/basics.broadcasting.html>__ to a common shape, that will become the shape of the output. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tiles, or whichever input type has the highest precision (see :ref:nki-type-promotion for more information); :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile with boolean result of x OR y element-wise. """ ...

[docs]def logical_xor(x, y, *, dtype=bool, mask=None, **kwargs): r""" Element-wise boolean result of x XOR y.

((Similar to numpy.logical_xor <https://numpy.org/doc/stable/reference/generated/numpy.logical_xor.html>_))

:param x: a tile or a scalar value. :param y: a tile or a scalar value. x.shape and y.shape must be broadcastable <https://numpy.org/doc/stable/user/basics.broadcasting.html>__ to a common shape, that will become the shape of the output. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tiles, or whichever input type has the highest precision (see :ref:nki-type-promotion for more information); :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile with boolean result of x XOR y element-wise. """ ...

[docs]def loop_reduce(x, op, loop_indices, *, dtype=None, mask=None, **kwargs): r""" Apply reduce operation over a loop. This is an ideal instruction to compute a high performance reduce_max or reduce_min.

Note: The destination tile is also the rhs input to op. For example,

.. code-block:: python

b = nl.zeros((N_TILE_SIZE, M_TILE_SIZE), dtype=float32, buffer=nl.sbuf)
for k_i in affine_range(NUM_K_BLOCKS):
  
  # Skipping over multiple nested loops here. 
  # a, is a psum tile from a matmul accumulation group.
  b = nl.loop_reduce(a, op=np.add, loop_indices=[k_i], dtype=nl.float32)
 

is the same as:

.. code-block:: python

b = nl.zeros((N_TILE_SIZE, M_TILE_SIZE), dtype=nl.float32, buffer=nl.sbuf)
for k_i in affine_range(NUM_K_BLOCKS):

  # Skipping over multiple nested loops here. 
  # a, is a psum tile from a matmul accumulation group.
  b = nisa.tensor_tensor(data1=b, data2=a, op=np.add, dtype=nl.float32)

If you are trying to use this instruction only for accumulating results on SBUF, consider simply using the += operator instead.

The loop_indices list enables the compiler to recognize which loops this reduction can be optimized across as part of any aggressive loop-level optimizations it may perform.

:param x: a tile. :param op: numpy ALU operator to use to reduce over the input tile. :param loop_indices: a single loop index or a tuple of loop indices along which the reduction operation is performed. Can be numbers or loop_index objects coming from nl.affine_range. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tile. :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: the reduced resulting tile """ ...

[docs]def matmul(x, y, *, transpose_x=False, mask=None, **kwargs): r""" x @ y matrix multiplication of x and y.

((Similar to numpy.matmul <https://numpy.org/doc/stable/reference/generated/numpy.matmul.html>_))

.. note:: For optimal performance on hardware, use :func:nki.isa.nc_matmul or call nki.language.matmul with transpose_x=True. Use nki.isa.nc_matmul also to access low-level features of the Tensor Engine.

.. note:: Implementation details: nki.language.matmul calls nki.isa.nc_matmul under the hood. nc_matmul is neuron specific customized implementation of matmul that computes x.T @ y, as a result, matmul(x, y) lowers to nc_matmul(transpose(x), y). To avoid this extra transpose instruction being inserted, use x.T and transpose_x=True inputs to this matmul.

:param x: a tile on SBUF (partition dimension <= 128, free dimension <= 128), x's free dimension must match y's partition dimension. :param y: a tile on SBUF (partition dimension <= 128, free dimension <= 512) :param transpose_x: Defaults to False. If True, x is treated as already transposed. If False, an additional transpose will be inserted to make x's partition dimension the contract dimension of the matmul to align with the Tensor Engine. :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details)

:return: x @ y or x.T @ y if transpose_x=True """ ...

[docs]def max(x, axis, *, dtype=None, mask=None, keepdims=False, **kwargs): r""" Maximum of elements along the specified axis (or axes) of the input.

((Similar to numpy.max <https://numpy.org/doc/stable/reference/generated/numpy.max.html>_))

:param x: a tile. :param axis: int or tuple/list of ints. The axis (or axes) along which to operate; must be free dimensions, not partition dimension (0); can only be the last contiguous dim(s) of the tile: [1], [1,2], [1,2,3], [1,2,3,4] :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tile. :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :param keepdims: If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array. :return: a tile with the maximum of elements along the provided axis. This return tile will have a shape of the input tile's shape with the specified axes removed. """ ...

[docs]def maximum(x, y, *, dtype=None, mask=None, **kwargs): r""" Maximum of the inputs, element-wise.

((Similar to numpy.maximum <https://numpy.org/doc/stable/reference/generated/numpy.maximum.html>_))

:param x: a tile or a scalar value. :param y: a tile or a scalar value. x.shape and y.shape must be broadcastable <https://numpy.org/doc/stable/user/basics.broadcasting.html>__ to a common shape, that will become the shape of the output. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tiles, or whichever input type has the highest precision (see :ref:nki-type-promotion for more information); :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile that has the maximum of each elements from x and y. """ ...

[docs]def mean(x, axis, *, dtype=None, mask=None, keepdims=False, **kwargs): r""" Arithmetic mean along the specified axis (or axes) of the input.

((Similar to numpy.mean <https://numpy.org/doc/stable/reference/generated/numpy.mean.html>_))

:param x: a tile. :param axis: int or tuple/list of ints. The axis (or axes) along which to operate; must be free dimensions, not partition dimension (0); can only be the last contiguous dim(s) of the tile: [1], [1,2], [1,2,3], [1,2,3,4] :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tile. :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile with the average of elements along the provided axis. This return tile will have a shape of the input tile's shape with the specified axes removed. float32 intermediate and return values are used for integer inputs. """ ...

[docs]def min(x, axis, *, dtype=None, mask=None, keepdims=False, **kwargs): r""" Minimum of elements along the specified axis (or axes) of the input.

((Similar to numpy.min <https://numpy.org/doc/stable/reference/generated/numpy.min.html>_))

:param x: a tile. :param axis: int or tuple/list of ints. The axis (or axes) along which to operate; must be free dimensions, not partition dimension (0); can only be the last contiguous dim(s) of the tile: [1], [1,2], [1,2,3], [1,2,3,4] :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tile. :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :param keepdims: If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array. :return: a tile with the minimum of elements along the provided axis. This return tile will have a shape of the input tile's shape with the specified axes removed. """ ...

[docs]def minimum(x, y, *, dtype=None, mask=None, **kwargs): r""" Minimum of the inputs, element-wise.

((Similar to numpy.minimum <https://numpy.org/doc/stable/reference/generated/numpy.minimum.html>_))

:param x: a tile or a scalar value. :param y: a tile or a scalar value. x.shape and y.shape must be broadcastable <https://numpy.org/doc/stable/user/basics.broadcasting.html>__ to a common shape, that will become the shape of the output. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tiles, or whichever input type has the highest precision (see :ref:nki-type-promotion for more information); :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile that has the minimum of each elements from x and y. """ ...

[docs]def mish(x, *, dtype=None, mask=None, **kwargs): r""" Mish activation function on the input, element-wise.

Mish: A Self Regularized Non-Monotonic Neural Activation Function is defined as:

.. math:: mish(x) = x * tanh(softplus(x))

see: https://arxiv.org/abs/1908.08681

:param x: a tile. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tile. :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile that has mish of x. """ ...

[docs]def mod(x, y, dtype=None, mask=None, **kwargs): r""" Integer Mod of x / y, element-wise

Computes the remainder complementary to the floor_divide function. It is equivalent to the Python modulus x % y and has the same sign as the divisor y.

((Similar to numpy.mod <https://numpy.org/doc/stable/reference/generated/numpy.mod.html>_))

:param x: a tile. If x is a scalar value it will be broadcast to the shape of y. x.shape and y.shape must be broadcastable <https://numpy.org/doc/stable/user/basics.broadcasting.html>__ to a common shape, that will become the shape of the output. :param y: a tile or a scalar value. x.shape and y.shape must be broadcastable <https://numpy.org/doc/stable/user/basics.broadcasting.html>__ to a common shape, that will become the shape of the output. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tiles, or whichever input type has the highest precision (see :ref:nki-type-promotion for more information); :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile that has values x mod y. """ ...

[docs]def multiply(x, y, *, dtype=None, mask=None, **kwargs): r""" Multiply the inputs, element-wise.

((Similar to numpy.multiply <https://numpy.org/doc/stable/reference/generated/numpy.multiply.html>_))

:param x: a tile or a scalar value. :param y: a tile or a scalar value. x.shape and y.shape must be broadcastable <https://numpy.org/doc/stable/user/basics.broadcasting.html>__ to a common shape, that will become the shape of the output. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tiles, or whichever input type has the highest precision (see :ref:nki-type-promotion for more information); :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile that has x * y, element-wise. """ ...

[docs]def ndarray(shape, dtype, *, buffer=None, name="", **kwargs): r""" Create a new tensor of given shape and dtype on the specified buffer.

((Similar to numpy.ndarray <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html>_))

:param shape: the shape of the tensor. :param dtype: the data type of the tensor (see :ref:nki-dtype for more information). :param buffer: the specific buffer (ie, :doc:sbuf<nki.language.sbuf>, :doc:psum<nki.language.psum>, :doc:hbm<nki.language.hbm>), defaults to :doc:sbuf<nki.language.sbuf>. :param name: the name of the tensor. :return: a new tensor allocated on the buffer. """ ...

[docs]def negative(x, *, dtype=None, mask=None, **kwargs): r""" Numerical negative of the input, element-wise.

((Similar to numpy.negative <https://numpy.org/doc/stable/reference/generated/numpy.negative.html>_))

:param x: a tile. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tile. :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile that has numerical negative values of x. """ ...

[docs]def not_equal(x, y, *, dtype=bool, mask=None, **kwargs): r""" Element-wise boolean result of x != y.

((Similar to numpy.not_equal <https://numpy.org/doc/stable/reference/generated/numpy.not_equal.html>_))

:param x: a tile or a scalar value. :param y: a tile or a scalar value. x.shape and y.shape must be broadcastable <https://numpy.org/doc/stable/user/basics.broadcasting.html>__ to a common shape, that will become the shape of the output. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tiles, or whichever input type has the highest precision (see :ref:nki-type-promotion for more information); :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile with boolean result of x != y element-wise. """ ...

[docs]def num_programs(axes=None): r""" Number of SPMD programs along the given axes in the launch grid. If axes is not provided, returns the total number of programs.

:param axes: The axes of the ND launch grid. If not provided, returns the total number of programs along the entire launch grid. :return: The number of SPMD(single process multiple data) programs along axes in the launch grid """ ...

[docs]def ones(shape, dtype, *, buffer=None, name="", **kwargs): r""" Create a new tensor of given shape and dtype on the specified buffer, filled with ones.

((Similar to numpy.ones <https://numpy.org/doc/stable/reference/generated/numpy.ones.html>_))

:param shape: the shape of the tensor. :param dtype: the data type of the tensor (see :ref:nki-dtype for more information). :param buffer: the specific buffer (ie, :doc:sbuf<nki.language.sbuf>, :doc:psum<nki.language.psum>, :doc:hbm<nki.language.hbm>), defaults to :doc:sbuf<nki.language.sbuf>. :param name: the name of the tensor. :return: a new tensor allocated on the buffer. """ ...

[docs]def power(x, y, *, dtype=None, mask=None, **kwargs): r""" Elements of x raised to powers of y, element-wise.

((Similar to numpy.power <https://numpy.org/doc/stable/reference/generated/numpy.power.html>_))

:param x: a tile or a scalar value. :param y: a tile or a scalar value. x.shape and y.shape must be broadcastable <https://numpy.org/doc/stable/user/basics.broadcasting.html>__ to a common shape, that will become the shape of the output. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tiles, or whichever input type has the highest precision (see :ref:nki-type-promotion for more information); :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile that has values x to the power of y. """ ...

[docs]def prod(x, axis, *, dtype=None, mask=None, keepdims=False, **kwargs): r""" Product of elements along the specified axis (or axes) of the input.

((Similar to numpy.prod <https://numpy.org/doc/stable/reference/generated/numpy.prod.html>_))

:param x: a tile. :param axis: int or tuple/list of ints. The axis (or axes) along which to operate; must be free dimensions, not partition dimension (0); can only be the last contiguous dim(s) of the tile: [1], [1,2], [1,2,3], [1,2,3,4] :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tile. :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :param keepdims: If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array. :return: a tile with the product of elements along the provided axis. This return tile will have a shape of the input tile's shape with the specified axes removed. """ ...

[docs]def program_id(axis): r""" Index of the current SPMD program along the given axis in the launch grid.

:param axis: The axis of the ND launch grid. :return: The program id along axis in the launch grid """ ...

[docs]def program_ndim(): r""" Number of dimensions in the SPMD launch grid.

:return: The number of dimensions in the launch grid, i.e. the number of axes """ ...

[docs]def rand(shape, dtype=np.float32, **kwargs): r""" Generate a tile of given shape and dtype, filled with random values that are sampled from a uniform distribution between 0 and 1.

:param shape: the shape of the tile. :param dtype: the data type of the tile (see :ref:nki-dtype for more information). :return: a tile with random values. """ ...

[docs]def random_seed(seed, *, mask=None, **kwargs): r""" Sets a seed, specified by user, to the random number generator on HW. Using the same seed will generate the same sequence of random numbers when using together with the random() API

:param seed: a scalar value to use as the seed. :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: none """ ...

[docs]def reciprocal(x, *, dtype=None, mask=None, **kwargs): r""" Reciprocal of the the input, element-wise.

((Similar to numpy.reciprocal <https://numpy.org/doc/stable/reference/generated/numpy.reciprocal.html>_))

reciprocal(x) = 1 / x

:param x: a tile. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tile. :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile that has reciprocal values of x. """ ...

[docs]def relu(x, *, dtype=None, mask=None, **kwargs): r""" Rectified Linear Unit activation function on the input, element-wise.

relu(x) = (x)+ = max(0,x)

((Similar to torch.nn.functional.relu <https://pytorch.org/docs/stable/generated/torch.nn.functional.relu.html>_))

:param x: a tile. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tile. :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile that has relu of x. """ ...

[docs]def right_shift(x, y, *, dtype=None, mask=None, **kwargs): r""" Bitwise right-shift x by y, element-wise.

((Similar to numpy.right_shift <https://numpy.org/doc/stable/reference/generated/numpy.right_shift.html>_))

Computes the bit-wise right shift of the underlying binary representation of the integers in the input tiles. This function implements the C/Python operator >>

:param x: a tile or a scalar value of integer type. :param y: a tile or a scalar value of integer type. x.shape and y.shape must be broadcastable <https://numpy.org/doc/stable/user/basics.broadcasting.html>__ to a common shape, that will become the shape of the output. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tiles, or whichever input type has the highest precision (see :ref:nki-type-promotion for more information); :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile that has values x >> y. """ ...

[docs]def rms_norm(x, w, axis, n, epsilon=1e-06, *, dtype=None, compute_dtype=None, mask=None, **kwargs): r""" Apply Root Mean Square Layer Normalization.

:param x: input tile :param w: weight tile :param axis: axis along which to compute the root mean square (rms) value :param n: total number of values to calculate rms :param epsilon: epsilon value used by rms calculation to avoid divide-by-zero :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tile. :param compute_dtype: (optional) dtype for the internal computation - currently dtype and compute_dtype behave the same, both sets internal compute and return dtype. :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: x / RMS(x) * w """ ...

[docs]def rsqrt(x, *, dtype=None, mask=None, **kwargs): r""" Reciprocal of the square-root of the input, element-wise.

((Similar to torch.rsqrt <https://pytorch.org/docs/master/generated/torch.rsqrt.html>_))

rsqrt(x) = 1 / sqrt(x)

:param x: a tile. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tile. :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile that has reciprocal square-root values of x. """ ...

[docs]def sequential_range(*args, **kwargs): r""" Create a sequence of numbers for use as sequential loop iterators in NKI. sequential_range should be used when there is a loop carried dependency. Note, associative reductions are not considered loop carried dependencies in this context. See :doc:affine_range <nki.language.affine_range> for an example of such associative reduction.

Notes:

.. code-block:: :linenos:

import neuronxcc.nki.language as nl

#######################################################################
# Example 1: Loop carried dependency from tiling tensor_tensor_scan
# Both sbuf tensor input0 and input1 shapes: [128, 2048]
# Perform a scan operation between the two inputs using a tile size of [128, 512]
# Store the scan output to another [128, 2048] tensor
#######################################################################

# Loop iterations communicate through this init tensor
init = nl.zeros((128, 1), dtype=input0.dtype)

# This loop will only produce correct results if the iterations are performed in order
for i_input in nl.sequential_range(input0.shape[1] // 512):
  offset = i_input * 512

  # Depends on scan result from the previous loop iteration
  result = nisa.tensor_tensor_scan(input0[:, offset:offset+512],
                                   input1[:, offset:offset+512],
                                   initial=init,
                                   op0=nl.multiply, op1=nl.add)

  nl.store(output[0:input0.shape[0], offset:offset+512], result)

  # Prepare initial result for scan in the next loop iteration
  init[:, :] = result[:, 511]

""" ...

[docs]def shared_constant(constant, dtype=None, **kwargs): r""" Create a new tensor filled with the data specified by data array.

:param constant: the constant data to be filled into a tensor :return: a tensor which contains the constant data """ ...

[docs]def shared_identity_matrix(n, dtype=np.uint8, **kwargs): r""" Create a new identity tensor with specified data type.

This function has the same behavior to :doc:nki.language.shared_constant <nki.language.shared_constant> but is preferred if the constant matrix is an identity matrix. The compiler will reuse all the identity matrices of the same dtype in the graph to save space.

:param n: the number of rows(and columns) of the returned identity matrix :param dtype: the data type of the tensor, default to be np.uint8 (see :ref:nki-dtype for more information). :return: a tensor which contains the identity tensor """ ...

[docs]def sigmoid(x, *, dtype=None, mask=None, **kwargs): r""" Logistic sigmoid activation function on the input, element-wise.

((Similar to torch.nn.functional.sigmoid <https://pytorch.org/docs/stable/generated/torch.nn.functional.sigmoid.html>_))

sigmoid(x) = 1/(1+exp(-x))

:param x: a tile. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tile. :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile that has sigmoid of x. """ ...

[docs]def sign(x, *, dtype=None, mask=None, **kwargs): r""" Sign of the numbers of the input, element-wise.

((Similar to numpy.sign <https://numpy.org/doc/stable/reference/generated/numpy.sign.html>_))

The sign function returns -1 if x < 0, 0 if x==0, 1 if x > 0.

:param x: a tile. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tile. :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile that has sign values of x. """ ...

[docs]def silu(x, *, dtype=None, mask=None, **kwargs): r""" Sigmoid Linear Unit activation function on the input, element-wise.

((Similar to torch.nn.functional.silu <https://pytorch.org/docs/stable/generated/torch.nn.functional.silu.html>_))

:param x: a tile. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tile. :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile that has silu of x. """ ...

[docs]def silu_dx(x, *, dtype=None, mask=None, **kwargs): r""" Derivative of Sigmoid Linear Unit activation function on the input, element-wise.

:param x: a tile. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tile. :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile that has silu_dx of x. """ ...

[docs]def sin(x, *, dtype=None, mask=None, **kwargs): r""" Sine of the input, element-wise.

((Similar to numpy.sin <https://numpy.org/doc/stable/reference/generated/numpy.sin.html>_))

:param x: a tile. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tile. :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile that has sine values of x. """ ...

[docs]def softmax(x, axis, *, dtype=None, compute_dtype=None, mask=None, **kwargs): r""" Softmax activation function on the input, element-wise.

((Similar to torch.nn.functional.softmax <https://pytorch.org/docs/stable/generated/torch.nn.functional.softmax.html>_))

:param x: a tile. :param axis: int or tuple/list of ints. The axis (or axes) along which to operate; must be free dimensions, not partition dimension (0); can only be the last contiguous dim(s) of the tile: [1], [1,2], [1,2,3], [1,2,3,4] :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tile. :param compute_dtype: (optional) dtype for the internal computation - currently dtype and compute_dtype behave the same, both sets internal compute and return dtype. :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile that has softmax of x. """ ...

[docs]def softplus(x, *, dtype=None, mask=None, **kwargs): r""" Softplus activation function on the input, element-wise.

Softplus is a smooth approximation to the ReLU activation, defined as:

softplus(x) = log(1 + exp(x))

:param x: a tile. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tile. :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile that has softplus of x. """ ...

[docs]def sqrt(x, *, dtype=None, mask=None, **kwargs): r""" Non-negative square-root of the input, element-wise.

((Similar to numpy.sqrt <https://numpy.org/doc/stable/reference/generated/numpy.sqrt.html>_))

:param x: a tile. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tile. :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile that has square-root values of x. """ ...

[docs]def square(x, *, dtype=None, mask=None, **kwargs): r""" Square of the input, element-wise.

((Similar to numpy.square <https://numpy.org/doc/stable/reference/generated/numpy.square.html>_))

:param x: a tile. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tile. :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile that has square of x. """ ...

def static_cast(arr, dtype): ...

[docs]def static_range(*args): r""" Create a sequence of numbers for use as loop iterators in NKI, resulting in a fully unrolled loop. Unlike :doc:affine_range <nki.language.affine_range> or :doc:sequential_range <nki.language.sequential_range>, Neuron compiler will fully unroll the loop during NKI kernel tracing.

Notes:

""" ...

[docs]def store(dst, value, *, mask=None, **kwargs): r""" Store into a tensor on device memory (HBM) from on-chip memory (SBUF).

See :ref:nki-pm-memory for detailed information.

:param dst: HBM tensor to store the data into. :param value: An SBUF tile that contains the values to store. :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return:

.. nki_example:: ../../test/test_nki_nl_load_store.py :language: python 📑 NKI_EXAMPLE_14

.. note:: Partition dimension size can't exceed the hardware limitation of nki.language.tile_size.pmax, see :ref:nki-tile-size.

Partition dimension has to be the first dimension in the index tuple of a tile. Therefore, data may need to be split into multiple batches to load/store, for example:

.. nki_example:: ../../test/test_nki_nl_load_store.py :language: python 📑 NKI_EXAMPLE_15

Also supports indirect DMA access with dynamic index values:

.. nki_example:: ../../test/test_nki_nl_load_store_indirect.py :language: python 📑 NKI_EXAMPLE_16

.. nki_example:: ../../test/test_nki_nl_load_store_indirect.py :language: python 📑 NKI_EXAMPLE_17

""" ...

[docs]def subtract(x, y, *, dtype=None, mask=None, **kwargs): r""" Subtract the inputs, element-wise.

((Similar to numpy.subtract <https://numpy.org/doc/stable/reference/generated/numpy.subtract.html>_))

:param x: a tile or a scalar value. :param y: a tile or a scalar value. x.shape and y.shape must be broadcastable <https://numpy.org/doc/stable/user/basics.broadcasting.html>__ to a common shape, that will become the shape of the output. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tiles, or whichever input type has the highest precision (see :ref:nki-type-promotion for more information); :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile that has x - y, element-wise. """ ...

[docs]def sum(x, axis, *, dtype=None, mask=None, keepdims=False, **kwargs): r""" Sum of elements along the specified axis (or axes) of the input.

((Similar to numpy.sum <https://numpy.org/doc/stable/reference/generated/numpy.sum.html>_))

:param x: a tile. :param axis: int or tuple/list of ints. The axis (or axes) along which to operate; must be free dimensions, not partition dimension (0); can only be the last contiguous dim(s) of the tile: [1], [1,2], [1,2,3], [1,2,3,4] :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tile. :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :param keepdims: If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array. :return: a tile with the sum of elements along the provided axis. This return tile will have a shape of the input tile's shape with the specified axes removed. """ ...

[docs]def tan(x, *, dtype=None, mask=None, **kwargs): r""" Tangent of the input, element-wise.

((Similar to numpy.tan <https://numpy.org/doc/stable/reference/generated/numpy.tan.html>_))

:param x: a tile. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tile. :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile that has tangent values of x. """ ...

[docs]def tanh(x, *, dtype=None, mask=None, **kwargs): r""" Hyperbolic tangent of the input, element-wise.

((Similar to numpy.tanh <https://numpy.org/doc/stable/reference/generated/numpy.tanh.html>_))

:param x: a tile. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tile. :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile that has hyperbolic tangent values of x. """ ...

[docs]def transpose(x, *, dtype=None, mask=None, **kwargs): r""" Transposes a 2D tile between its partition and free dimension.

:param x: 2D input tile :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tile. :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile that has the values of the input tile with its partition and free dimensions swapped. """ ...

[docs]def trunc(x, *, dtype=None, mask=None, **kwargs): r""" Truncated value of the input, element-wise.

((Similar to numpy.trunc <https://numpy.org/doc/stable/reference/generated/numpy.trunc.html>_))

The truncated value of the scalar x is the nearest integer i which is closer to zero than x is. In short, the fractional part of the signed number x is discarded.

:param x: a tile. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tile. :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile that has truncated values of x. """ ...

[docs]def var(x, axis, *, dtype=None, mask=None, **kwargs): r""" Variance along the specified axis (or axes) of the input.

((Similar to numpy.var <https://numpy.org/doc/stable/reference/generated/numpy.var.html>_))

:param x: a tile. :param axis: int or tuple/list of ints. The axis (or axes) along which to operate; must be free dimensions, not partition dimension (0); can only be the last contiguous dim(s) of the tile: [1], [1,2], [1,2,3], [1,2,3,4] :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tile. :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile with the variance of the elements along the provided axis. This return tile will have a shape of the input tile's shape with the specified axes removed. """ ...

[docs]def where(condition, x, y, *, dtype=None, mask=None, **kwargs): r""" Return elements chosen from x or y depending on condition.

((Similar to numpy.where <https://numpy.org/doc/stable/reference/generated/numpy.where.html>_))

:param condition: if True, yield x, otherwise yield y. :param x: a tile with values from which to choose if condition is True. :param y: a tile or a numerical value from which to choose if condition is False. :param dtype: (optional) data type to cast the output type to (see :ref:nki-dtype for more information); if not specified, it will default to be the same as the data type of the input tiles, or whichever input type has the highest precision (see :ref:nki-type-promotion for more information); :param mask: (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see :ref:nki-mask for details) :return: a tile with elements from x where condition is True, and elements from y otherwise. """ ...

[docs]def zeros(shape, dtype, *, buffer=None, name="", **kwargs): r""" Create a new tensor of given shape and dtype on the specified buffer, filled with zeros.

((Similar to numpy.zeros <https://numpy.org/doc/stable/reference/generated/numpy.zeros.html>_))

:param shape: the shape of the tensor. :param dtype: the data type of the tensor (see :ref:nki-dtype for more information). :param buffer: the specific buffer (ie, :doc:sbuf<nki.language.sbuf>, :doc:psum<nki.language.psum>, :doc:hbm<nki.language.hbm>), defaults to :doc:sbuf<nki.language.sbuf>. :param name: the name of the tensor. :return: a new tensor allocated on the buffer. """ ...

[docs]def zeros_like(a, dtype=None, *, buffer=None, name="", **kwargs): r""" Create a new tensor of zeros with the same shape and type as a given tensor.

((Similar to numpy.zeros_like <https://numpy.org/doc/stable/reference/generated/numpy.zeros_like.html>_))

:param a: the tensor. :param dtype: the data type of the tensor (see :ref:nki-dtype for more information). :param buffer: the specific buffer (ie, :doc:sbuf<nki.language.sbuf>, :doc:psum<nki.language.psum>, :doc:hbm<nki.language.hbm>), defaults to :doc:sbuf<nki.language.sbuf>. :param name: the name of the tensor. :return: a tensor of zeros with the same shape and type as a given tensor. """ ...