` feature. """">

lightning.fabric.plugins.io.torch_io — PyTorch Lightning 2.5.1.post0 documentation (original) (raw)

Source code for lightning.fabric.plugins.io.torch_io

Copyright The Lightning AI team.

Licensed under the Apache License, Version 2.0 (the "License");

you may not use this file except in compliance with the License.

You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software

distributed under the License is distributed on an "AS IS" BASIS,

WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

See the License for the specific language governing permissions and

limitations under the License.

import logging import os from typing import Any, Callable, Optional

from typing_extensions import override

from lightning.fabric.plugins.io.checkpoint_io import CheckpointIO from lightning.fabric.utilities.cloud_io import _atomic_save, get_filesystem from lightning.fabric.utilities.cloud_io import _load as pl_load from lightning.fabric.utilities.types import _PATH

log = logging.getLogger(name)

[docs]class TorchCheckpointIO(CheckpointIO): """CheckpointIO that utilizes :func:torch.save and :func:torch.load to save and load checkpoints respectively, common for most use cases.

.. warning::  This is an :ref:`experimental <versioning:Experimental API>` feature.

"""

[docs] @override def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: """Save model/training states as a checkpoint file through state-dump and file-write.

    Args:
        checkpoint: dict containing model and trainer state
        path: write-target path
        storage_options: not used in ``TorchCheckpointIO.save_checkpoint``

    Raises:
        TypeError:
            If ``storage_options`` arg is passed in

    """
    if storage_options is not None:
        raise TypeError(
            "`Trainer.save_checkpoint(..., storage_options=...)` with `storage_options` arg"
            f" is not supported for `{self.__class__.__name__}`. Please implement your custom `CheckpointIO`"
            " to define how you'd like to use `storage_options`."
        )
    fs = get_filesystem(path)
    fs.makedirs(os.path.dirname(path), exist_ok=True)
    _atomic_save(checkpoint, path)

[docs] @override def load_checkpoint( self, path: _PATH, map_location: Optional[Callable] = lambda storage, loc: storage ) -> dict[str, Any]: """Loads checkpoint using :func:torch.load, with additional handling for fsspec remote loading of files.

    Args:
        path: Path to checkpoint
        map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage
            locations.

    Returns: The loaded checkpoint.

    Raises:
        FileNotFoundError: If ``path`` is not found by the ``fsspec`` filesystem

    """

    # Try to read the checkpoint at `path`. If not exist, do not restore checkpoint.
    fs = get_filesystem(path)
    if not fs.exists(path):
        raise FileNotFoundError(f"Checkpoint file not found: {path}")

    return pl_load(path, map_location=map_location)

[docs] @override def remove_checkpoint(self, path: _PATH) -> None: """Remove checkpoint file from the filesystem.

    Args:
        path: Path to checkpoint

    """
    fs = get_filesystem(path)
    if fs.exists(path):
        fs.rm(path, recursive=True)
        log.debug(f"Removed checkpoint: {path}")