Issue with cudnnMultiHeadAttnForward() – CUDNN_STATUS_BAD_PARAM (original) (raw)

Hello,

I am currently implementing a wrapper function to call the cudnnMultiHeadAttnForward() API provided by cuDNN. However, after extensive testing, I consistently encounter a parameter error:

cudnnMultiHeadAttnForward failed: CUDNN_STATUS_BAD_PARAM

I am unsure whether the issue comes from my API usage or from how I am passing the parameters.

Below is the implementation of my wrapper function:

extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuCudnnMultiHeadAttention(
    int num_heads,
    int embed_dim,
    int max_seq_len_q,
    int max_seq_len_kv,
    int batch_size,
    int curr_idx,
    const int* lo_win_idx,                 
    const int* hi_win_idx,                  
    const void* dev_seq_lengths_qo,        
    const void* dev_seq_lengths_kv,        
    const void* queries,
    const void* residuals,                 
    const void* keys,
    const void* values,
    void* output,
    size_t weight_size_bytes,
    const void* weights,
    size_t workspace_size_bytes,            
    CUstream stream
) {
    mgpuEnsureContext();
    
    StreamHandles handles;
    if (!getHandlesForStream(stream, handles)) {
        fprintf(stderr, "[MHA] ERROR: Failed to get handles for stream %p\n", stream);
        return;
    }
    cudnnHandle_t handle = handles.cudnn_handle;
    
    if (num_heads <= 0 || embed_dim <= 0 || max_seq_len_q <= 0 || max_seq_len_kv <= 0 || batch_size <= 0) {
        fprintf(stderr, "[MHA] ERROR: Invalid dimensions: heads=%d, embed=%d, seq_q=%d, seq_kv=%d, batch=%d\n", 
                num_heads, embed_dim, max_seq_len_q, max_seq_len_kv, batch_size);
        return;
    }
    
    if (!queries || !keys || !values || !output) {
        fprintf(stderr, "[MHA] ERROR: One or more data pointers are NULL\n");
        return;
    }
    
    if (!weights && weight_size_bytes > 0) {
        fprintf(stderr, "[MHA] ERROR: Weights pointer is NULL but weight size > 0\n");
        return;
    }
    
    if (embed_dim % num_heads != 0) {
        fprintf(stderr, "[MHA] ERROR: embed_dim (%d) must be divisible by num_heads (%d)\n", 
                embed_dim, num_heads);
        return;
    }
    
    cudnnAttnDescriptor_t attn_desc = nullptr;
    cudnnStatus_t status = cudnnCreateAttnDescriptor(&attn_desc);
    if (status != CUDNN_STATUS_SUCCESS) {
        fprintf(stderr, "[MHA] ERROR: Failed to create attention descriptor: %s\n", 
                cudnnGetErrorString(status));
        return;
    }
    
    int head_dim = embed_dim / num_heads;
    double sm_scaler = 1.0f / sqrtf((double)head_dim);

    unsigned attnMode = 0;
    attnMode |= CUDNN_ATTN_DISABLE_PROJ_BIASES;  
    attnMode |= CUDNN_ATTN_QUERYMAP_ALL_TO_ONE;  
    // attnMode |= CUDNN_ATTN_QUERYMAP_ONE_TO_ONE;
    
    status = cudnnSetAttnDescriptor(
        attn_desc,
        attnMode,        // attnMode
        num_heads,                   // nHeads
        sm_scaler,                   // smScaler
        CUDNN_DATA_FLOAT,           // dataType
        CUDNN_DATA_FLOAT,           // computePrec
        CUDNN_DEFAULT_MATH,         // mathType
        nullptr,                    // attnDropoutDesc
        nullptr,                    // postDropoutDesc
        embed_dim,                  // qSize
        embed_dim,                  // kSize  
        embed_dim,                  // vSize
        0,                  // qProjSize
        0,                  // kProjSize
        0,                  // vProjSize
        0,                  // oProjSize
        max_seq_len_q,              // qoMaxSeqLength
        max_seq_len_kv,             // kvMaxSeqLength
        batch_size,                 // maxBatchSize
        1                           
    );
    
    if (status != CUDNN_STATUS_SUCCESS) {
        fprintf(stderr, "[MHA] ERROR: Failed to set attention descriptor: %s\n", 
                cudnnGetErrorString(status));
        cudnnDestroyAttnDescriptor(attn_desc);
        return;
    }
    
    cudnnSeqDataDescriptor_t q_desc = nullptr, k_desc = nullptr;
    cudnnSeqDataDescriptor_t v_desc = nullptr, o_desc = nullptr;
    
    status = cudnnCreateSeqDataDescriptor(&q_desc);
    if (status != CUDNN_STATUS_SUCCESS) {
        fprintf(stderr, "[MHA] ERROR: Failed to create Q descriptor: %s\n", cudnnGetErrorString(status));
        cudnnDestroyAttnDescriptor(attn_desc);
        return;
    }
    
    int q_dims[CUDNN_SEQDATA_DIM_COUNT];
    q_dims[CUDNN_SEQDATA_BATCH_DIM] = batch_size;
    q_dims[CUDNN_SEQDATA_TIME_DIM] = max_seq_len_q;
    q_dims[CUDNN_SEQDATA_BEAM_DIM] = 1;              
    q_dims[CUDNN_SEQDATA_VECT_DIM] = embed_dim;

    std::vector<int> q_seq_lengths(batch_size, max_seq_len_q);
   
    cudnnSeqDataAxis_t q_axes[CUDNN_SEQDATA_DIM_COUNT];
    q_axes[0] = CUDNN_SEQDATA_BATCH_DIM;           
    q_axes[1] = CUDNN_SEQDATA_BEAM_DIM;             
    q_axes[2] = CUDNN_SEQDATA_TIME_DIM;            
    q_axes[3] = CUDNN_SEQDATA_VECT_DIM;             

    status = cudnnSetSeqDataDescriptor(
        q_desc, 
        CUDNN_DATA_FLOAT, 
        4,                                          
        q_dims, 
        q_axes,                                     
        batch_size,                                 // seqLengthArraySize = batch_size * beam_size
        q_seq_lengths.data(),                       // seqLengthArray
        nullptr                                     // paddingFill
    );
    if (status != CUDNN_STATUS_SUCCESS) {
        fprintf(stderr, "[MHA] ERROR: Failed to set Q descriptor: %s\n", cudnnGetErrorString(status));
        cudnnDestroySeqDataDescriptor(q_desc);
        cudnnDestroyAttnDescriptor(attn_desc);
        return;
    }
    
    status = cudnnCreateSeqDataDescriptor(&k_desc);
    if (status != CUDNN_STATUS_SUCCESS) {
        fprintf(stderr, "[MHA] ERROR: Failed to create K descriptor: %s\n", cudnnGetErrorString(status));
        cudnnDestroySeqDataDescriptor(q_desc);
        cudnnDestroyAttnDescriptor(attn_desc);
        return;
    }
    
    int k_dims[CUDNN_SEQDATA_DIM_COUNT];
    k_dims[CUDNN_SEQDATA_BATCH_DIM] = batch_size;
    k_dims[CUDNN_SEQDATA_TIME_DIM] = max_seq_len_kv;
    k_dims[CUDNN_SEQDATA_BEAM_DIM] = 1;             
    k_dims[CUDNN_SEQDATA_VECT_DIM] = embed_dim;

    std::vector<int> kv_seq_lengths(batch_size, max_seq_len_kv);
    
    cudnnSeqDataAxis_t kv_axes[CUDNN_SEQDATA_DIM_COUNT];
    kv_axes[0] = CUDNN_SEQDATA_BATCH_DIM;           
    kv_axes[1] = CUDNN_SEQDATA_BEAM_DIM;            
    kv_axes[2] = CUDNN_SEQDATA_TIME_DIM;           
    kv_axes[3] = CUDNN_SEQDATA_VECT_DIM;            

    status = cudnnSetSeqDataDescriptor(
        k_desc, 
        CUDNN_DATA_FLOAT, 
        4,                                          
        k_dims, 
        kv_axes,                                    
        batch_size,                                 // seqLengthArraySize = batch_size * beam_size
        kv_seq_lengths.data(),                      // seqLengthArray
        nullptr                                     // paddingFill
    );
    if (status != CUDNN_STATUS_SUCCESS) {
        fprintf(stderr, "[MHA] ERROR: Failed to set K descriptor: %s\n", cudnnGetErrorString(status));
        cudnnDestroySeqDataDescriptor(k_desc);
        cudnnDestroySeqDataDescriptor(q_desc);
        cudnnDestroyAttnDescriptor(attn_desc);
        return;
    }
    
    status = cudnnCreateSeqDataDescriptor(&v_desc);
    if (status != CUDNN_STATUS_SUCCESS) {
        fprintf(stderr, "[MHA] ERROR: Failed to create V descriptor: %s\n", cudnnGetErrorString(status));
        cudnnDestroySeqDataDescriptor(k_desc);
        cudnnDestroySeqDataDescriptor(q_desc);
        cudnnDestroyAttnDescriptor(attn_desc);
        return;
    }
    
    status = cudnnSetSeqDataDescriptor(
        v_desc, 
        CUDNN_DATA_FLOAT, 
        4,                                          
        k_dims,                                     
        kv_axes,                                    
        batch_size,                                 
        kv_seq_lengths.data(),                    
        nullptr                                     // paddingFill
    );
    if (status != CUDNN_STATUS_SUCCESS) {
        fprintf(stderr, "[MHA] ERROR: Failed to set V descriptor: %s\n", cudnnGetErrorString(status));
        cudnnDestroySeqDataDescriptor(v_desc);
        cudnnDestroySeqDataDescriptor(k_desc);
        cudnnDestroySeqDataDescriptor(q_desc);
        cudnnDestroyAttnDescriptor(attn_desc);
        return;
    }
    
    status = cudnnCreateSeqDataDescriptor(&o_desc);
    if (status != CUDNN_STATUS_SUCCESS) {
        fprintf(stderr, "[MHA] ERROR: Failed to create O descriptor: %s\n", cudnnGetErrorString(status));
        cudnnDestroySeqDataDescriptor(v_desc);
        cudnnDestroySeqDataDescriptor(k_desc);
        cudnnDestroySeqDataDescriptor(q_desc);
        cudnnDestroyAttnDescriptor(attn_desc);
        return;
    }
    
    status = cudnnSetSeqDataDescriptor(
        o_desc, 
        CUDNN_DATA_FLOAT, 
        4,                                         
        q_dims,                                    
        q_axes,                                    
        batch_size,                                 
        q_seq_lengths.data(),                       
        nullptr                                    
    );
    if (status != CUDNN_STATUS_SUCCESS) {
        fprintf(stderr, "[MHA] ERROR: Failed to set O descriptor: %s\n", cudnnGetErrorString(status));
        cudnnDestroySeqDataDescriptor(o_desc);
        cudnnDestroySeqDataDescriptor(v_desc);
        cudnnDestroySeqDataDescriptor(k_desc);
        cudnnDestroySeqDataDescriptor(q_desc);
        cudnnDestroyAttnDescriptor(attn_desc);
        return;
    }
    
    void* workspace = nullptr;
    bool using_workspace_pool = false;
    bool workspace_allocated = false;
    size_t actual_workspace_size = workspace_size_bytes;
    
    if (workspace_size_bytes == 0) {
        size_t estimated_size = (size_t)batch_size * max_seq_len_q * max_seq_len_kv * num_heads * sizeof(float);
        actual_workspace_size = std::max(estimated_size * 2, (size_t)(32 * 1024 * 1024));
        
        // fprintf(stderr, "[MHA] Using estimated workspace size: %.2f MB\n", 
        //         actual_workspace_size / (1024.0 * 1024.0));
    }
    
    if (actual_workspace_size > 0) {
        workspace = acquirePooledWorkspace(actual_workspace_size, stream, TENSOR_CORE_ALIGNMENT);
        
        if (workspace != nullptr) {
            using_workspace_pool = true;
        } else {
            CUdeviceptr workspace_ptr = allocateAlignedMemory(actual_workspace_size, TENSOR_CORE_ALIGNMENT);
            if (workspace_ptr != 0) {
                workspace = reinterpret_cast<void*>(workspace_ptr);
                workspace_allocated = true;
                fprintf(stderr, "[MHA] Using dynamic workspace (%.2f MB)\n", 
                        actual_workspace_size / (1024.0 * 1024.0));
            } else {
                fprintf(stderr, "[MHA] ERROR: Failed to allocate workspace of %.2f MB\n", 
                        actual_workspace_size / (1024.0 * 1024.0));
                cudnnDestroySeqDataDescriptor(o_desc);
                cudnnDestroySeqDataDescriptor(v_desc);
                cudnnDestroySeqDataDescriptor(k_desc);
                cudnnDestroySeqDataDescriptor(q_desc);
                cudnnDestroyAttnDescriptor(attn_desc);
                return;
            }
        }
    }
    
    const int* dev_seq_qo = static_cast<const int*>(dev_seq_lengths_qo);
    const int* dev_seq_kv = static_cast<const int*>(dev_seq_lengths_kv);
    
    status = cudnnMultiHeadAttnForward(
        handle,
        attn_desc,
        curr_idx,
        lo_win_idx,               
        hi_win_idx,             
        dev_seq_qo,               
        dev_seq_kv,              
        q_desc,
        queries,
        residuals,               
        k_desc,
        keys,
        v_desc,
        values,
        o_desc,
        output,
        weight_size_bytes,
        weights,
        actual_workspace_size,
        workspace,
        0,                      
        nullptr                   
    );
    
    if (status != CUDNN_STATUS_SUCCESS) {
        fprintf(stderr, "[MHA] ERROR: cudnnMultiHeadAttnForward failed: %s\n", 
                cudnnGetErrorString(status));
    }
    
    if (o_desc) cudnnDestroySeqDataDescriptor(o_desc);
    if (v_desc) cudnnDestroySeqDataDescriptor(v_desc);
    if (k_desc) cudnnDestroySeqDataDescriptor(k_desc);
    if (q_desc) cudnnDestroySeqDataDescriptor(q_desc);
    if (attn_desc) cudnnDestroyAttnDescriptor(attn_desc);
    
    if (workspace_allocated && workspace) {
        CUresult result = cuMemFree(reinterpret_cast<CUdeviceptr>(workspace));
        if (result != CUDA_SUCCESS) {
            fprintf(stderr, "[MHA] WARNING: Failed to free workspace\n");
        }
    }
}

