nki.language.spmd_dim — AWS Neuron Documentation (original) (raw)

This document is relevant for: Inf2, Trn1, Trn2

nki.language.spmd_dim#

nki.language.spmd_dim = Ellipsis#

Create a dimension in the SPMD launch grid of a NKI kernel with sub-dimension tiling.

A key use case for spmd_dim is to shard an existing NKI kernel over multiple NeuronCores without modifying the internal kernel implementation. Suppose we have a kernel, nki_spmd_kernel, which is launched with a 2D SPMD grid, (4, 2). We can shard the first dimension of the launch grid (size 4) over two physical NeuronCores by directly manipulating the launch grid as follows:

import neuronxcc.nki.language as nl

@nki.jit def nki_spmd_kernel(a): b = nl.ndarray(a.shape, dtype=a.dtype, buffer=nl.shared_hbm) i = nl.program_id(0) j = nl.program_id(1)

a_tile = nl.load(a[i, j]) nl.store(b[i, j], a_tile)

return b

############################################################################

Example 1: Let compiler decide how to distribute the instances of spmd kernel

############################################################################ dst = nki_spmd_kernel4, 2

############################################################################

Example 2: Distribute SPMD kernel instances to physical NeuronCores with

explicit annotations. Expected physical NeuronCore assignments:

Physical NC [0]: kernel[0, 0], kernel[0, 1], kernel[1, 0], kernel[1, 1]

Physical NC [1]: kernel[2, 0], kernel[2, 1], kernel[3, 0], kernel[3, 1]

############################################################################ dst = nki_spmd_kernelnl.spmd_dim(nl.nc(2), 2), 2 dst = nki_spmd_kernelnl.nc(2) * 2, 2 # syntactic sugar

############################################################################

Example 3: Distribute SPMD kernel instances to physical NeuronCores with

explicit annotations. Expected physical NeuronCore assignments:

Physical NC [0]: kernel[0, 0], kernel[0, 1], kernel[2, 0], kernel[2, 1]

Physical NC [1]: kernel[1, 0], kernel[1, 1], kernel[3, 0], kernel[3, 1]

############################################################################ dst = nki_spmd_kernelnl.spmd_dim(2, nl.nc(2)), 2 dst = nki_spmd_kernel2 * nl.nc(2), 2 # syntactic sugar

This document is relevant for: Inf2, Trn1, Trn2