nki.language — AWS Neuron Documentation (original) (raw)

Contents

This document is relevant for: Inf2, Trn1, Trn2

nki.language#

Memory operations#

load Load a tensor from device memory (HBM) into on-chip memory (SBUF).
store Store into a tensor on device memory (HBM) from on-chip memory (SBUF).
load_transpose2d Load a tensor from device memory (HBM) and 2D-transpose the data before storing into on-chip memory (SBUF).
atomic_rmw Perform an atomic read-modify-write operation on HBM data dst = op(dst, value)
copy Create a copy of the src tile.
broadcast_to Broadcast the src tile to a new shape based on numpy broadcast rules.

Creation operations#

ndarray Create a new tensor of given shape and dtype on the specified buffer.
empty_like Create a new tensor with the same shape and type as a given tensor.
zeros Create a new tensor of given shape and dtype on the specified buffer, filled with zeros.
zeros_like Create a new tensor of zeros with the same shape and type as a given tensor.
ones Create a new tensor of given shape and dtype on the specified buffer, filled with ones.
full Create a new tensor of given shape and dtype on the specified buffer, filled with initial value.
rand Generate a tile of given shape and dtype, filled with random values that are sampled from a uniform distribution between 0 and 1.
random_seed Sets a seed, specified by user, to the random number generator on HW.
shared_constant Create a new tensor filled with the data specified by data array.
shared_identity_matrix Create a new identity tensor with specified data type.

Math operations#

add Add the inputs, element-wise.
subtract Subtract the inputs, element-wise.
multiply Multiply the inputs, element-wise.
divide Divide the inputs, element-wise.
power Elements of x raised to powers of y, element-wise.
maximum Maximum of the inputs, element-wise.
minimum Minimum of the inputs, element-wise.
max Maximum of elements along the specified axis (or axes) of the input.
min Minimum of elements along the specified axis (or axes) of the input.
mean Arithmetic mean along the specified axis (or axes) of the input.
var Variance along the specified axis (or axes) of the input.
sum Sum of elements along the specified axis (or axes) of the input.
prod Product of elements along the specified axis (or axes) of the input.
all Whether all elements along the specified axis (or axes) evaluate to True.
abs Absolute value of the input, element-wise.
negative Numerical negative of the input, element-wise.
sign Sign of the numbers of the input, element-wise.
trunc Truncated value of the input, element-wise.
floor Floor of the input, element-wise.
ceil Ceiling of the input, element-wise.
mod Integer Mod of x / y, element-wise
fmod Floor-mod of x / y, element-wise.
exp Exponential of the input, element-wise.
log Natural logarithm of the input, element-wise.
cos Cosine of the input, element-wise.
sin Sine of the input, element-wise.
tan Tangent of the input, element-wise.
tanh Hyperbolic tangent of the input, element-wise.
arctan Inverse tangent of the input, element-wise.
sqrt Non-negative square-root of the input, element-wise.
rsqrt Reciprocal of the square-root of the input, element-wise.
sigmoid Logistic sigmoid activation function on the input, element-wise.
relu Rectified Linear Unit activation function on the input, element-wise.
gelu Gaussian Error Linear Unit activation function on the input, element-wise.
gelu_dx Derivative of Gaussian Error Linear Unit (gelu) on the input, element-wise.
gelu_apprx_tanh Gaussian Error Linear Unit activation function on the input, element-wise, with tanh approximation.
silu Sigmoid Linear Unit activation function on the input, element-wise.
silu_dx Derivative of Sigmoid Linear Unit activation function on the input, element-wise.
erf Error function of the input, element-wise.
erf_dx Derivative of the Error function (erf) on the input, element-wise.
softplus Softplus activation function on the input, element-wise.
mish Mish activation function on the input, element-wise.
square Square of the input, element-wise.
softmax Softmax activation function on the input, element-wise.
rms_norm Apply Root Mean Square Layer Normalization.
dropout Randomly zeroes some of the elements of the input tile given a probability rate.
matmul x @ y matrix multiplication of x and y.
transpose Transposes a 2D tile between its partition and free dimension.
reciprocal Reciprocal of the the input, element-wise.

Bitwise operations#

bitwise_and Bitwise AND of the two inputs, element-wise.
bitwise_or Bitwise OR of the two inputs, element-wise.
bitwise_xor Bitwise XOR of the two inputs, element-wise.
invert Bitwise NOT of the input, element-wise.
left_shift Bitwise left-shift x by y, element-wise.
right_shift Bitwise right-shift x by y, element-wise.

Logical operations#

equal Element-wise boolean result of x == y.
not_equal Element-wise boolean result of x != y.
greater Element-wise boolean result of x > y.
greater_equal Element-wise boolean result of x >= y.
less Element-wise boolean result of x < y.
less_equal Element-wise boolean result of x <= y.
logical_and Element-wise boolean result of x AND y.
logical_or Element-wise boolean result of x OR y.
logical_xor Element-wise boolean result of x XOR y.
logical_not Element-wise boolean result of NOT x.

Tensor manipulation operations#

ds Construct a dynamic slice for simple tensor indexing.
arange Return contiguous values within a given interval, used for indexing a tensor to define a tile.
mgrid Same as NumPy mgrid: "An instance which returns a dense (or fleshed out) mesh-grid when indexed, so that each returned argument has the same shape.
expand_dims Expand the shape of a tile.

Sorting/Searching operations#

where Return elements chosen from x or y depending on condition.

Collective communication operations#

all_reduce Apply reduce operation over multiple SPMD programs.

Iterators#

static_range Create a sequence of numbers for use as loop iterators in NKI, resulting in a fully unrolled loop.
affine_range Create a sequence of numbers for use as parallel loop iterators in NKI.
sequential_range Create a sequence of numbers for use as sequential loop iterators in NKI.

Memory Hierarchy#

par_dim Mark a dimension explicitly as a partition dimension.
psum PSUM - Only visible to each individual kernel instance in the SPMD grid, alias of nki.compiler.psum.auto_alloc()
sbuf State Buffer - Only visible to each individual kernel instance in the SPMD grid, alias of nki.compiler.sbuf.auto_alloc()
hbm HBM - Alias of private_hbm
private_hbm HBM - Only visible to each individual kernel instance in the SPMD grid
shared_hbm Shared HBM - Visible to all kernel instances in the SPMD grid

Others#

program_id Index of the current SPMD program along the given axis in the launch grid.
num_programs Number of SPMD programs along the given axes in the launch grid.
program_ndim Number of dimensions in the SPMD launch grid.
spmd_dim Create a dimension in the SPMD launch grid of a NKI kernel with sub-dimension tiling.
nc Create a logical neuron core dimension in launch grid.
device_print Print a message with a String prefix followed by the value of a tile x.
loop_reduce Apply reduce operation over a loop.

Data Types#

tfloat32 32-bit floating-point number (1S,8E,10M)
bfloat16 16-bit floating-point number (1S,8E,7M)
float8_e4m3 8-bit floating-point number (1S,4E,3M)
float8_e5m2 8-bit floating-point number (1S,5E,2M)

Constants#

This document is relevant for: Inf2, Trn1, Trn2