Llama 2 model divergence with FSDP · Issue #28826 · huggingface/transformers (original) (raw)

System Info

Who can help?

When fine-tuning Llama 2 model with HF 4.37 and PT FSDP, found model divergence in comparison to HF 4.31. Fine-tuning with 4.31 works fine, but with HF 4.37, the loss consistently rises instead of stabilizing when setting attn_implementation="flash_attention_2", while attn_implementation="sdpa" works fine.

Information

Tasks

Reproduction

The model is inited as
model = AutoModelForCausalLM.from_pretrained(pretrained_model_weights, attn_implementation="flash_attention_2")

Expected behavior

The loss should not go up as the training goes.