[torch-xla 2.6] Training performance regression in torch-xla 2.6 for medium/small models (original) (raw)

🐛 Bug

This issue is to document a fix already in top-of-tree torch-xla and back port to 2.7.

We are seeing about 5% reduction in performance for llama 8B training and 10% reduction for BERT training just moving to PT2.6. The generated HLOs are the same as in PT2.5 so the overhead is suspected to come from tracing. The problem was narrowed down to 5ce8609 and fixed in #8976 .

Since there's a new CVE affecting torch 2.5, it would be best to patch torch-xla 2.6 so that customers can enjoy same performance as 2.5, without the CVE. Otherwise, customers will only have 2.7 to use for best performance, and which Neuron has not complete the testing yet for 2.7.

To Reproduce

Steps to reproduce the behavior:

  1. Install 2.5 and 2.6 software stacks with Neuron torch-neuronx + torch-xla + torch
  2. Run https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-neuronx/tutorials/training/bert.html#hf-bert-pretraining-tutorial
  3. Compare performance.

Expected behavior

Performance on par with torch-xla 2.5 on smaller models.

Environment

Additional context