LightningDataModule — PyTorch Lightning 2.5.1.post0 documentation (original) (raw)

class lightning.pytorch.core.LightningDataModule[source]

Bases: DataHooks, HyperparametersMixin

A DataModule standardizes the training, val, test splits, data preparation and transforms. The main advantage is consistent data splits, data preparation and transforms across models.

Example:

import lightning as L import torch.utils.data as data from lightning.pytorch.demos.boring_classes import RandomDataset

class MyDataModule(L.LightningDataModule): def prepare_data(self): # download, IO, etc. Useful with shared filesystems # only called on 1 GPU/TPU in distributed ...

def setup(self, stage):
    # make assignments here (val/train/test split)
    # called on every process in DDP
    dataset = RandomDataset(1, 100)
    self.train, self.val, self.test = data.random_split(
        dataset, [80, 10, 10], generator=torch.Generator().manual_seed(42)
    )

def train_dataloader(self):
    return data.DataLoader(self.train)

def val_dataloader(self):
    return data.DataLoader(self.val)

def test_dataloader(self):
    return data.DataLoader(self.test)

def on_exception(self, exception):
    # clean up state after the trainer faced an exception
    ...

def teardown(self):
    # clean up state after the trainer stops, delete files...
    # called on every process in DDP
    ...

prepare_data_per_node

If True, each LOCAL_RANK=0 will call prepare data. Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data.

allow_zero_length_dataloader_with_multiple_devices

If True, dataloader with zero length within local rank is allowed. Default value is False.

classmethod from_datasets(train_dataset=None, val_dataset=None, test_dataset=None, predict_dataset=None, batch_size=1, num_workers=0, **datamodule_kwargs)[source]

Create an instance from torch.utils.data.Dataset.

Parameters:

Return type:

LightningDataModule

load_from_checkpoint(checkpoint_path, map_location=None, hparams_file=None, **kwargs)[source]

Primary way of loading a datamodule from a checkpoint. When Lightning saves a checkpoint it stores the arguments passed to __init__ in the checkpoint under "datamodule_hyper_parameters".

Any arguments specified through **kwargs will override args stored in "datamodule_hyper_parameters".

Parameters:

Return type:

Self

Returns:

LightningDataModule instance with loaded weights and hyperparameters (if available).

Note

load_from_checkpoint is a class method. You must use your LightningDataModule class to call it instead of the LightningDataModule instance, or aTypeError will be raised.

Example:

load weights without mapping ...

datamodule = MyLightningDataModule.load_from_checkpoint('path/to/checkpoint.ckpt')

or load weights and hyperparameters from separate files.

datamodule = MyLightningDataModule.load_from_checkpoint( 'path/to/checkpoint.ckpt', hparams_file='/path/to/hparams_file.yaml' )

override some of the params with new values

datamodule = MyLightningDataModule.load_from_checkpoint( PATH, batch_size=32, num_workers=10, )

load_state_dict(state_dict)[source]

Called when loading a checkpoint, implement to reload datamodule state given datamodule state_dict.

Parameters:

state_dict (dict[str, Any]) – the datamodule state returned by state_dict.

Return type:

None

on_exception(exception)[source]

Called when the trainer execution is interrupted by an exception.

Return type:

None

state_dict()[source]

Called when saving a checkpoint, implement to generate and save datamodule state.

Return type:

dict[str, Any]

Returns:

A dictionary containing datamodule state.