nki.isa.max8 — AWS Neuron Documentation (original) (raw)

This document is relevant for: Inf2, Trn1, Trn2

nki.isa.max8#

nki.isa.max8(*, src, mask=None, dtype=None, **kwargs)[source]#

Find the 8 largest values in each partition of the source tile.

This instruction reads the input elements, converts them to fp32 internally, and outputs the 8 largest values in descending order for each partition. By default, returns the same dtype as the input tensor.

The source tile can be up to 5-dimensional, while the output tile is always 2-dimensional. The number of elements read per partition must be between 8 and 16,384 inclusive. The output will always contain exactly 8 elements per partition. The source and output must have the same partition dimension size:

Estimated instruction cost:

N engine cycles, where:

Parameters:

Returns:

a 2D tile containing the 8 largest values per partition in descending order with shape [par_dim, 8]

Example:

import neuronxcc.nki.isa as nisa import neuronxcc.nki.language as nl from neuronxcc.nki.typing import tensor

##################################################################

Example 1: Generate tile b of 32 * 128 random floating point values

and get the 8 largest values in each row:

################################################################## expr_a = nl.rand((32, 128)) a = nisa.max8(src=expr_a)

a_tensor = nl.ndarray([32, 8], dtype=nl.float32, buffer=nl.shared_hbm) nl.store(a_tensor, value=a)

This document is relevant for: Inf2, Trn1, Trn2