BasePredictionWriter — PyTorch Lightning 2.5.1.post0 documentation (original) (raw)
class lightning.pytorch.callbacks.BasePredictionWriter(write_interval='batch')[source]¶
Bases: Callback
Base class to implement how the predictions should be stored.
Parameters:
write_interval¶ (Literal['batch'
, 'epoch'
, 'batch_and_epoch'
]) – When to write.
Example:
import torch from lightning.pytorch.callbacks import BasePredictionWriter
class CustomWriter(BasePredictionWriter):
def __init__(self, output_dir, write_interval):
super().__init__(write_interval)
self.output_dir = output_dir
def write_on_batch_end(
self, trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx
):
torch.save(prediction, os.path.join(self.output_dir, dataloader_idx, f"{batch_idx}.pt"))
def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):
torch.save(predictions, os.path.join(self.output_dir, "predictions.pt"))
pred_writer = CustomWriter(output_dir="pred_path", write_interval="epoch") trainer = Trainer(callbacks=[pred_writer]) model = BoringModel() trainer.predict(model, return_predictions=False)
Example:
multi-device inference example
import torch from lightning.pytorch.callbacks import BasePredictionWriter
class CustomWriter(BasePredictionWriter):
def __init__(self, output_dir, write_interval):
super().__init__(write_interval)
self.output_dir = output_dir
def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):
# this will create N (num processes) files in `output_dir` each containing
# the predictions of it's respective rank
torch.save(predictions, os.path.join(self.output_dir, f"predictions_{trainer.global_rank}.pt"))
# optionally, you can also save `batch_indices` to get the information about the data index
# from your prediction data
torch.save(batch_indices, os.path.join(self.output_dir, f"batch_indices_{trainer.global_rank}.pt"))
or you can set write_interval="batch"
and override write_on_batch_end
to save
predictions at batch level
pred_writer = CustomWriter(output_dir="pred_path", write_interval="epoch") trainer = Trainer(accelerator="gpu", strategy="ddp", devices=8, callbacks=[pred_writer]) model = BoringModel() trainer.predict(model, return_predictions=False)
on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)[source]¶
Called when the predict batch ends.
Return type:
on_predict_epoch_end(trainer, pl_module)[source]¶
Called when the predict epoch ends.
Return type:
setup(trainer, pl_module, stage)[source]¶
Called when fit, validate, test, predict, or tune begins.
Return type:
write_on_batch_end(trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx)[source]¶
Override with the logic to write a single batch.
Return type:
write_on_epoch_end(trainer, pl_module, predictions, batch_indices)[source]¶
Override with the logic to write all batches.
Return type: