nki.isa.affine_select — AWS Neuron Documentation (original) (raw)

This document is relevant for: Inf2, Trn1, Trn2

nki.isa.affine_select#

nki.isa.affine_select(pred, on_true_tile, on_false_value, *, mask=None, dtype=None, **kwargs)[source]#

Select elements between an input tile on_true_tile and a scalar value on_false_valueaccording to a boolean predicate tile using GpSimd Engine. The predicate tile is calculated on-the-fly in the engine by evaluating an affine expression element-by-element as indicated in pred.

pred must meet the following requirements:

For a complex predicate that doesn’t meet the above requirements, consider using nl.where.

The input tile on_true_tile, the calculated boolean predicate tile expressed by pred, and the returned output tile of this instruction must have the same shape. If the predicate value of a given position is True, the corresponding output element will take the element from on_true_tile in the same position. If the predicate value of a given position is False, the corresponding output element will take the value of on_false_value.

A common use case for affine_select is to apply a causal mask on the attention scores for transformer decoder models.

This instruction allows any float or 8-bit/16-bit integer data types for both the input data tile and output tile (see Supported Data Types for more information). The output tile data type is specified using the dtype field. If dtype is not specified, the output data type will be the same as the input data type of data. However, the data type of on_false_value must be float32, regardless of the input/output tile data types.

Estimated instruction cost:

GPSIMD_START + N GpSimd Engine cycles, where N is the number of elements per partition in on_true_tile andGPSIMD_START is the instruction startup overhead on GpSimdE, roughly 150 engine cycles.

Parameters:

Returns:

an output tile with values selected from either on_true_tile oron_false_value according to the following equation: output[x] = (pred[x] > 0) ? on_true_tile[x] : on_false_value

Example:

import neuronxcc.nki.isa as nisa import neuronxcc.nki.language as nl

##################################################################

Example 1: Take tile a of shape [128, 128] and replace its

upper triangle with -9984.0;

################################################################## ix, iy = nl.mgrid[0:128, 0:128] a = nl.load(a_tensor[ix, iy])

b = nisa.affine_select(pred=(iy <ix), on_true_tile=a[ix, iy], on_false_value=-9984.0)

nl.store(b_tensor[ix, iy], b)

This document is relevant for: Inf2, Trn1, Trn2