During actual testing, the debug information I receive is as follows:

I! CuDNN (v90501 17) function cudnnMultiHeadAttnForward() called:
i! handle: type=cudnnHandle_t; streamId=0x17609670;
i! attnDesc: type=cudnnAttnDescriptor_t:
i! attnMode: type=unsigned; val=CUDNN_ATTN_QUERYMAP_ONE_TO_ONE|CUDNN_ATTN_DISABLE_PROJ_BIASES (0x1);
i! nHeads: type=int; val=8;
i! smScaler: type=double; val=0.176777;
i! dataType: type=cudnnDataType_t; val=CUDNN_DATA_FLOAT (0);
i! mathPrec: type=cudnnDataType_t; val=CUDNN_DATA_FLOAT (0);
i! mathType: type=cudnnMathType_t; val=CUDNN_DEFAULT_MATH (0);
i! attnDropoutDesc: type=cudnnDropoutDescriptor_t; val=NULL_PTR;
i! postDropoutDesc: type=cudnnDropoutDescriptor_t; val=NULL_PTR;
i! qSize: type=int; val=256;
i! kSize: type=int; val=256;
i! vSize: type=int; val=256;
i! qProjSize: type=int; val=0;
i! kProjSize: type=int; val=0;
i! vProjSize: type=int; val=0;
i! oProjSize: type=int; val=0;
i! qoMaxSeqLength: type=int; val=256;
i! kvMaxSeqLength: type=int; val=32;
i! maxBatchSize: type=int; val=64;
i! maxBeamSize: type=int; val=1;
i! currIdx: type=int; val=-1;
i! loWinIdx: location=host; addr=0x7ffd4a348ad8;
i! hiWinIdx: location=host; addr=0x7ffd4a348ed8;
i! devSeqLengthsQO: location=dev; addr=0x7734c7000400;
i! devSeqLengthsKV: location=dev; addr=0x7734c7000600;
i! qDesc: type=cudnnSeqDataDescriptor_t:
i! dataType: type=cudnnDataType_t; val=CUDNN_DATA_FLOAT (0);
i! dimA.time: type=int; val=256;
i! dimA.batch: type=int; val=64;
i! dimA.beam: type=int; val=1;
i! dimA.vect: type=int; val=256;
i! axes: type=cudnnSeqDataAxis_t; val=[CUDNN_SEQDATA_BATCH_DIM,CUDNN_SEQDATA_BEAM_DIM,CUDNN_SEQDATA_TIME_DIM,CUDNN_SEQDATA_VECT_DIM];
i! seqLengthArraySize: type=size_t; val=64;
i! seqLengthArray: type=int; val=[256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256];
i! paddingFill: type=double; val=0;
i! queries: location=dev; addr=0x7734dd200000;
i! residuals: location=dev; addr=NULL_PTR;
i! kDesc: type=cudnnSeqDataDescriptor_t:
i! dataType: type=cudnnDataType_t; val=CUDNN_DATA_FLOAT (0);
i! dimA.time: type=int; val=32;
i! dimA.batch: type=int; val=64;
i! dimA.beam: type=int; val=1;
i! dimA.vect: type=int; val=256;
i! axes: type=cudnnSeqDataAxis_t; val=[CUDNN_SEQDATA_BATCH_DIM,CUDNN_SEQDATA_BEAM_DIM,CUDNN_SEQDATA_TIME_DIM,CUDNN_SEQDATA_VECT_DIM];
i! seqLengthArraySize: type=size_t; val=64;
i! seqLengthArray: type=int; val=[32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32];
i! paddingFill: type=double; val=0;
i! keys: location=dev; addr=0x7734dd400000;
i! vDesc: type=cudnnSeqDataDescriptor_t:
i! dataType: type=cudnnDataType_t; val=CUDNN_DATA_FLOAT (0);
i! dimA.time: type=int; val=32;
i! dimA.batch: type=int; val=64;
i! dimA.beam: type=int; val=1;
i! dimA.vect: type=int; val=256;
i! axes: type=cudnnSeqDataAxis_t; val=[CUDNN_SEQDATA_BATCH_DIM,CUDNN_SEQDATA_BEAM_DIM,CUDNN_SEQDATA_TIME_DIM,CUDNN_SEQDATA_VECT_DIM];
i! seqLengthArraySize: type=size_t; val=64;
i! seqLengthArray: type=int; val=[32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32];
i! paddingFill: type=double; val=0;
i! values: location=dev; addr=0x7734dd600000;
i! oDesc: type=cudnnSeqDataDescriptor_t:
i! dataType: type=cudnnDataType_t; val=CUDNN_DATA_FLOAT (0);
i! dimA.time: type=int; val=256;
i! dimA.batch: type=int; val=64;
i! dimA.beam: type=int; val=1;
i! dimA.vect: type=int; val=256;
i! axes: type=cudnnSeqDataAxis_t; val=[CUDNN_SEQDATA_BATCH_DIM,CUDNN_SEQDATA_BEAM_DIM,CUDNN_SEQDATA_TIME_DIM,CUDNN_SEQDATA_VECT_DIM];
i! seqLengthArraySize: type=size_t; val=64;
i! seqLengthArray: type=int; val=[256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256];
i! paddingFill: type=double; val=0;
i! out: location=dev; addr=0x7734bc400000;
i! weightSizeInBytes: type=size_t; val=0;
i! weights: location=dev; addr=NULL_PTR;
i! workSpaceSizeInBytes: type=size_t; val=33554432;
i! workSpace: location=dev; addr=0x7734a2000000;
i! reserveSpaceSizeInBytes: type=size_t; val=0;
i! reserveSpace: location=dev; addr=NULL_PTR;
i! Time: 2025-09-28T21:37:40.822081 (0d+0h+0m+2s since start)
i! Process=4007425; Thread=4007425; GPU=0; Handle=0x258f98a0; StreamId=0x17609670.

