Matrix transpose perfomance profile explanation (original) (raw)
Hey I’m trying to profile some community matrix transpose implementation to better understand metrics in ncu, both input and output matrix are stored in row-major format (and bind to PyTorch tensors).
Part I
Impl A
The first implementation is:
__global__ void mat_transpose_f32_col2row_kernel(
float *x, float *y, const int row, const int col) {
const int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
const int global_row = global_idx / col;
const int global_col = global_idx % col;
const int in_idx = global_idx;
const int out_idx = global_col * row + global_row;
if (global_idx < row * col) {
y[out_idx] = x[in_idx];
}
}
, which does coalesced reads on input matrix but a strided writes on output matrix.
Impl B
__global__ void mat_transpose_f32_row2col_kernel(
float *x, float *y, const int row, const int col) {
const int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
const int global_col = global_idx / row;
const int global_row = global_idx % row;
const int in_idx = global_row * col + global_col;
const int out_idx = global_idx;
if (global_idx < row * col) {
y[out_idx] = x[in_idx];
}
}
The second does it in the opposite way, strided reads on input matrix and coalesced writes on output matrix.
Description
I expected them to have similar performance (although they are both suboptimal), however, the Impl B is notably faster than Impl A (2x faster).
DRAM
By observing the DRAM profile, I saw something weird. Both kernels did 4 MiB read from DRAM, but Impl B (strided reads) achieved a higher throughput than Impl A (as below, Impl A is Baseline and Impl B is Current):
So my Q1 will be why strided reads here yielded higher DRAM throughput than coalesced reads?
I also tried to interpret L2 and L1 cache statitics.
L2
As for L2 cache, it’s understandable that strided read/write will lead to more sector reads and lower request efficiency. One metric I cannot understand is L2 Fabric Total: why would Impl A lead to cache misses in l2 partition and how would this lead to overall performance impact?
L1
Comparing L1 and L2 cache, I expected the global_load_sectors * (1 - hit_rate) == l1_load_sectors (check red and blue and boxes), but it seems there is still a gap in between. Meanwhile, the Impl B l1 cache hit rate drops dramatically compared with Impl A, why does it behave like this?
Part II
Impl C
I also tried to implement Impl A using 2D indices (Impl C) and did coalesced row reads, I expected Impl A and Impl C to have the same performance or even the same saas code, as the mapping between thread and elements is the same. However, I spot similar issue as in Part I. Could you please provide some insights?
__global__ void mat_transpose_f32_col2row2d_kernel(
float *x, float *y, const int row, const int col) {
const int global_x = blockIdx.x * blockDim.x + threadIdx.x;
const int global_y = blockIdx.y * blockDim.y + threadIdx.y;
int in_idx = global_y * col + global_x;
int out_idx = global_x * row + global_y;
if (global_x < col && global_y < row) {
y[out_idx] = x[in_idx];
}
}
Impl D
Just for curiosity, I tried to implement the above using CuTe (but I didn’t use any advanced features), and none of the above issues appear; the kernel is even faster than plain CUDA C. How could this be possible and does NVCC provide special optimization for CuTe?
template <typename T, int BLK_M, int BLK_N, typename ThreadLayoutA,
typename ThreadLayoutB>
__global__ void mat_transpose_cute_reg_kernel(const T *pA, T *pB, int M, int N,
ThreadLayoutA tA,
ThreadLayoutB tB) {
int tx = threadIdx.x;
int bx = blockIdx.x, by = blockIdx.y;
auto mA =
make_tensor(make_gmem_ptr(pA),
make_layout(make_shape(M, N), GenRowMajor{})); // (M, N)
auto mB =
make_tensor(make_gmem_ptr(pB),
make_layout(make_shape(N, M), GenRowMajor{})); // (N, M)
auto gA = local_tile(mA, make_shape(Int<BLK_M>{}, Int<BLK_N>{}),
make_coord(bx, by)); // (BM, BN)
auto gB = local_tile(mB, make_shape(Int<BLK_N>{}, Int<BLK_M>{}),
make_coord(by, bx)); // (BN, BM)
auto cA = local_tile(make_identity_tensor(mA.shape()),
make_shape(Int<BLK_M>{}, Int<BLK_N>{}),
make_coord(bx, by)); // (BM, BN)
Tensor tAgA = local_partition(gA, tA, tx);
Tensor tBgB = local_partition(gB, tB, tx);
Tensor tAcA = local_partition(cA, tA, tx);
Tensor tApA = make_tensor<bool>(tAcA.shape(), tAcA.stride());
CUTE_UNROLL
for (int i = 0; i < size<0>(tApA); i++) {
CUTE_UNROLL
for (int j = 0; j < size<1>(tApA); j++) {
tApA(i, j) = get<0>(tAcA(i, j)) < M && get<1>(tAcA(i, j)) < N;
}
}
copy_if(tApA, tAgA, tBgB);
}
void mat_transpose_cute_row2col_reg(torch::Tensor x, torch::Tensor y) {
const int BM = UNIT_BLK_SIZE;
const int BN = UNIT_BLK_SIZE;
const int M = x.size(0);
const int N = x.size(1);
auto tA = make_layout(make_shape(Int<BM>{}, Int<BN>{}), GenColMajor{});
auto tB = make_layout(make_shape(Int<BN>{}, Int<BM>{}), GenRowMajor{});
static_assert(size(tA) == size(tB));
dim3 block(size(tA));
dim3 grid((M + BM - 1) / BM, (N + BN - 1) / BN);
mat_transpose_cute_reg_kernel<float, BM, BN, decltype(tA), decltype(tB)>
<<<grid, block>>>(x.data_ptr<float>(), y.data_ptr<float>(), M, N, tA, tB);
CUDA_CHECK(cudaGetLastError());
}
void mat_transpose_cute_col2row_reg(torch::Tensor x, torch::Tensor y) {
const int BM = UNIT_BLK_SIZE;
const int BN = UNIT_BLK_SIZE;
const int M = x.size(0);
const int N = x.size(1);
auto tA = make_layout(make_shape(Int<BM>{}, Int<BN>{}), GenRowMajor{});
auto tB = make_layout(make_shape(Int<BN>{}, Int<BM>{}), GenColMajor{});
static_assert(size(tA) == size(tB));
dim3 block(size(tA));
dim3 grid((M + BM - 1) / BM, (N + BN - 1) / BN);
mat_transpose_cute_reg_kernel<float, BM, BN, decltype(tA), decltype(tB)>
<<<grid, block>>>(x.data_ptr<float>(), y.data_ptr<float>(), M, N, tA, tB);
CUDA_CHECK(cudaGetLastError());
}
Profile and code
Sorry that my questions might seem dumb and not providing enough details. The original codes and profiles are here in case you need them:
code_and_profile.zip (4.2 MB)