Subtle memory leak in _ReversibleModuleFunction (original) (raw)

Hi, first of all: very nice work and congrats to your MICCAI paper!

I would like to point out to you a subtle memory leak in _ReversibleModuleFunction, which is due to not using ctx.save_for_backward() for storing x. The memory leak occurs under rare conditions, namely if a network output is not consumed by the loss term, thus it is not backpropagated through, and thus del ctx.y in _ReversibleModuleFunction.backward() never happens, as _ReversibleModuleFunction.backward() for this network output is never called in the first place (at least, this is my uneducated guess on the source of the leak).

Consider the following minimal example:

class Net(nn.Module):

def __init__(self):
    super().__init__()
    self.conv = nn.Conv2d(2, 2, kernel_size=1)
    self.rev1 = ReversibleSequence(nn.ModuleList([
            ReversibleBlock(nn.Conv2d(1, 1, kernel_size=1),
                            nn.Conv2d(1, 1, kernel_size=1))]))
    self.rev2 = ReversibleSequence(nn.ModuleList(
            [ReversibleBlock(nn.Conv2d(1, 1, kernel_size=1),
                             nn.Conv2d(1, 1, kernel_size=1))]))

def forward(self, x):
    x = nn.functional.relu(self.conv(x))
    y1 = nn.functional.relu(self.rev1(x))
    y2 = nn.functional.relu(self.rev2(x))
    return y1, y2

def my_loss(output, unused):

result = output.sum()  # + 0 * unused.sum()
return result

if name == "main":

model = Net().to(torch.device("cpu"))
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)

model.train()
for i in range(10000):
    data = torch.randn(64, 2, 32, 32, dtype=torch.float32)
    y1, unused = model(data)
    loss = my_loss(y1, unused)
    loss.backward()
    optimizer.step()

As you can see, the second network output, y2, is nowhere used in the loss calculation, and memory consumption is building up. I found two ways to fix the leak:

  1. Make sure that all network outputs are consumed by the loss term (even if it is by multiplying them with zero and adding them, see code comment above).
  2. Use ctx.save_for_backward() and ctx.saved_tensors for storing and retrieving x, respectively, in _ReversibleModuleFunction.

Maybe you want to try to reproduce the memory leak, as I am not sure if it depends on the PyTorch version and/or operating system (my setup is PyTorch 1.2.0 on Windows 10). You may then want to decide whether you change the implementation of _ReversibleModuleFunction or whether you point out to the users the need to "consume" all network outputs, as described above.