DeepSpeedPlugin with activation checkpoint fails · Lightning-AI/pytorch-lightning · Discussion #9144 (original) (raw)
Thanks @nachshonc!
I've managed to reproduce the same case without Deepspeed using torch.utils.checkpoint
and our bug report model:
import deepspeed import torch from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.plugins import DeepSpeedPlugin from torch.utils.data import DataLoader, Dataset
class RandomDataset(Dataset): def init(self, size, length): self.len = length self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule): def init(self): super().init() self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return torch.utils.checkpoint.checkpoint(self.layer, x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("valid_loss", loss)
def test_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("test_loss", loss)
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def run(): train_data = DataLoader(RandomDataset(32, 64), batch_size=2) val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
model = BoringModel()
trainer = Trainer(
max_epochs=1,
)
trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
if name == "main": run()
I think the issue arises from the fact that the entire model's activations have been removed, with the input tensors not requiring any gradients, thus the autograd engine not being able to infer any gradients.
For activation checkpointing, it only makes sense to include it if you have intermediate layers which can create expensive activations. For example, swap the model out to look like this:
class BoringModel(LightningModule): def init(self): super().init() self.layer_h = torch.nn.Linear(32, 32) self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
x = torch.utils.checkpoint.checkpoint(self.layer_h, x)
return self.layer(x)
Activation checkpointing just means on the backwards, we'll need to re-compute the activations (unless you do CPU checkpointing with Deepspeed or something, where activations are just transferred to the CPU memory). In this case, there is no point checkpointing the final layer, as the final layer will instantly need to be re-computed.
class BoringModel(LightningModule): def init(self): super().init() self.layer_h = torch.nn.Linear(32, 32) self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
x = self.layer_h(x)
return torch.utils.checkpoint.checkpoint(self.layer, x) # no point doing this!
We should definitely make the docs clearer for this, I'll make this an issue :)