High memory usage while training (original) (raw)
January 24, 2017, 3:47am 1
Hello all,
I train a simple RNN network to predict a label on each input timestep on a huge random dataset.
I record memory usage while training, and notice that it is increasing linearly with dataset size:
(VSIZE = Virtual Memory recorded by Ubuntu, %MEM: How much % RAM it takes, x-axis = time in second)
My training script for reference:
class testNet(nn.Module):
def __init__(self):
super(testNet, self).__init__()
self.rnn = nn.RNN(input_size=200, hidden_size=1000, num_layers=1)
self.linear = nn.Linear(1000, 100)
def forward(self, x, init):
x = self.rnn(x, init)[0]
y = self.linear(x.view(x.size(0)*x.size(1), x.size(2)))
return y.view(x.size(0), x.size(1), y.size(1))
net = testNet()
init = Variable(torch.zeros(1, 4, 1000))
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
total_loss = 0.0
for i in range(10000): #10000 mini-batch
input = Variable(torch.randn(1000, 4, 200)) #Seqlen = 1000, batch_size = 4, feature = 200
target = Variable(torch.LongTensor(4, 1000).zero_())
optimizer.zero_grad()
output = net(input, init)
loss = criterion(output.view(-1, output.size(2)), target.view(-1))
loss.backward()
optimizer.step()
total_loss += loss[0]
print(total_loss)
I expect memory usage not increasing per mini-batch. What might be the problem? (Correct me if my script is wrong)
smth January 24, 2017, 5:08pm 2
hi @NgPDat
I’m trying to reproduce your results.
Can you tell me the units of VSIZE? Is it bytes?
And %MEM, is it a percentage of the system memory?
So far, my run is pretty stable at around 105MB, after 400 mini-batches, I will wait for some time.
apaszke (Adam Paszke) January 24, 2017, 7:53pm 3
I think I see the problem. You have to remember that loss
is a Variable, and indexing Variables, always returns a Variable, even if they’re 1D! So when you do total_loss += loss[0]
you’re actually making total_loss
a Variable, and adding more and more subgraphs to its history, making it impossible to free them, because you’re still holding a reference. Just replace total_loss += loss[0]
with total_loss += loss.data[0]
and it should be back to normal.
NgPDat (Ng P Dat) January 25, 2017, 6:57am 4
Work like a charm! Thank you.
I think I have better understanding of Variable now.
NgPDat (Ng P Dat) January 25, 2017, 12:48pm 5
VSIZE is in kilobytes.
Yes, %MEM is percentage of the system memory.
The whole script for recording memory usage is from here: http://stackoverflow.com/questions/7998302/graphing-a-processs-memory-usage
ado_sar (ado sar) January 24, 2025, 11:37pm 6
Why we add subgraphs to its history? Is it because loss
still requires grad after loss.backward()
?