Host-side memory leak when moving an nn.Module object to TPU devices (and sharding it) 路 Issue #3545 路 pytorch/xla (original) (raw)

馃悰 Bug

There seems to be a host-side memory leak when we move an nn.Module object to a TPU device via module.to(device). Specifically, when constructing a module on the host CPU side and then moving it to a TPU device, one would expect that all the parameters in module should no longer occupy any host CPU memory after we call module = module.to(device) and xm.mark_step(), where device = xm.xla_device() is a TPU.

However, this turned out not to be the case based on our analysis. After calling module = module.to(device) and xm.mark_step(), it is still taking a large proportion of the CPU memory, which indicates a potential host-side memory leak after the CPU to TPU transfer.

Such memory leak happens frequently in our real use cases when we try to build large models on the CPU side (e.g. to initialize their parameters) and then transfer them to TPU devices. It blocks the application of FSDP (#3431) to large models due to host-side OOM.

I suspect it is probably because the XRT server did not release the CPU tensors in time after transferring them to TPUs. Could it be that the XRT server trying to cache all the recently compiled graphs and unfortunately also happens to cache the data tensors in them?

To Reproduce

Below is a simplified example of our real use case that suffers from this host-side memory leak. It tries to construct 64 ViT blocks, move them to TPU, and take a shard of them.

If there isn't a memory leak, then the script below should take a constant CPU memory after constructing each block. However, we experience that its CPU memory consumption keeps growing after we build new module blocks, even though the previous blocks are already moved to TPU and we are explicitly calling gc.collect() after each step.

In our real use cases, this memory leak eventually leads to host-side OOM since we also need memory to store e.g. cached data.


To reproduce it:

  1. Allocate a v3-8 TPU VM from tpu-vm-pt-1.10 runtime and install 20220419 version of torch, torchvision, and torch_xla, while keeping 20220408 version of libtpu (due to PyTorch XLA .data assignment fails when the new tensor is a different shape #3502 (comment), I haven't tried newer nightly builds yet). In addition, also install the timm package for vision transformers and the psutil package to show CPU memory usage as follows.
# torch, torchvision and torch_xla 20220419
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch-nightly+20220419-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torchvision-nightly+20220419-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-nightly+20220419-cp38-cp38-linux_x86_64.whl

# libtpu 20220408
sudo pip3 install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20220408-py3-none-any.whl

# dependencies for this example
sudo pip3 install timm==0.4.12
sudo pip3 install psutil
  1. Save the following content to a python file (e.g. /home/ronghanghu/oom_debug.py below).
import argparse
import gc
import time

import torch
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp

# install with `sudo pip3 install psutil timm==0.4.12`
import psutil
from timm.models.vision_transformer import Block


def build_sharded_vit_blocks(send_to_tpu):
    world_size = xm.xrt_world_size()
    rank = xm.get_ordinal()
    device = xm.xla_device() if send_to_tpu else torch.device("cpu")

    num_blocks = 64
    all_blocks = []
    for b_idx in range(num_blocks):
        # Create and shard a ViT block
        block = Block(dim=5120, num_heads=32, mlp_ratio=4.0, qkv_bias=True)
        block = block.to(device)
        block = create_module_shard(block, world_size, rank, device)
        all_blocks.append(block)

        gc.collect()
        xm.mark_step()
        xm.wait_device_ops()
        xm.rendezvous(f"built block {b_idx}")

        # check CPU and TPU memory usage
        cpu_mem_info = psutil.virtual_memory()
        cpu_mem_used_mb = cpu_mem_info.used / 1024 ** 2
        tpu_mem_info = xm.get_memory_info(xm.xla_device())
        tpu_mem_used_mb = (tpu_mem_info["kb_total"] - tpu_mem_info["kb_free"]) // 1024
        xm.master_print(
            f"after building block {b_idx}"
            f"\n\tCPU memory used: {cpu_mem_used_mb:.2f} MB"
            f"\n\tTPU memory used: {tpu_mem_used_mb:.2f} MB"
        )

    return all_blocks


def create_module_shard(module, world_size, rank, device):
    """
    Flatten module parameters into a single vector and shard it (to a slice)
    """
    # flatten parameters into a concatenated vector and pad it to world size
    p_flat = torch.cat([p.view(-1) for p in module.parameters()], dim=0)
    if p_flat.numel() % world_size != 0:
        padding = p_flat.new_zeros(world_size - p_flat.numel() % world_size)
        p_flat = torch.cat([p_flat, padding], dim=0)

    # shard the flattened parameter (i.e. taking a slice of it)
    begin = (p_flat.numel() // world_size) * rank
    end = (p_flat.numel() // world_size) * (rank + 1)
    p_shard = p_flat[begin:end].clone().detach()

    # free all the original model parameters and append the sharded param
    for p in module.parameters():
        p.data = torch.zeros(1, device=device)
    module.p_shard = torch.nn.Parameter(p_shard, requires_grad=True)
    return module


def _mp_fn(index, send_to_tpu):
    all_blocks = build_sharded_vit_blocks(send_to_tpu)
    xm.master_print("all blocks constructed")
    time.sleep(1000000)  # pause here so that we can inspect CPU memory via other tools


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--send_to_tpu", type=int, required=True)
    args = parser.parse_args()
    xmp.spawn(_mp_fn, args=(args.send_to_tpu,), nprocs=8)
  1. Run this file on the v3-8 TPU VM:
python3 /home/ronghanghu/oom_debug.py --send_to_tpu 1

It prints

after building block 0                                      
        CPU memory used: 23103.51 MB                        
        TPU memory used: 2951.00 MB                         
after building block 1                                      
        CPU memory used: 26145.32 MB                        
        TPU memory used: 3101.00 MB                         
after building block 2                                      
        CPU memory used: 29756.12 MB                        
        TPU memory used: 3251.00 MB
...
after building block 61
        CPU memory used: 144820.79 MB
        TPU memory used: 12115.00 MB
after building block 62
        CPU memory used: 146509.26 MB
        TPU memory used: 12265.00 MB
after building block 63
        CPU memory used: 148226.47 MB
        TPU memory used: 12415.00 MB
all blocks constructed

which shows that the CPU memory consumption keeps growing when we construct more ViT blocks, despite that all the blocks are moved to TPU and shouldn't occupy any CPU memory. And below is a screenshot along with htop when running it on my v3-8 TPU VM.

Screen Shot 2022-04-29 at 9 42 28 PM (2)


Also note that this memory leak is not due to the timm package itself. In particular, if we don't move the modules to TPU (i.e. keeping them in CPU), it actually uses less CPU memory than moving them to TPU, which shows indicates that the memory leak is beyond any CPU-side computation and likely comes from the PyTorch XLA graph tracing or the XRT server.

For example, if we run with --send_to_tpu 0 to keep everything on CPU:

python3 /home/ronghanghu/oom_debug.py --send_to_tpu 0

It prints

...
after building block 61
        CPU memory used: 81666.55 MB
        TPU memory used: 0.00 MB
after building block 62
        CPU memory used: 82880.04 MB
        TPU memory used: 0.00 MB
after building block 63
        CPU memory used: 84107.28 MB
        TPU memory used: 0.00 MB
all blocks constructed

Expected behavior

After moving a module to TPU and calling xm.mark_step, it should no longer occupy any host-side memory.

Environment

Additional context

This memory leak is the issue (2) mentioned in #3431 (comment)

This issue persists (and seems worse) after upgrading torch, torch_xla, and torchvision wheels to their nightly 20220430 versions and using them together with libtpu-nightly==0.1.dev20220413 as follows:

# torch, torchvision and torch_xla 20220430
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch-nightly+20220430-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torchvision-nightly+20220430-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-nightly+20220430-cp38-cp38-linux_x86_64.whl

# libtpu 20220413
sudo pip3 install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20220413-py3-none-any.whl

# dependencies for this example
sudo pip3 install timm==0.4.12
sudo pip3 install psutil

python3 /home/ronghanghu/oom_debug.py --send_to_tpu 1

prints

...
after building block 61
        CPU memory used: 153795.23 MB
        TPU memory used: 12115.00 MB
after building block 62
        CPU memory used: 155552.81 MB
        TPU memory used: 12265.00 MB
after building block 63
        CPU memory used: 157576.84 MB
        TPU memory used: 12415.00 MB
all blocks constructed

which is even higher CPU memory consumption (148 GB => 157 GB) than in the 20220419 wheels.

Screen Shot 2022-04-30 at 3 11 55 PM (2)

cc: @JackCaoG