nki.language — AWS Neuron Documentation (original) (raw)
Contents
- Memory operations
- Creation operations
- Math operations
- Bitwise operations
- Logical operations
- Tensor manipulation operations
- Sorting/Searching operations
- Collective communication operations
- Iterators
- Memory Hierarchy
- Others
- Data Types
- Constants
This document is relevant for: Inf2
, Trn1
, Trn2
nki.language#
Memory operations#
load | Load a tensor from device memory (HBM) into on-chip memory (SBUF). |
---|---|
store | Store into a tensor on device memory (HBM) from on-chip memory (SBUF). |
load_transpose2d | Load a tensor from device memory (HBM) and 2D-transpose the data before storing into on-chip memory (SBUF). |
atomic_rmw | Perform an atomic read-modify-write operation on HBM data dst = op(dst, value) |
copy | Create a copy of the src tile. |
broadcast_to | Broadcast the src tile to a new shape based on numpy broadcast rules. |
Creation operations#
ndarray | Create a new tensor of given shape and dtype on the specified buffer. |
---|---|
empty_like | Create a new tensor with the same shape and type as a given tensor. |
zeros | Create a new tensor of given shape and dtype on the specified buffer, filled with zeros. |
zeros_like | Create a new tensor of zeros with the same shape and type as a given tensor. |
ones | Create a new tensor of given shape and dtype on the specified buffer, filled with ones. |
full | Create a new tensor of given shape and dtype on the specified buffer, filled with initial value. |
rand | Generate a tile of given shape and dtype, filled with random values that are sampled from a uniform distribution between 0 and 1. |
random_seed | Sets a seed, specified by user, to the random number generator on HW. |
shared_constant | Create a new tensor filled with the data specified by data array. |
shared_identity_matrix | Create a new identity tensor with specified data type. |
Math operations#
add | Add the inputs, element-wise. |
---|---|
subtract | Subtract the inputs, element-wise. |
multiply | Multiply the inputs, element-wise. |
divide | Divide the inputs, element-wise. |
power | Elements of x raised to powers of y, element-wise. |
maximum | Maximum of the inputs, element-wise. |
minimum | Minimum of the inputs, element-wise. |
max | Maximum of elements along the specified axis (or axes) of the input. |
min | Minimum of elements along the specified axis (or axes) of the input. |
mean | Arithmetic mean along the specified axis (or axes) of the input. |
var | Variance along the specified axis (or axes) of the input. |
sum | Sum of elements along the specified axis (or axes) of the input. |
prod | Product of elements along the specified axis (or axes) of the input. |
all | Whether all elements along the specified axis (or axes) evaluate to True. |
abs | Absolute value of the input, element-wise. |
negative | Numerical negative of the input, element-wise. |
sign | Sign of the numbers of the input, element-wise. |
trunc | Truncated value of the input, element-wise. |
floor | Floor of the input, element-wise. |
ceil | Ceiling of the input, element-wise. |
mod | Integer Mod of x / y, element-wise |
fmod | Floor-mod of x / y, element-wise. |
exp | Exponential of the input, element-wise. |
log | Natural logarithm of the input, element-wise. |
cos | Cosine of the input, element-wise. |
sin | Sine of the input, element-wise. |
tan | Tangent of the input, element-wise. |
tanh | Hyperbolic tangent of the input, element-wise. |
arctan | Inverse tangent of the input, element-wise. |
sqrt | Non-negative square-root of the input, element-wise. |
rsqrt | Reciprocal of the square-root of the input, element-wise. |
sigmoid | Logistic sigmoid activation function on the input, element-wise. |
relu | Rectified Linear Unit activation function on the input, element-wise. |
gelu | Gaussian Error Linear Unit activation function on the input, element-wise. |
gelu_dx | Derivative of Gaussian Error Linear Unit (gelu) on the input, element-wise. |
gelu_apprx_tanh | Gaussian Error Linear Unit activation function on the input, element-wise, with tanh approximation. |
silu | Sigmoid Linear Unit activation function on the input, element-wise. |
silu_dx | Derivative of Sigmoid Linear Unit activation function on the input, element-wise. |
erf | Error function of the input, element-wise. |
erf_dx | Derivative of the Error function (erf) on the input, element-wise. |
softplus | Softplus activation function on the input, element-wise. |
mish | Mish activation function on the input, element-wise. |
square | Square of the input, element-wise. |
softmax | Softmax activation function on the input, element-wise. |
rms_norm | Apply Root Mean Square Layer Normalization. |
dropout | Randomly zeroes some of the elements of the input tile given a probability rate. |
matmul | x @ y matrix multiplication of x and y. |
transpose | Transposes a 2D tile between its partition and free dimension. |
reciprocal | Reciprocal of the the input, element-wise. |
Bitwise operations#
bitwise_and | Bitwise AND of the two inputs, element-wise. |
---|---|
bitwise_or | Bitwise OR of the two inputs, element-wise. |
bitwise_xor | Bitwise XOR of the two inputs, element-wise. |
invert | Bitwise NOT of the input, element-wise. |
left_shift | Bitwise left-shift x by y, element-wise. |
right_shift | Bitwise right-shift x by y, element-wise. |
Logical operations#
equal | Element-wise boolean result of x == y. |
---|---|
not_equal | Element-wise boolean result of x != y. |
greater | Element-wise boolean result of x > y. |
greater_equal | Element-wise boolean result of x >= y. |
less | Element-wise boolean result of x < y. |
less_equal | Element-wise boolean result of x <= y. |
logical_and | Element-wise boolean result of x AND y. |
logical_or | Element-wise boolean result of x OR y. |
logical_xor | Element-wise boolean result of x XOR y. |
logical_not | Element-wise boolean result of NOT x. |
Tensor manipulation operations#
ds | Construct a dynamic slice for simple tensor indexing. |
---|---|
arange | Return contiguous values within a given interval, used for indexing a tensor to define a tile. |
mgrid | Same as NumPy mgrid: "An instance which returns a dense (or fleshed out) mesh-grid when indexed, so that each returned argument has the same shape. |
expand_dims | Expand the shape of a tile. |
Sorting/Searching operations#
where | Return elements chosen from x or y depending on condition. |
---|
Collective communication operations#
all_reduce | Apply reduce operation over multiple SPMD programs. |
---|
Iterators#
static_range | Create a sequence of numbers for use as loop iterators in NKI, resulting in a fully unrolled loop. |
---|---|
affine_range | Create a sequence of numbers for use as parallel loop iterators in NKI. |
sequential_range | Create a sequence of numbers for use as sequential loop iterators in NKI. |
Memory Hierarchy#
par_dim | Mark a dimension explicitly as a partition dimension. |
---|---|
psum | PSUM - Only visible to each individual kernel instance in the SPMD grid, alias of nki.compiler.psum.auto_alloc() |
sbuf | State Buffer - Only visible to each individual kernel instance in the SPMD grid, alias of nki.compiler.sbuf.auto_alloc() |
hbm | HBM - Alias of private_hbm |
private_hbm | HBM - Only visible to each individual kernel instance in the SPMD grid |
shared_hbm | Shared HBM - Visible to all kernel instances in the SPMD grid |
Others#
program_id | Index of the current SPMD program along the given axis in the launch grid. |
---|---|
num_programs | Number of SPMD programs along the given axes in the launch grid. |
program_ndim | Number of dimensions in the SPMD launch grid. |
spmd_dim | Create a dimension in the SPMD launch grid of a NKI kernel with sub-dimension tiling. |
nc | Create a logical neuron core dimension in launch grid. |
device_print | Print a message with a String prefix followed by the value of a tile x. |
loop_reduce | Apply reduce operation over a loop. |
Data Types#
tfloat32 | 32-bit floating-point number (1S,8E,10M) |
---|---|
bfloat16 | 16-bit floating-point number (1S,8E,7M) |
float8_e4m3 | 8-bit floating-point number (1S,4E,3M) |
float8_e5m2 | 8-bit floating-point number (1S,5E,2M) |
Constants#
This document is relevant for: Inf2
, Trn1
, Trn2