Transpose2D — AWS Neuron Documentation (original) (raw)
This document is relevant for: Inf2
, Trn1
, Trn2
Transpose2D#
In this tutorial, we transpose a tensor along two of its axes using NKI. In doing so, we learn about:
- The NKI syntax and programming model.
- Multi-dimensional memory address patterns in NKI.
As background, there are two main types of transposition in NKI:
- Transposition between the partition-dimension axis and one of the free-dimension axes, which is achieved via the
nki.isa.nc_transpose
instruction. - Transposition between two axes on the free-dimension, which is achieved via a
nki.language.copy
instruction, with indexing manipulation in the free axis to re-arrange the data.
In this example, we’ll focus on the second case: consider a three-dimensional input tensor [P, F1, F2]
, where the P
axis is mapped to the different SBUF partitions and the F1
and F2
axes are flattened and placed in each partition, with F1
being the major dimension. Our goal in this example is to transpose the F1
andF2
axes with a parallel dimension P
, to re-arrange the data within each partition. Figurebelow illustrates the input and output tensor layouts.
Fig. 78 Tensor F1:F2 Transpose#
PyTorch#
Compute kernel#
1import neuronxcc.nki as nki 2import neuronxcc.nki.language as nl 3 4 5@nki.jit 6def tensor_transpose2D_kernel_(in_tensor, shape2D): 7 """ 8 NKI kernel to reorder the elements on axis[1] of the input tensor. 9 10 Every row of the input tensor is a flattened row-major 2D matrix. 11 The shape2D argument defines the dimensions of the flattened matrices (#rows,#cols). 12 Our goal in this kernel is to transpose these flattened 2D matrices, i.e. make them (#cols,#rows). 13 14 Example: 15 in_tensor = [a0,a1,a2,a3,b0,b1,b2,b3,c0,c1,c2,c3] 16 shape2D = (3,4) 17 this means that in_tensor has 3 rows and 4 columns, i.e. can be represented as: 18 [a0,a1,a2,a3] 19 [b0,b1,b2,b3] 20 [c0,c1,c2,c3] 21 after transpose, we expect to get: 22 [a0,b0,c0] 23 [a1,b1,c1] 24 [a2,b2,c2] 25 [a3,b3,c3] 26 Thus, out_tensor is expected to be [a0,b0,c0,a1,b1,c1,a2,b2,c2,a3,b3,c3] 27 28 Args: 29 in_tensor: an input tensor 30 shape2D: tuple representing the dimensions to be transposed: (#rows, #cols) 31 out_tensor: an output (transposed) tensor 32 """ 33 out_tensor = nl.ndarray(in_tensor.shape, dtype=in_tensor.dtype, 34 buffer=nl.shared_hbm) 35 # Gather input shapes 36 sz_p, _ = in_tensor.shape 37 38 # Load input data from external memory to on-chip memory 39 in_tile = nl.load(in_tensor) 40 41 # Performing f1/f2 transpose 42 # ========================== 43 # The desired transpose pattern is provided as an input: 44 sz_f1, sz_f2 = shape2D 45 46 # We're going to need 3 indices to perform f1:f2 transpose. 47 # - i_p0 is the parallel index 48 # - i_f1 and i_f2 are both free-dim indices, and will be used to transpose between the f1/f2 axes 49 i_p0 = nl.arange(sz_p)[:, None, None] 50 i_f1 = nl.arange(sz_f1)[None, :, None] 51 i_f2 = nl.arange(sz_f2)[None, None, :] 52 53 # Perform the transposition via a SBUF-to-SBUF copy, with access-pattern manipulation 54 # Note that we have 2D tensors and 3 indices, since we need to represent a 2D access pattern per partition 55 # RHS traverses an F1 x F2 matrix in a row major manner 56 # LHS traverses an F2 x F1 (new) matrix in a row major manner 57 out_tile = nl.ndarray(shape=(sz_p, sz_f2sz_f1), dtype=out_tensor.dtype) 58 out_tile[i_p0, i_f2sz_f1+i_f1] = nl.copy(in_tile[i_p0, i_f1*sz_f2+i_f2]) 59 60 # Finally, we store out_tile to external memory 61 nl.store(out_tensor, value=out_tile) 62 63 return out_tensor
Launching kernel and testing correctness#
To execute the kernel, we prepare tensors a
and call tensor_transpose2D_kernel_
:
1import torch 2from torch_xla.core import xla_model as xm 3 4if name == "main": 5 device = xm.xla_device() 6 7 P, X, Y = 5, 3, 4 8 a = torch.arange(PXY, dtype=torch.int8).reshape((P, XY)).to(device=device) 9 a_t_nki = torch.zeros((P, YX), dtype=torch.int8).to(device=device) 10 11 a_t_nki = tensor_transpose2D_kernel_(a, (X, Y)) 12 13 a_t_torch = torch.transpose(a.reshape(P, X, Y), 1, 2).reshape(P, X * Y) 14 15 print(a, a_t_nki, a_t_torch) 16 17 allclose = torch.allclose(a_t_torch, a_t_nki) 18 if allclose: 19 print("NKI and PyTorch match") 20 else: 21 print("NKI and PyTorch differ") 22 23 assert allclose
JAX#
Compute kernel#
We can reuse the same NKI compute kernel defined for PyTorch above.
1import neuronxcc.nki as nki 2import neuronxcc.nki.language as nl 3 4 5@nki.jit 6def tensor_transpose2D_kernel_(in_tensor, shape2D): 7 """ 8 NKI kernel to reorder the elements on axis[1] of the input tensor. 9 10 Every row of the input tensor is a flattened row-major 2D matrix. 11 The shape2D argument defines the dimensions of the flattened matrices (#rows,#cols). 12 Our goal in this kernel is to transpose these flattened 2D matrices, i.e. make them (#cols,#rows). 13 14 Example: 15 in_tensor = [a0,a1,a2,a3,b0,b1,b2,b3,c0,c1,c2,c3] 16 shape2D = (3,4) 17 this means that in_tensor has 3 rows and 4 columns, i.e. can be represented as: 18 [a0,a1,a2,a3] 19 [b0,b1,b2,b3] 20 [c0,c1,c2,c3] 21 after transpose, we expect to get: 22 [a0,b0,c0] 23 [a1,b1,c1] 24 [a2,b2,c2] 25 [a3,b3,c3] 26 Thus, out_tensor is expected to be [a0,b0,c0,a1,b1,c1,a2,b2,c2,a3,b3,c3] 27 28 Args: 29 in_tensor: an input tensor 30 shape2D: tuple representing the dimensions to be transposed: (#rows, #cols) 31 out_tensor: an output (transposed) tensor 32 """ 33 out_tensor = nl.ndarray(in_tensor.shape, dtype=in_tensor.dtype, 34 buffer=nl.shared_hbm) 35 # Gather input shapes 36 sz_p, _ = in_tensor.shape 37 38 # Load input data from external memory to on-chip memory 39 in_tile = nl.load(in_tensor) 40 41 # Performing f1/f2 transpose 42 # ========================== 43 # The desired transpose pattern is provided as an input: 44 sz_f1, sz_f2 = shape2D 45 46 # We're going to need 3 indices to perform f1:f2 transpose. 47 # - i_p0 is the parallel index 48 # - i_f1 and i_f2 are both free-dim indices, and will be used to transpose between the f1/f2 axes 49 i_p0 = nl.arange(sz_p)[:, None, None] 50 i_f1 = nl.arange(sz_f1)[None, :, None] 51 i_f2 = nl.arange(sz_f2)[None, None, :] 52 53 # Perform the transposition via a SBUF-to-SBUF copy, with access-pattern manipulation 54 # Note that we have 2D tensors and 3 indices, since we need to represent a 2D access pattern per partition 55 # RHS traverses an F1 x F2 matrix in a row major manner 56 # LHS traverses an F2 x F1 (new) matrix in a row major manner 57 out_tile = nl.ndarray(shape=(sz_p, sz_f2sz_f1), dtype=out_tensor.dtype) 58 out_tile[i_p0, i_f2sz_f1+i_f1] = nl.copy(in_tile[i_p0, i_f1*sz_f2+i_f2]) 59 60 # Finally, we store out_tile to external memory 61 nl.store(out_tensor, value=out_tile) 62 63 return out_tensor
Launching kernel and testing correctness#
To execute the kernel, we prepare array a
and call tensor_transpose2D_kernel_
:
1import jax 2import jax.numpy as jnp 3 4if name == "main": 5 P, X, Y = 5, 37, 44 6 a = jax.random.uniform(jax.random.PRNGKey(42), (P, X * Y)) 7 a_t_nki = tensor_transpose2D_kernel_(a, shape2D=(X, Y)) 8 9 a_t_jax = jnp.transpose(a.reshape(P, X, Y), axes=(0, 2, 1)).reshape(P, X * Y) 10 print(a, a_t_nki, a_t_jax) 11 12 allclose = jnp.allclose(a_t_jax, a_t_nki) 13 if allclose: 14 print("NKI and JAX match") 15 else: 16 print("NKI and JAX differ") 17 18 assert allclose
Note
We pass shape2D
as kwargs to pass the shape as a compile-time constant to the kernel function.
Download All Source Code#
Click the links to download source code of the kernels and the testing code discussed in this tutorial.
- NKI baremetal implementation: transpose2d_nki_kernels.py
- PyTorch implementation: transpose2d_torch.py
- You must also download transpose2d_nki_kernels.pyinto the same folder to run this PyTorch script.
- JAX implementation: transpose2d_jax.py
- You must also download transpose2d_nki_kernels.pyinto the same folder to run this JAX script.
You can also view the source code in the GitHub repository nki_samples
Example usage of the scripts:#
Run NKI baremetal implementation:
python3 transpose2d_nki_kernels.py
Run PyTorch implementation:
python3 transpose2d_torch.py
Run JAX implementation:
python3 transpose2d_jax.py
This document is relevant for: Inf2
, Trn1
, Trn2