I! CuDNN (v90501 17) function cudnnSetAttnDescriptor() called:
i! attnMode: type=unsigned; val=CUDNN_ATTN_QUERYMAP_ONE_TO_ONE|CUDNN_ATTN_DISABLE_PROJ_BIASES (0x1);
i! nHeads: type=int; val=8;
i! smScaler: type=double; val=0.176777;
i! dataType: type=cudnnDataType_t; val=CUDNN_DATA_FLOAT (0);
i! computePrec: type=cudnnDataType_t; val=CUDNN_DATA_FLOAT (0);
i! mathType: type=cudnnMathType_t; val=CUDNN_DEFAULT_MATH (0);
i! attnDropoutDesc: type=cudnnDropoutDescriptor_t; val=NULL_PTR;
i! postDropoutDesc: type=cudnnDropoutDescriptor_t; val=NULL_PTR;
i! qSize: type=int; val=256;
i! kSize: type=int; val=256;
i! vSize: type=int; val=256;
i! qProjSize: type=int; val=0;
i! kProjSize: type=int; val=0;
i! vProjSize: type=int; val=0;
i! oProjSize: type=int; val=0;
i! qoMaxSeqLength: type=int; val=256;
i! kvMaxSeqLength: type=int; val=32;
i! maxBatchSize: type=int; val=64;
i! maxBeamSize: type=int; val=1;
i! Time: 2025-09-28T21:37:40.821349 (0d+0h+0m+2s since start)
i! Process=4007425; Thread=4007425; GPU=NULL; Handle=NULL; StreamId=NULL.

