Llama 2 model divergence with FSDP · Issue #28826 · huggingface/transformers (original) (raw)
System Info
transformers
version: 4.37.1- Platform: Linux-5.10.199-190.747.amzn2.x86_64-x86_64-with-glibc2.31
- Python version: 3.10.8
- Huggingface_hub version: 0.20.2
- Safetensors version: 0.3.3
- Accelerate version: 0.26.1
- Accelerate config: not found
- PyTorch version (GPU?): 2.1.2 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: yes
- Using distributed or parallel set-up in script?: yes
Who can help?
When fine-tuning Llama 2 model with HF 4.37 and PT FSDP, found model divergence in comparison to HF 4.31. Fine-tuning with 4.31 works fine, but with HF 4.37, the loss consistently rises instead of stabilizing when setting attn_implementation="flash_attention_2", while attn_implementation="sdpa" works fine.
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
The model is inited asmodel = AutoModelForCausalLM.from_pretrained(pretrained_model_weights, attn_implementation="flash_attention_2")
Expected behavior
The loss should not go up as the training goes.