nki.isa.nc_find_index8 — AWS Neuron Documentation (original) (raw)

Contents

This document is relevant for: Inf2, Trn1, Trn2

nki.isa.nc_find_index8#

nki.isa.nc_find_index8(*, data, vals, mask=None, dtype=None, **kwargs)[source]#

Find indices of the 8 given vals in each partition of the data tensor.

This instruction first loads the 8 values, then loads the data tensor and outputs the indices (starting at 0) of the first occurrence of each value in the data tensor, for each partition.

The data tensor can be up to 5-dimensional, while the vals tensor must be up to 3-dimensional. The data tensor must have between 8 and 16,384 elements per partition. The vals tensor must have exactly 8 elements per partition. The output will contain exactly 8 elements per partition and will be uint16 or uint32 type. Default output type is uint32.

Behavior is undefined if vals tensor contains values that are not in the data tensor.

If provided, a mask is applied only to the data tensor.

Estimated instruction cost:

N engine cycles, where:

Parameters:

Returns:

a 2D tile containing indices (uint16 or uint32) of the 8 values in each partition with shape [par_dim, 8]

Example:

import neuronxcc.nki.isa as nisa import neuronxcc.nki.language as nl from neuronxcc.nki.typing import tensor

##################################################################

Example 1: Generate tile b of 32 * 128 random floating point values,

find the 8 largest values in each row, then find their indices:

##################################################################

Generate random data

data = nl.rand((32, 128))

Find max 8 values per row

max_vals = nisa.max8(src=data)

Create output tensor for indices

indices_tensor = nl.ndarray([32, 8], dtype=nl.uint32, buffer=nl.shared_hbm)

Find indices of max values

indices = nisa.nc_find_index8(data=data, vals=max_vals)

Store results

nl.store(indices_tensor, value=indices)

This document is relevant for: Inf2, Trn1, Trn2