[torch-xla 2.1 - 2.4] when functionalization is on, there are no aliasing for gradients when using gradient accumulation 路 Issue #7174 路 pytorch/xla (original) (raw)
馃悰 Bug
When functionalization is on (XLA_DISABLE_FUNCTIONALIZATION=0), I see that there are fewer aliased tensors. Jack has a patch to increase the number of aliased tensors e3fc033 . However, even though this change helped increase the number of aliased tensor, it seems to still missing aliasing for gradients when gradient accumulation is used.
Using test_train_mp_mnist.py, make the modifications below. I added a mark_step to isolate the gradient accumulation loops.
@ -158,16 +163,19 @@ def train_mnist(flags, **kwargs):
output = model(data)
loss = loss_fn(output, target)
loss.backward()
- if flags.ddp:
- optimizer.step()
- else:
- xm.optimizer_step(optimizer)
- tracker.add(flags.batch_size)
- if step % flags.log_steps == 0:
- xm.add_step_closure(
- _train_update,
- args=(device, step, loss, tracker, epoch, writer),
- run_async=flags.async_closures)
+
+ if step % 4 == 0:
+ xm.mark_step()
+ if flags.ddp:
+ optimizer.step()
+ else:
+ xm.optimizer_step(optimizer)
+ tracker.add(flags.batch_size)
+ if step % flags.log_steps == 0:
+ xm.add_step_closure(
+ _train_update,
+ args=(device, step, loss, tracker, epoch, writer),
+ run_async=flags.async_closures)
I only see 2 alias even though we expect all the gradient tensors to be aliased:
2024-06-03 21:15:37.676472: I torch_xla/csrc/xla_graph_executor.cpp:1462] Parameter sequence graph hash b8e15ed0391b82171706a34d84ca8ea0
2024-06-03 21:15:37.678822: I torch_xla/csrc/xla_graph_executor.cpp:1299] Aliased paramter 13 with output 4: s64[]
2024-06-03 21:15:37.678862: I torch_xla/csrc/xla_graph_executor.cpp:1299] Aliased paramter 14 with output 5: s64[]
2024-06-03 21:15:37.679222: I torch_xla/csrc/xla_graph_executor.cpp:1397] Compiling IR graph hash b8e15ed0391b82171706a34d84ca8ea0 on device CPU:0 ...
To Reproduce
Steps to reproduce the behavior:
- Check out r2.1_aws_neuron branch
- Apply a patch from Jack e3fc033
- Build/install as in CONTRIBUTION doc
- Go into xla/test
- Edit test_train_mp_mnist.py and add gradient accumulation loop as above.
- Run with TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE="xla_graph_executor=6,pjrt_computation_client=5" to see aliasing debugging logs:
XLA_IR_DEBUG=1 XLA_HLO_DEBUG=1 XLA_SAVE_TENSORS_FMT="hlo" XLA_SAVE_TENSORS_FILE="/tmp/save1.hlo" TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE="xla_graph_executor=6,pjrt_computation_client=5" python test_train_mp_mnist.py |& tee log
Expected behavior
Expect gradients to be aliased
Environment
- Reproducible on XLA backend [CPU/TPU/CUDA]: CPU
- torch_xla version: 2.1