Host-side memory leak when moving an nn.Module
object to TPU devices (and sharding it) 路 Issue #3545 路 pytorch/xla (original) (raw)
馃悰 Bug
There seems to be a host-side memory leak when we move an nn.Module
object to a TPU device via module.to(device)
. Specifically, when constructing a module
on the host CPU side and then moving it to a TPU device, one would expect that all the parameters in module
should no longer occupy any host CPU memory after we call module = module.to(device)
and xm.mark_step()
, where device = xm.xla_device()
is a TPU.
However, this turned out not to be the case based on our analysis. After calling module = module.to(device)
and xm.mark_step()
, it is still taking a large proportion of the CPU memory, which indicates a potential host-side memory leak after the CPU to TPU transfer.
Such memory leak happens frequently in our real use cases when we try to build large models on the CPU side (e.g. to initialize their parameters) and then transfer them to TPU devices. It blocks the application of FSDP (#3431) to large models due to host-side OOM.
I suspect it is probably because the XRT server did not release the CPU tensors in time after transferring them to TPUs. Could it be that the XRT server trying to cache all the recently compiled graphs and unfortunately also happens to cache the data tensors in them?
To Reproduce
Below is a simplified example of our real use case that suffers from this host-side memory leak. It tries to construct 64 ViT blocks, move them to TPU, and take a shard of them.
If there isn't a memory leak, then the script below should take a constant CPU memory after constructing each block. However, we experience that its CPU memory consumption keeps growing after we build new module blocks, even though the previous blocks are already moved to TPU and we are explicitly calling gc.collect()
after each step.
In our real use cases, this memory leak eventually leads to host-side OOM since we also need memory to store e.g. cached data.
To reproduce it:
- Allocate a v3-8 TPU VM from
tpu-vm-pt-1.10
runtime and install20220419
version oftorch
,torchvision
, andtorch_xla
, while keeping20220408
version of libtpu (due to PyTorch XLA .data assignment fails when the new tensor is a different shape #3502 (comment), I haven't tried newer nightly builds yet). In addition, also install thetimm
package for vision transformers and thepsutil
package to show CPU memory usage as follows.
# torch, torchvision and torch_xla 20220419
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch-nightly+20220419-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torchvision-nightly+20220419-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-nightly+20220419-cp38-cp38-linux_x86_64.whl
# libtpu 20220408
sudo pip3 install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20220408-py3-none-any.whl
# dependencies for this example
sudo pip3 install timm==0.4.12
sudo pip3 install psutil
- Save the following content to a python file (e.g.
/home/ronghanghu/oom_debug.py
below).
import argparse
import gc
import time
import torch
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
# install with `sudo pip3 install psutil timm==0.4.12`
import psutil
from timm.models.vision_transformer import Block
def build_sharded_vit_blocks(send_to_tpu):
world_size = xm.xrt_world_size()
rank = xm.get_ordinal()
device = xm.xla_device() if send_to_tpu else torch.device("cpu")
num_blocks = 64
all_blocks = []
for b_idx in range(num_blocks):
# Create and shard a ViT block
block = Block(dim=5120, num_heads=32, mlp_ratio=4.0, qkv_bias=True)
block = block.to(device)
block = create_module_shard(block, world_size, rank, device)
all_blocks.append(block)
gc.collect()
xm.mark_step()
xm.wait_device_ops()
xm.rendezvous(f"built block {b_idx}")
# check CPU and TPU memory usage
cpu_mem_info = psutil.virtual_memory()
cpu_mem_used_mb = cpu_mem_info.used / 1024 ** 2
tpu_mem_info = xm.get_memory_info(xm.xla_device())
tpu_mem_used_mb = (tpu_mem_info["kb_total"] - tpu_mem_info["kb_free"]) // 1024
xm.master_print(
f"after building block {b_idx}"
f"\n\tCPU memory used: {cpu_mem_used_mb:.2f} MB"
f"\n\tTPU memory used: {tpu_mem_used_mb:.2f} MB"
)
return all_blocks
def create_module_shard(module, world_size, rank, device):
"""
Flatten module parameters into a single vector and shard it (to a slice)
"""
# flatten parameters into a concatenated vector and pad it to world size
p_flat = torch.cat([p.view(-1) for p in module.parameters()], dim=0)
if p_flat.numel() % world_size != 0:
padding = p_flat.new_zeros(world_size - p_flat.numel() % world_size)
p_flat = torch.cat([p_flat, padding], dim=0)
# shard the flattened parameter (i.e. taking a slice of it)
begin = (p_flat.numel() // world_size) * rank
end = (p_flat.numel() // world_size) * (rank + 1)
p_shard = p_flat[begin:end].clone().detach()
# free all the original model parameters and append the sharded param
for p in module.parameters():
p.data = torch.zeros(1, device=device)
module.p_shard = torch.nn.Parameter(p_shard, requires_grad=True)
return module
def _mp_fn(index, send_to_tpu):
all_blocks = build_sharded_vit_blocks(send_to_tpu)
xm.master_print("all blocks constructed")
time.sleep(1000000) # pause here so that we can inspect CPU memory via other tools
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--send_to_tpu", type=int, required=True)
args = parser.parse_args()
xmp.spawn(_mp_fn, args=(args.send_to_tpu,), nprocs=8)
- Run this file on the v3-8 TPU VM:
python3 /home/ronghanghu/oom_debug.py --send_to_tpu 1
It prints
after building block 0
CPU memory used: 23103.51 MB
TPU memory used: 2951.00 MB
after building block 1
CPU memory used: 26145.32 MB
TPU memory used: 3101.00 MB
after building block 2
CPU memory used: 29756.12 MB
TPU memory used: 3251.00 MB
...
after building block 61
CPU memory used: 144820.79 MB
TPU memory used: 12115.00 MB
after building block 62
CPU memory used: 146509.26 MB
TPU memory used: 12265.00 MB
after building block 63
CPU memory used: 148226.47 MB
TPU memory used: 12415.00 MB
all blocks constructed
which shows that the CPU memory consumption keeps growing when we construct more ViT blocks, despite that all the blocks are moved to TPU and shouldn't occupy any CPU memory. And below is a screenshot along with htop
when running it on my v3-8 TPU VM.
Also note that this memory leak is not due to the timm package itself. In particular, if we don't move the modules to TPU (i.e. keeping them in CPU), it actually uses less CPU memory than moving them to TPU, which shows indicates that the memory leak is beyond any CPU-side computation and likely comes from the PyTorch XLA graph tracing or the XRT server.
For example, if we run with --send_to_tpu 0
to keep everything on CPU:
python3 /home/ronghanghu/oom_debug.py --send_to_tpu 0
It prints
...
after building block 61
CPU memory used: 81666.55 MB
TPU memory used: 0.00 MB
after building block 62
CPU memory used: 82880.04 MB
TPU memory used: 0.00 MB
after building block 63
CPU memory used: 84107.28 MB
TPU memory used: 0.00 MB
all blocks constructed
Expected behavior
After moving a module to TPU and calling xm.mark_step
, it should no longer occupy any host-side memory.
Environment
- Reproducible on XLA backend [CPU/TPU]: v3-8 TPU VM
- torch_xla version: 20220419 nightly from
tpu-vm-pt-1.10
(see Step 1 above)
Additional context
This memory leak is the issue (2) mentioned in #3431 (comment)
This issue persists (and seems worse) after upgrading torch, torch_xla, and torchvision wheels to their nightly 20220430 versions and using them together with libtpu-nightly==0.1.dev20220413
as follows:
# torch, torchvision and torch_xla 20220430
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch-nightly+20220430-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torchvision-nightly+20220430-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-nightly+20220430-cp38-cp38-linux_x86_64.whl
# libtpu 20220413
sudo pip3 install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20220413-py3-none-any.whl
# dependencies for this example
sudo pip3 install timm==0.4.12
sudo pip3 install psutil
python3 /home/ronghanghu/oom_debug.py --send_to_tpu 1
prints
...
after building block 61
CPU memory used: 153795.23 MB
TPU memory used: 12115.00 MB
after building block 62
CPU memory used: 155552.81 MB
TPU memory used: 12265.00 MB
after building block 63
CPU memory used: 157576.84 MB
TPU memory used: 12415.00 MB
all blocks constructed
which is even higher CPU memory consumption (148 GB => 157 GB) than in the 20220419 wheels.
cc: @JackCaoG