Can't finetune stable diffusion with --enable_xformers_memory_efficient_attention · Issue #2234 · huggingface/diffusers (original) (raw)

Describe the bug

I'm trying to finetune stable diffusion, and I'm trying to reduce the memory footprint so I can train with a larger batch size (and thus fewer gradient accumulation steps, and thus faster).

Setting --enable_xformers_memory_efficient_attention results in numeric instability of some kind, I think? The safety_checker tripped (training on the Pokemon dataset, validation prompt "Yoda"). If I disable the safety_checker, and I get black images anyway, along with the error message:

/home/lucas/.local/lib/python3.8/site-packages/diffusers/pipelines/pipeline_utils.py:813: RuntimeWarning: invalid value encountered in cast
  images = (images * 255).round().astype("uint8")

If I instead set --enable_xformers_memory_efficient_attention, but disable --gradient_checkpointing, everything hums along nicely, but the model doesn't actually fine tune.

I attempted to force xformers to use Flash Attention (using the snippet in #2049), because #1997 suggested there were issues with the other xformers attention kernels, I get this error:

ValueError: Operator `memory_efficient_attention` does not support inputs:
     query       : shape=(8, 256, 1, 160) (torch.float16)
     key         : shape=(8, 256, 1, 160) (torch.float16)
     value       : shape=(8, 256, 1, 160) (torch.float16)
     attn_bias   : <class 'NoneType'>
     p           : 0.0
`flshattF` is not supported because:
    max(query.shape[-1] != value.shape[-1]) > 128

Reproduction

Here's the command I ran with --enable_xformers_memory_efficient_attention, but not with --gradient_checkpointing:

accelerate launch train_text_to_image.py   --pretrained_model_name_or_path=$MODEL_NAME   --dataset_name=$dataset_name   --use_ema   --resolution=512 --center_crop --random_flip   --train_batch_size=1   --gradient_accumulation_steps=8   --mixed_precision="fp16"   --max_train_steps=15000   --learning_rate=1e-05   --max_grad_norm=1   --lr_scheduler="constant" --lr_warmup_steps=0   --output_dir="sd-pokemon-model"  --validation_prompt=Yoda --num_validation_images=8  --validation_steps=1000  --enable_xformers_memory_efficient_attention

I'm running with #2157, because that gives me images to see how training is progressing (which is how I noticed it wasn't finetuning), but I've observed it at HEAD.

Logs

No response

System Info