[torch-xla 2.1 - 2.4] when functionalization is on, there are no aliasing for gradients when using gradient accumulation 路 Issue #7174 路 pytorch/xla (original) (raw)

馃悰 Bug

When functionalization is on (XLA_DISABLE_FUNCTIONALIZATION=0), I see that there are fewer aliased tensors. Jack has a patch to increase the number of aliased tensors e3fc033 . However, even though this change helped increase the number of aliased tensor, it seems to still missing aliasing for gradients when gradient accumulation is used.

Using test_train_mp_mnist.py, make the modifications below. I added a mark_step to isolate the gradient accumulation loops.

@ -158,16 +163,19 @@ def train_mnist(flags, **kwargs):
       output = model(data)
       loss = loss_fn(output, target)
       loss.backward()
-      if flags.ddp:
-        optimizer.step()
-      else:
-        xm.optimizer_step(optimizer)
-      tracker.add(flags.batch_size)
-      if step % flags.log_steps == 0:
-        xm.add_step_closure(
-            _train_update,
-            args=(device, step, loss, tracker, epoch, writer),
-            run_async=flags.async_closures)
+
+      if step % 4 == 0:
+          xm.mark_step()
+          if flags.ddp:
+            optimizer.step()
+          else:
+            xm.optimizer_step(optimizer)
+          tracker.add(flags.batch_size)
+          if step % flags.log_steps == 0:
+            xm.add_step_closure(
+                _train_update,
+                args=(device, step, loss, tracker, epoch, writer),
+                run_async=flags.async_closures)

I only see 2 alias even though we expect all the gradient tensors to be aliased:

2024-06-03 21:15:37.676472: I torch_xla/csrc/xla_graph_executor.cpp:1462] Parameter sequence graph hash b8e15ed0391b82171706a34d84ca8ea0
2024-06-03 21:15:37.678822: I torch_xla/csrc/xla_graph_executor.cpp:1299] Aliased paramter 13 with output 4: s64[]
2024-06-03 21:15:37.678862: I torch_xla/csrc/xla_graph_executor.cpp:1299] Aliased paramter 14 with output 5: s64[]
2024-06-03 21:15:37.679222: I torch_xla/csrc/xla_graph_executor.cpp:1397] Compiling IR graph hash b8e15ed0391b82171706a34d84ca8ea0 on device CPU:0 ...

To Reproduce

Steps to reproduce the behavior:

  1. Check out r2.1_aws_neuron branch
  2. Apply a patch from Jack e3fc033
  3. Build/install as in CONTRIBUTION doc
  4. Go into xla/test
  5. Edit test_train_mp_mnist.py and add gradient accumulation loop as above.
  6. Run with TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE="xla_graph_executor=6,pjrt_computation_client=5" to see aliasing debugging logs:
XLA_IR_DEBUG=1 XLA_HLO_DEBUG=1 XLA_SAVE_TENSORS_FMT="hlo" XLA_SAVE_TENSORS_FILE="/tmp/save1.hlo"   TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE="xla_graph_executor=6,pjrt_computation_client=5" python test_train_mp_mnist.py |& tee log

Expected behavior

Expect gradients to be aliased

Environment

Additional context