nki.language.affine_range — AWS Neuron Documentation (original) (raw)

This document is relevant for: Inf2, Trn1, Trn2

nki.language.affine_range#

nki.language.affine_range(*args, **kwargs)[source]#

Create a sequence of numbers for use as parallel loop iterators in NKI. affine_range should be the default loop iterator choice, when there is no loop carried dependency. Note, associative reductions are not considered loop carried dependencies in this context. A concrete example of associative reduction is multiple nl.matmulor nisa.nc_matmul calls accumulating into the same output buffer defined outside of this loop level (see code example #2 below).

When the above conditions are not met, we recommend using sequential_rangeinstead.

Notes:

1import neuronxcc.nki.language as nl 2 3####################################################################### 4# Example 1: No loop carried dependency 5# Input/Output tensor shape: [128, 2048] 6# Load one tile ([128, 512]) at a time, square the tensor element-wise, 7# and store it into output tile 8####################################################################### 9 10# Every loop instance works on an independent input/output tile. 11# No data dependency between loop instances. 12for i_input in nl.affine_range(input.shape[1] // 512): 13 offset = i_input * 512 14 input_sb = nl.load(input[0:input.shape[0], offset:offset+512]) 15 result = nl.multiply(input_sb, input_sb) 16 nl.store(output[0:input.shape[0], offset:offset+512], result) 17 18####################################################################### 19# Example 2: Matmul output buffer accumulation, a type of associative reduction 20# Input tensor shapes for nl.matmul: xT[K=2048, M=128] and y[K=2048, N=128] 21# Load one tile ([128, 128]) from both xT and y at a time, matmul and 22# accumulate into the same output buffer 23####################################################################### 24 25result_psum = nl.zeros((128, 128), dtype=nl.float32, buffer=nl.psum) 26for i_K in nl.affine_range(xT.shape[0] // 128): 27 offset = i_K * 128 28 xT_sbuf = nl.load(offset:offset+128, 0:xT.shape[1]]) 29 y_sbuf = nl.load(offset:offset+128, 0:y.shape[1]]) 30 31 result_psum += nl.matmul(xT_sbuf, y_sbuf, transpose_x=True)

This document is relevant for: Inf2, Trn1, Trn2