Subtle memory leak in _ReversibleModuleFunction (original) (raw)
Hi, first of all: very nice work and congrats to your MICCAI paper!
I would like to point out to you a subtle memory leak in _ReversibleModuleFunction, which is due to not using ctx.save_for_backward() for storing x. The memory leak occurs under rare conditions, namely if a network output is not consumed by the loss term, thus it is not backpropagated through, and thus del ctx.y in _ReversibleModuleFunction.backward() never happens, as _ReversibleModuleFunction.backward() for this network output is never called in the first place (at least, this is my uneducated guess on the source of the leak).
Consider the following minimal example:
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(2, 2, kernel_size=1)
self.rev1 = ReversibleSequence(nn.ModuleList([
ReversibleBlock(nn.Conv2d(1, 1, kernel_size=1),
nn.Conv2d(1, 1, kernel_size=1))]))
self.rev2 = ReversibleSequence(nn.ModuleList(
[ReversibleBlock(nn.Conv2d(1, 1, kernel_size=1),
nn.Conv2d(1, 1, kernel_size=1))]))
def forward(self, x):
x = nn.functional.relu(self.conv(x))
y1 = nn.functional.relu(self.rev1(x))
y2 = nn.functional.relu(self.rev2(x))
return y1, y2def my_loss(output, unused):
result = output.sum() # + 0 * unused.sum()
return resultif name == "main":
model = Net().to(torch.device("cpu"))
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
model.train()
for i in range(10000):
data = torch.randn(64, 2, 32, 32, dtype=torch.float32)
y1, unused = model(data)
loss = my_loss(y1, unused)
loss.backward()
optimizer.step()As you can see, the second network output, y2, is nowhere used in the loss calculation, and memory consumption is building up. I found two ways to fix the leak:
- Make sure that all network outputs are consumed by the loss term (even if it is by multiplying them with zero and adding them, see code comment above).
- Use
ctx.save_for_backward()andctx.saved_tensorsfor storing and retrievingx, respectively, in_ReversibleModuleFunction.
Maybe you want to try to reproduce the memory leak, as I am not sure if it depends on the PyTorch version and/or operating system (my setup is PyTorch 1.2.0 on Windows 10). You may then want to decide whether you change the implementation of _ReversibleModuleFunction or whether you point out to the users the need to "consume" all network outputs, as described above.