Fix compilation bug with CUDA 12.1 by Edenzzzz · Pull Request #949 · NVIDIA/TransformerEngine (original) (raw)

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

Changes

This has been mentioned in #560 but somehow someone just changed it back... Importing <cuda_fp8.h>, which imports <cuda_bf16.h>, after defining nv_bfloat16 triggers re-declaration error.

Checklist: