nki.isa.nc_match_replace8 — AWS Neuron Documentation (original) (raw)

This document is relevant for: Inf2, Trn1, Trn2

nki.isa.nc_match_replace8#

nki.isa.nc_match_replace8(*, data, vals, imm, mask=None, dtype=None, **kwargs)[source]#

Replace first occurrence of each value in vals with imm in datausing the Vector engine. This is an in-place modification of the datatensor.

This instruction reads the input data and replaces the first occurrence of each of the given values (from vals tensor) with the specified immediate constant. Other values are written out unchanged.

The data tensor can be up to 5-dimensional, while the vals tensor can be up to 3-dimensional. The vals tensor must have exactly 8 elements per partition. The data tensor must have no more than 16,384 elements per partition. The output will have the same shape as the input data tensor. data and valsmust have the same number of partitions. Both input tensors can come from SBUF or PSUM.

Behavior is undefined if vals tensor contains values that are not in the data tensor.

If provided, a mask is applied to the data tensor.

Estimated instruction cost:

N engine cycles, where:

Parameters:

Returns:

the modified data tensor

Example:

import neuronxcc.nki.isa as nisa import neuronxcc.nki.language as nl from neuronxcc.nki.typing import tensor

##################################################################

Example 1: Generate tile a of random floating point values,

get the 8 largest values in each row, then replace their first

occurrences with -inf:

################################################################## N = 4 M = 16 data_tile = nl.rand((N, M)) max_vals = nisa.max8(src=data_tile)

result = nisa.nc_match_replace8(data=data_tile[:, :], vals=max_vals, imm=float('-inf')) result_tensor = nl.ndarray([N, M], dtype=nl.float32, buffer=nl.shared_hbm) nl.store(result_tensor, value=result)

This document is relevant for: Inf2, Trn1, Trn2