I! CuDNN (v90501 17) function cudnnSetSeqDataDescriptor() called:
i! dataType: type=cudnnDataType_t; val=CUDNN_DATA_FLOAT (0);
i! nbDims: type=int; val=4;
i! dimA: type=int; val=[256,64,1,256];
i! : type=int; val=[1,2,0,3];
i! seqLengthArray: type=int; val=[256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256,256];
i! paddingFill: type=CUDNN_DATA_FLOAT; val=NULL_PTR;
i! Time: 2025-09-28T21:37:40.821496 (0d+0h+0m+2s since start)
i! Process=4007425; Thread=4007425; GPU=NULL; Handle=NULL; StreamId=NULL.

I! CuDNN (v90501 17) function cudnnSetSeqDataDescriptor() called:
i! dataType: type=cudnnDataType_t; val=CUDNN_DATA_FLOAT (0);
i! nbDims: type=int; val=4;
i! dimA: type=int; val=[32,64,1,256];
i! : type=int; val=[1,2,0,3];
i! seqLengthArray: type=int; val=[32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32,32];
i! paddingFill: type=CUDNN_DATA_FLOAT; val=NULL_PTR;
i! Time: 2025-09-28T21:37:40.821593 (0d+0h+0m+2s since start)
i! Process=4007425; Thread=4007425; GPU=NULL; Handle=NULL; StreamId=NULL.

I would appreciate any guidance on which part of my API usage or parameter setup might be incorrect.

Thank you in advance for your help!