Issue with cudnnMultiHeadAttnForward() – CUDNN_STATUS_BAD_PARAM (original) (raw)
Hello,
I am currently implementing a wrapper function to call the cudnnMultiHeadAttnForward() API provided by cuDNN. However, after extensive testing, I consistently encounter a parameter error:
cudnnMultiHeadAttnForward failed: CUDNN_STATUS_BAD_PARAM
I am unsure whether the issue comes from my API usage or from how I am passing the parameters.
Below is the implementation of my wrapper function:
extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuCudnnMultiHeadAttention(
int num_heads,
int embed_dim,
int max_seq_len_q,
int max_seq_len_kv,
int batch_size,
int curr_idx,
const int* lo_win_idx,
const int* hi_win_idx,
const void* dev_seq_lengths_qo,
const void* dev_seq_lengths_kv,
const void* queries,
const void* residuals,
const void* keys,
const void* values,
void* output,
size_t weight_size_bytes,
const void* weights,
size_t workspace_size_bytes,
CUstream stream
) {
mgpuEnsureContext();
StreamHandles handles;
if (!getHandlesForStream(stream, handles)) {
fprintf(stderr, "[MHA] ERROR: Failed to get handles for stream %p\n", stream);
return;
}
cudnnHandle_t handle = handles.cudnn_handle;
if (num_heads <= 0 || embed_dim <= 0 || max_seq_len_q <= 0 || max_seq_len_kv <= 0 || batch_size <= 0) {
fprintf(stderr, "[MHA] ERROR: Invalid dimensions: heads=%d, embed=%d, seq_q=%d, seq_kv=%d, batch=%d\n",
num_heads, embed_dim, max_seq_len_q, max_seq_len_kv, batch_size);
return;
}
if (!queries || !keys || !values || !output) {
fprintf(stderr, "[MHA] ERROR: One or more data pointers are NULL\n");
return;
}
if (!weights && weight_size_bytes > 0) {
fprintf(stderr, "[MHA] ERROR: Weights pointer is NULL but weight size > 0\n");
return;
}
if (embed_dim % num_heads != 0) {
fprintf(stderr, "[MHA] ERROR: embed_dim (%d) must be divisible by num_heads (%d)\n",
embed_dim, num_heads);
return;
}
cudnnAttnDescriptor_t attn_desc = nullptr;
cudnnStatus_t status = cudnnCreateAttnDescriptor(&attn_desc);
if (status != CUDNN_STATUS_SUCCESS) {
fprintf(stderr, "[MHA] ERROR: Failed to create attention descriptor: %s\n",
cudnnGetErrorString(status));
return;
}
int head_dim = embed_dim / num_heads;
double sm_scaler = 1.0f / sqrtf((double)head_dim);
unsigned attnMode = 0;
attnMode |= CUDNN_ATTN_DISABLE_PROJ_BIASES;
attnMode |= CUDNN_ATTN_QUERYMAP_ALL_TO_ONE;
// attnMode |= CUDNN_ATTN_QUERYMAP_ONE_TO_ONE;
status = cudnnSetAttnDescriptor(
attn_desc,
attnMode, // attnMode
num_heads, // nHeads
sm_scaler, // smScaler
CUDNN_DATA_FLOAT, // dataType
CUDNN_DATA_FLOAT, // computePrec
CUDNN_DEFAULT_MATH, // mathType
nullptr, // attnDropoutDesc
nullptr, // postDropoutDesc
embed_dim, // qSize
embed_dim, // kSize
embed_dim, // vSize
0, // qProjSize
0, // kProjSize
0, // vProjSize
0, // oProjSize
max_seq_len_q, // qoMaxSeqLength
max_seq_len_kv, // kvMaxSeqLength
batch_size, // maxBatchSize
1
);
if (status != CUDNN_STATUS_SUCCESS) {
fprintf(stderr, "[MHA] ERROR: Failed to set attention descriptor: %s\n",
cudnnGetErrorString(status));
cudnnDestroyAttnDescriptor(attn_desc);
return;
}
cudnnSeqDataDescriptor_t q_desc = nullptr, k_desc = nullptr;
cudnnSeqDataDescriptor_t v_desc = nullptr, o_desc = nullptr;
status = cudnnCreateSeqDataDescriptor(&q_desc);
if (status != CUDNN_STATUS_SUCCESS) {
fprintf(stderr, "[MHA] ERROR: Failed to create Q descriptor: %s\n", cudnnGetErrorString(status));
cudnnDestroyAttnDescriptor(attn_desc);
return;
}
int q_dims[CUDNN_SEQDATA_DIM_COUNT];
q_dims[CUDNN_SEQDATA_BATCH_DIM] = batch_size;
q_dims[CUDNN_SEQDATA_TIME_DIM] = max_seq_len_q;
q_dims[CUDNN_SEQDATA_BEAM_DIM] = 1;
q_dims[CUDNN_SEQDATA_VECT_DIM] = embed_dim;
std::vector<int> q_seq_lengths(batch_size, max_seq_len_q);
cudnnSeqDataAxis_t q_axes[CUDNN_SEQDATA_DIM_COUNT];
q_axes[0] = CUDNN_SEQDATA_BATCH_DIM;
q_axes[1] = CUDNN_SEQDATA_BEAM_DIM;
q_axes[2] = CUDNN_SEQDATA_TIME_DIM;
q_axes[3] = CUDNN_SEQDATA_VECT_DIM;
status = cudnnSetSeqDataDescriptor(
q_desc,
CUDNN_DATA_FLOAT,
4,
q_dims,
q_axes,
batch_size, // seqLengthArraySize = batch_size * beam_size
q_seq_lengths.data(), // seqLengthArray
nullptr // paddingFill
);
if (status != CUDNN_STATUS_SUCCESS) {
fprintf(stderr, "[MHA] ERROR: Failed to set Q descriptor: %s\n", cudnnGetErrorString(status));
cudnnDestroySeqDataDescriptor(q_desc);
cudnnDestroyAttnDescriptor(attn_desc);
return;
}
status = cudnnCreateSeqDataDescriptor(&k_desc);
if (status != CUDNN_STATUS_SUCCESS) {
fprintf(stderr, "[MHA] ERROR: Failed to create K descriptor: %s\n", cudnnGetErrorString(status));
cudnnDestroySeqDataDescriptor(q_desc);
cudnnDestroyAttnDescriptor(attn_desc);
return;
}
int k_dims[CUDNN_SEQDATA_DIM_COUNT];
k_dims[CUDNN_SEQDATA_BATCH_DIM] = batch_size;
k_dims[CUDNN_SEQDATA_TIME_DIM] = max_seq_len_kv;
k_dims[CUDNN_SEQDATA_BEAM_DIM] = 1;
k_dims[CUDNN_SEQDATA_VECT_DIM] = embed_dim;
std::vector<int> kv_seq_lengths(batch_size, max_seq_len_kv);
cudnnSeqDataAxis_t kv_axes[CUDNN_SEQDATA_DIM_COUNT];
kv_axes[0] = CUDNN_SEQDATA_BATCH_DIM;
kv_axes[1] = CUDNN_SEQDATA_BEAM_DIM;
kv_axes[2] = CUDNN_SEQDATA_TIME_DIM;
kv_axes[3] = CUDNN_SEQDATA_VECT_DIM;
status = cudnnSetSeqDataDescriptor(
k_desc,
CUDNN_DATA_FLOAT,
4,
k_dims,
kv_axes,
batch_size, // seqLengthArraySize = batch_size * beam_size
kv_seq_lengths.data(), // seqLengthArray
nullptr // paddingFill
);
if (status != CUDNN_STATUS_SUCCESS) {
fprintf(stderr, "[MHA] ERROR: Failed to set K descriptor: %s\n", cudnnGetErrorString(status));
cudnnDestroySeqDataDescriptor(k_desc);
cudnnDestroySeqDataDescriptor(q_desc);
cudnnDestroyAttnDescriptor(attn_desc);
return;
}
status = cudnnCreateSeqDataDescriptor(&v_desc);
if (status != CUDNN_STATUS_SUCCESS) {
fprintf(stderr, "[MHA] ERROR: Failed to create V descriptor: %s\n", cudnnGetErrorString(status));
cudnnDestroySeqDataDescriptor(k_desc);
cudnnDestroySeqDataDescriptor(q_desc);
cudnnDestroyAttnDescriptor(attn_desc);
return;
}
status = cudnnSetSeqDataDescriptor(
v_desc,
CUDNN_DATA_FLOAT,
4,
k_dims,
kv_axes,
batch_size,
kv_seq_lengths.data(),
nullptr // paddingFill
);
if (status != CUDNN_STATUS_SUCCESS) {
fprintf(stderr, "[MHA] ERROR: Failed to set V descriptor: %s\n", cudnnGetErrorString(status));
cudnnDestroySeqDataDescriptor(v_desc);
cudnnDestroySeqDataDescriptor(k_desc);
cudnnDestroySeqDataDescriptor(q_desc);
cudnnDestroyAttnDescriptor(attn_desc);
return;
}
status = cudnnCreateSeqDataDescriptor(&o_desc);
if (status != CUDNN_STATUS_SUCCESS) {
fprintf(stderr, "[MHA] ERROR: Failed to create O descriptor: %s\n", cudnnGetErrorString(status));
cudnnDestroySeqDataDescriptor(v_desc);
cudnnDestroySeqDataDescriptor(k_desc);
cudnnDestroySeqDataDescriptor(q_desc);
cudnnDestroyAttnDescriptor(attn_desc);
return;
}
status = cudnnSetSeqDataDescriptor(
o_desc,
CUDNN_DATA_FLOAT,
4,
q_dims,
q_axes,
batch_size,
q_seq_lengths.data(),
nullptr
);
if (status != CUDNN_STATUS_SUCCESS) {
fprintf(stderr, "[MHA] ERROR: Failed to set O descriptor: %s\n", cudnnGetErrorString(status));
cudnnDestroySeqDataDescriptor(o_desc);
cudnnDestroySeqDataDescriptor(v_desc);
cudnnDestroySeqDataDescriptor(k_desc);
cudnnDestroySeqDataDescriptor(q_desc);
cudnnDestroyAttnDescriptor(attn_desc);
return;
}
void* workspace = nullptr;
bool using_workspace_pool = false;
bool workspace_allocated = false;
size_t actual_workspace_size = workspace_size_bytes;
if (workspace_size_bytes == 0) {
size_t estimated_size = (size_t)batch_size * max_seq_len_q * max_seq_len_kv * num_heads * sizeof(float);
actual_workspace_size = std::max(estimated_size * 2, (size_t)(32 * 1024 * 1024));
// fprintf(stderr, "[MHA] Using estimated workspace size: %.2f MB\n",
// actual_workspace_size / (1024.0 * 1024.0));
}
if (actual_workspace_size > 0) {
workspace = acquirePooledWorkspace(actual_workspace_size, stream, TENSOR_CORE_ALIGNMENT);
if (workspace != nullptr) {
using_workspace_pool = true;
} else {
CUdeviceptr workspace_ptr = allocateAlignedMemory(actual_workspace_size, TENSOR_CORE_ALIGNMENT);
if (workspace_ptr != 0) {
workspace = reinterpret_cast<void*>(workspace_ptr);
workspace_allocated = true;
fprintf(stderr, "[MHA] Using dynamic workspace (%.2f MB)\n",
actual_workspace_size / (1024.0 * 1024.0));
} else {
fprintf(stderr, "[MHA] ERROR: Failed to allocate workspace of %.2f MB\n",
actual_workspace_size / (1024.0 * 1024.0));
cudnnDestroySeqDataDescriptor(o_desc);
cudnnDestroySeqDataDescriptor(v_desc);
cudnnDestroySeqDataDescriptor(k_desc);
cudnnDestroySeqDataDescriptor(q_desc);
cudnnDestroyAttnDescriptor(attn_desc);
return;
}
}
}
const int* dev_seq_qo = static_cast<const int*>(dev_seq_lengths_qo);
const int* dev_seq_kv = static_cast<const int*>(dev_seq_lengths_kv);
status = cudnnMultiHeadAttnForward(
handle,
attn_desc,
curr_idx,
lo_win_idx,
hi_win_idx,
dev_seq_qo,
dev_seq_kv,
q_desc,
queries,
residuals,
k_desc,
keys,
v_desc,
values,
o_desc,
output,
weight_size_bytes,
weights,
actual_workspace_size,
workspace,
0,
nullptr
);
if (status != CUDNN_STATUS_SUCCESS) {
fprintf(stderr, "[MHA] ERROR: cudnnMultiHeadAttnForward failed: %s\n",
cudnnGetErrorString(status));
}
if (o_desc) cudnnDestroySeqDataDescriptor(o_desc);
if (v_desc) cudnnDestroySeqDataDescriptor(v_desc);
if (k_desc) cudnnDestroySeqDataDescriptor(k_desc);
if (q_desc) cudnnDestroySeqDataDescriptor(q_desc);
if (attn_desc) cudnnDestroyAttnDescriptor(attn_desc);
if (workspace_allocated && workspace) {
CUresult result = cuMemFree(reinterpret_cast<CUdeviceptr>(workspace));
if (result != CUDA_SUCCESS) {
fprintf(stderr, "[MHA] WARNING: Failed to free workspace\n");
}
}
}
During actual testing, the debug information I receive is as follows:
I! CuDNN (v90501 17) function cudnnMultiHeadAttnForward() called:
i! handle: type=cudnnHandle_t; streamId=0x17609670;
i! attnDesc: type=cudnnAttnDescriptor_t:
i! attnMode: type=unsigned; val=CUDNN_ATTN_QUERYMAP_ONE_TO_ONE|CUDNN_ATTN_DISABLE_PROJ_BIASES (0x1);
i! nHeads: type=int; val=8;
i! smScaler: type=double; val=0.176777;
i! dataType: type=cudnnDataType_t; val=CUDNN_DATA_FLOAT (0);
i! mathPrec: type=cudnnDataType_t; val=CUDNN_DATA_FLOAT (0);
i! mathType: type=cudnnMathType_t; val=CUDNN_DEFAULT_MATH (0);
i! attnDropoutDesc: type=cudnnDropoutDescriptor_t; val=NULL_PTR;
i! postDropoutDesc: type=cudnnDropoutDescriptor_t; val=NULL_PTR;
i! qSize: type=int; val=256;
i! kSize: type=int; val=256;
i! vSize: type=int; val=256;
i! qProjSize: type=int; val=0;
i! kProjSize: type=int; val=0;
i! vProjSize: type=int; val=0;
i! oProjSize: type=int; val=0;
i! qoMaxSeqLength: type=int; val=256;
i! kvMaxSeqLength: type=int; val=32;
i! maxBatchSize: type=int; val=64;
i! maxBeamSize: type=int; val=1;
i! currIdx: type=int; val=-1;
i! loWinIdx: location=host; addr=0x7ffd4a348ad8;
i! hiWinIdx: location=host; addr=0x7ffd4a348ed8;
i! devSeqLengthsQO: location=dev; addr=0x7734c7000400;
i! devSeqLengthsKV: location=dev; addr=0x7734c7000600;
i! qDesc: type=cudnnSeqDataDescriptor_t:
i! dataType: type=cudnnDataType_t; val=CUDNN_DATA_FLOAT (0);
i! dimA.time: type=int; val=256;
i! dimA.batch: type=int; val=64;
i! dimA.beam: type=int; val=1;
i! dimA.vect: type=int; val=256;
i! axes: type=cudnnSeqDataAxis_t; val=[CUDNN_SEQDATA_BATCH_DIM,CUDNN_SEQDATA_BEAM_DIM,CUDNN_SEQDATA_TIME_DIM,CUDNN_SEQDATA_VECT_DIM];
i! seqLengthArraySize: type=size_t; val=64;
i! seqLengthArray: type=int; val=[256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256];
i! paddingFill: type=double; val=0;
i! queries: location=dev; addr=0x7734dd200000;
i! residuals: location=dev; addr=NULL_PTR;
i! kDesc: type=cudnnSeqDataDescriptor_t:
i! dataType: type=cudnnDataType_t; val=CUDNN_DATA_FLOAT (0);
i! dimA.time: type=int; val=32;
i! dimA.batch: type=int; val=64;
i! dimA.beam: type=int; val=1;
i! dimA.vect: type=int; val=256;
i! axes: type=cudnnSeqDataAxis_t; val=[CUDNN_SEQDATA_BATCH_DIM,CUDNN_SEQDATA_BEAM_DIM,CUDNN_SEQDATA_TIME_DIM,CUDNN_SEQDATA_VECT_DIM];
i! seqLengthArraySize: type=size_t; val=64;
i! seqLengthArray: type=int; val=[32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32];
i! paddingFill: type=double; val=0;
i! keys: location=dev; addr=0x7734dd400000;
i! vDesc: type=cudnnSeqDataDescriptor_t:
i! dataType: type=cudnnDataType_t; val=CUDNN_DATA_FLOAT (0);
i! dimA.time: type=int; val=32;
i! dimA.batch: type=int; val=64;
i! dimA.beam: type=int; val=1;
i! dimA.vect: type=int; val=256;
i! axes: type=cudnnSeqDataAxis_t; val=[CUDNN_SEQDATA_BATCH_DIM,CUDNN_SEQDATA_BEAM_DIM,CUDNN_SEQDATA_TIME_DIM,CUDNN_SEQDATA_VECT_DIM];
i! seqLengthArraySize: type=size_t; val=64;
i! seqLengthArray: type=int; val=[32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32];
i! paddingFill: type=double; val=0;
i! values: location=dev; addr=0x7734dd600000;
i! oDesc: type=cudnnSeqDataDescriptor_t:
i! dataType: type=cudnnDataType_t; val=CUDNN_DATA_FLOAT (0);
i! dimA.time: type=int; val=256;
i! dimA.batch: type=int; val=64;
i! dimA.beam: type=int; val=1;
i! dimA.vect: type=int; val=256;
i! axes: type=cudnnSeqDataAxis_t; val=[CUDNN_SEQDATA_BATCH_DIM,CUDNN_SEQDATA_BEAM_DIM,CUDNN_SEQDATA_TIME_DIM,CUDNN_SEQDATA_VECT_DIM];
i! seqLengthArraySize: type=size_t; val=64;
i! seqLengthArray: type=int; val=[256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256];
i! paddingFill: type=double; val=0;
i! out: location=dev; addr=0x7734bc400000;
i! weightSizeInBytes: type=size_t; val=0;
i! weights: location=dev; addr=NULL_PTR;
i! workSpaceSizeInBytes: type=size_t; val=33554432;
i! workSpace: location=dev; addr=0x7734a2000000;
i! reserveSpaceSizeInBytes: type=size_t; val=0;
i! reserveSpace: location=dev; addr=NULL_PTR;
i! Time: 2025-09-28T21:37:40.822081 (0d+0h+0m+2s since start)
i! Process=4007425; Thread=4007425; GPU=0; Handle=0x258f98a0; StreamId=0x17609670.
I! CuDNN (v90501 17) function cudnnSetAttnDescriptor() called:
i! attnMode: type=unsigned; val=CUDNN_ATTN_QUERYMAP_ONE_TO_ONE|CUDNN_ATTN_DISABLE_PROJ_BIASES (0x1);
i! nHeads: type=int; val=8;
i! smScaler: type=double; val=0.176777;
i! dataType: type=cudnnDataType_t; val=CUDNN_DATA_FLOAT (0);
i! computePrec: type=cudnnDataType_t; val=CUDNN_DATA_FLOAT (0);
i! mathType: type=cudnnMathType_t; val=CUDNN_DEFAULT_MATH (0);
i! attnDropoutDesc: type=cudnnDropoutDescriptor_t; val=NULL_PTR;
i! postDropoutDesc: type=cudnnDropoutDescriptor_t; val=NULL_PTR;
i! qSize: type=int; val=256;
i! kSize: type=int; val=256;
i! vSize: type=int; val=256;
i! qProjSize: type=int; val=0;
i! kProjSize: type=int; val=0;
i! vProjSize: type=int; val=0;
i! oProjSize: type=int; val=0;
i! qoMaxSeqLength: type=int; val=256;
i! kvMaxSeqLength: type=int; val=32;
i! maxBatchSize: type=int; val=64;
i! maxBeamSize: type=int; val=1;
i! Time: 2025-09-28T21:37:40.821349 (0d+0h+0m+2s since start)
i! Process=4007425; Thread=4007425; GPU=NULL; Handle=NULL; StreamId=NULL.
I! CuDNN (v90501 17) function cudnnSetSeqDataDescriptor() called:
i! dataType: type=cudnnDataType_t; val=CUDNN_DATA_FLOAT (0);
i! nbDims: type=int; val=4;
i! dimA: type=int; val=[256,64,1,256];
i! : type=int; val=[1,2,0,3];
i! seqLengthArray: type=int; val=[256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256];
i! paddingFill: type=CUDNN_DATA_FLOAT; val=NULL_PTR;
i! Time: 2025-09-28T21:37:40.821496 (0d+0h+0m+2s since start)
i! Process=4007425; Thread=4007425; GPU=NULL; Handle=NULL; StreamId=NULL.I! CuDNN (v90501 17) function cudnnSetSeqDataDescriptor() called:
i! dataType: type=cudnnDataType_t; val=CUDNN_DATA_FLOAT (0);
i! nbDims: type=int; val=4;
i! dimA: type=int; val=[32,64,1,256];
i! : type=int; val=[1,2,0,3];
i! seqLengthArray: type=int; val=[32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32];
i! paddingFill: type=CUDNN_DATA_FLOAT; val=NULL_PTR;
i! Time: 2025-09-28T21:37:40.821593 (0d+0h+0m+2s since start)
i! Process=4007425; Thread=4007425; GPU=NULL; Handle=NULL; StreamId=NULL.
I would appreciate any guidance on which part of my API usage or parameter setup might be incorrect.
Thank you in advance for your help!