torch.random — PyTorch 2.7 documentation (original) (raw)

Source code for torch.random

mypy: allow-untyped-defs

import contextlib import warnings from collections.abc import Generator

import torch from torch._C import default_generator

[docs]def set_rng_state(new_state: torch.Tensor) -> None: r"""Sets the random number generator state.

.. note:: This function only works for CPU. For CUDA, please use
    :func:`torch.manual_seed`, which works for both CPU and CUDA.

Args:
    new_state (torch.ByteTensor): The desired state
"""
default_generator.set_state(new_state)

[docs]def get_rng_state() -> torch.Tensor: r"""Returns the random number generator state as a torch.ByteTensor.

.. note:: The returned state is for the default generator on CPU only.

See also: :func:`torch.random.fork_rng`.
"""
return default_generator.get_state()

[docs]def manual_seed(seed) -> torch._C.Generator: r"""Sets the seed for generating random numbers on all devices. Returns a torch.Generator object.

Args:
    seed (int): The desired seed. Value must be within the inclusive range
        `[-0x8000_0000_0000_0000, 0xffff_ffff_ffff_ffff]`. Otherwise, a RuntimeError
        is raised. Negative inputs are remapped to positive values with the formula
        `0xffff_ffff_ffff_ffff + seed`.
"""
seed = int(seed)
import torch.cuda

if not torch.cuda._is_in_bad_fork():
    torch.cuda.manual_seed_all(seed)

import torch.mps

if not torch.mps._is_in_bad_fork():
    torch.mps.manual_seed(seed)

import torch.xpu

if not torch.xpu._is_in_bad_fork():
    torch.xpu.manual_seed_all(seed)

_seed_custom_device(seed)

return default_generator.manual_seed(seed)

[docs]def seed() -> int: r"""Sets the seed for generating random numbers to a non-deterministic random number on all devices. Returns a 64 bit number used to seed the RNG. """ seed = default_generator.seed() import torch.cuda

if not torch.cuda._is_in_bad_fork():
    torch.cuda.manual_seed_all(seed)

import torch.mps

if not torch.mps._is_in_bad_fork():
    torch.mps.manual_seed(seed)

import torch.xpu

if not torch.xpu._is_in_bad_fork():
    torch.xpu.manual_seed_all(seed)

_seed_custom_device(seed)

return seed

def _seed_custom_device(seed) -> None: r"""Sets the seed to generate random numbers for custom device.

Args:
    seed (int): The desired seed.

See [Note: support the custom device with privateuse1]
"""
seed = int(seed)
custom_backend_name = torch._C._get_privateuse1_backend_name()
if hasattr(torch, custom_backend_name):
    custom_device_mod = getattr(torch, custom_backend_name)
    _bad_fork_name = "_is_in_bad_fork"
    _seed_all_name = "manual_seed_all"
    if hasattr(custom_device_mod, _bad_fork_name) and hasattr(
        custom_device_mod, _seed_all_name
    ):
        if not getattr(custom_device_mod, _bad_fork_name)():
            getattr(custom_device_mod, _seed_all_name)(seed)
    else:
        message = f"Set seed for `{custom_backend_name}` device does not take effect, please add API's "
        message += f"`{_bad_fork_name}` and `{_seed_all_name}` to `{custom_backend_name}` device module."
        warnings.warn(message, UserWarning, stacklevel=3)

[docs]def initial_seed() -> int: r"""Returns the initial seed for generating random numbers as a Python long.

.. note:: The returned seed is for the default generator on CPU only.
"""
return default_generator.initial_seed()

_fork_rng_warned_already = False

[docs]@contextlib.contextmanager def fork_rng( devices=None, enabled=True, _caller="fork_rng", _devices_kw="devices", device_type="cuda", ) -> Generator: """ Forks the RNG, so that when you return, the RNG is reset to the state that it was previously in.

Args:
    devices (iterable of Device IDs): devices for which to fork
        the RNG. CPU RNG state is always forked. By default, :meth:`fork_rng` operates
        on all devices, but will emit a warning if your machine has a lot
        of devices, since this function will run very slowly in that case.
        If you explicitly specify devices, this warning will be suppressed
    enabled (bool): if ``False``, the RNG is not forked.  This is a convenience
        argument for easily disabling the context manager without having
        to delete it and unindent your Python code under it.
    device_type (str): device type str, default is `cuda`. As for custom device,
        see details in [Note: support the custom device with privateuse1]
"""

if device_type == "meta":
    yield
    return

device_type = torch.device(device_type).type
device_mod = getattr(torch, device_type, None)
if device_mod is None:
    raise RuntimeError(
        f"torch has no module of `{device_type}`, you should register "
        + "a module by `torch._register_device_module`."
    )
global _fork_rng_warned_already

# Internal arguments:
#   _caller: the function which called fork_rng, which the user used
#   _devices_kw: the devices keyword of _caller

if not enabled:
    yield
    return

if devices is None:
    num_devices = device_mod.device_count()
    if num_devices > 1 and not _fork_rng_warned_already:
        message = (
            f"{device_type.upper()} reports that you have {num_devices} available devices, and "
            f"you have used {_caller} without explicitly specifying which devices are being used. "
            f"For safety, we initialize *every* {device_type.upper()} device by default, which can "
            f"be quite slow if you have a lot of {device_type.upper()}s. If you know that you are only"
            f" making use of a few {device_type.upper()} devices, set the environment variable "
            f"{device_type.upper()}_VISIBLE_DEVICES or the '{_devices_kw}' keyword argument of {_caller} "
            "with the set of devices you are actually using. For example, if you are using CPU only, "
            "set device.upper()_VISIBLE_DEVICES= or devices=[]; if you are using device 0 only, "
            f"set {device_type.upper()}_VISIBLE_DEVICES=0 or devices=[0].  To initialize all devices "
            f"and suppress this warning, set the '{_devices_kw}' keyword argument to "
            f"`range(torch.{device_type}.device_count())`."
        )
        warnings.warn(message)
        _fork_rng_warned_already = True
    devices = list(range(num_devices))
else:
    # Protect against user passing us a generator; we need to traverse this
    # multiple times but a generator will be exhausted upon first traversal
    devices = list(devices)

cpu_rng_state = torch.get_rng_state()
device_rng_states = [device_mod.get_rng_state(device) for device in devices]

try:
    yield
finally:
    torch.set_rng_state(cpu_rng_state)
    for device, device_rng_state in zip(devices, device_rng_states):
        device_mod.set_rng_state(device_rng_state, device)