nki.isa.affine_select — AWS Neuron Documentation (original) (raw)
This document is relevant for: Inf2
, Trn1
, Trn2
nki.isa.affine_select#
nki.isa.affine_select(pred, on_true_tile, on_false_value, *, mask=None, dtype=None, **kwargs)[source]#
Select elements between an input tile on_true_tile
and a scalar value on_false_value
according to a boolean predicate tile using GpSimd Engine. The predicate tile is calculated on-the-fly in the engine by evaluating an affine expression element-by-element as indicated in pred
.
pred
must meet the following requirements:
- It must not depend on any runtime variables that can’t be resolved at compile-time.
- It can’t be multiple masks combined using logical operators such as
&
and|
.
For a complex predicate that doesn’t meet the above requirements, consider using nl.where.
The input tile on_true_tile
, the calculated boolean predicate tile expressed by pred
, and the returned output tile of this instruction must have the same shape. If the predicate value of a given position is True
, the corresponding output element will take the element from on_true_tile
in the same position. If the predicate value of a given position is False
, the corresponding output element will take the value of on_false_value
.
A common use case for affine_select
is to apply a causal mask on the attention scores for transformer decoder models.
This instruction allows any float or 8-bit/16-bit integer data types for both the input data tile and output tile (see Supported Data Types for more information). The output tile data type is specified using the dtype
field. If dtype
is not specified, the output data type will be the same as the input data type of data
. However, the data type of on_false_value
must be float32, regardless of the input/output tile data types.
Estimated instruction cost:
GPSIMD_START + N
GpSimd Engine cycles, where N
is the number of elements per partition in on_true_tile
andGPSIMD_START
is the instruction startup overhead on GpSimdE, roughly 150 engine cycles.
Parameters:
- pred – an affine expression that defines the boolean predicate
- on_true_tile – an input tile for selection with a
True
predicate value - on_false_value – a scalar value for selection with a
False
predicate value - mask – (optional) a compile-time constant predicate that controls whether/how this instruction is executed (see NKI API Masking for details)
- dtype – (optional) data type to cast the output type to (see Supported Data Types for more information); if not specified, it will default to be the same as the data type of the input tiles, or whichever input type has the highest precision (see NKI Type Promotion for more information);
Returns:
an output tile with values selected from either on_true_tile
oron_false_value
according to the following equation: output[x] = (pred[x] > 0) ? on_true_tile[x] : on_false_value
Example:
import neuronxcc.nki.isa as nisa import neuronxcc.nki.language as nl
##################################################################
Example 1: Take tile a of shape [128, 128] and replace its
upper triangle with -9984.0;
################################################################## ix, iy = nl.mgrid[0:128, 0:128] a = nl.load(a_tensor[ix, iy])
b = nisa.affine_select(pred=(iy <ix), on_true_tile=a[ix, iy], on_false_value=-9984.0)
nl.store(b_tensor[ix, iy], b)
This document is relevant for: Inf2
, Trn1
, Trn2