torch.Tensor.to — PyTorch 2.0 documentation (original) (raw)
Tensor.to(*args, **kwargs) → Tensor¶
Performs Tensor dtype and/or device conversion. A torch.dtype and torch.device are inferred from the arguments of self.to(*args, **kwargs)
.
Here are the ways to call to
:
to(dtype, non_blocking=False, copy=False, memory_format=torch.preserve_format) → Tensor
Returns a Tensor with the specified dtype
Args:
memory_format (torch.memory_format, optional): the desired memory format of returned Tensor. Default:
torch.preserve_format
.
torch.to(device=None, dtype=None, non_blocking=False, copy=False, memory_format=torch.preserve_format) → Tensor
Returns a Tensor with the specified device and (optional)dtype. If dtype is
None
it is inferred to beself.dtype
. Whennon_blocking
, tries to convert asynchronously with respect to the host if possible, e.g., converting a CPU Tensor with pinned memory to a CUDA Tensor. Whencopy
is set, a new Tensor is created even when the Tensor already matches the desired conversion.Args:
memory_format (torch.memory_format, optional): the desired memory format of returned Tensor. Default:
torch.preserve_format
.
torch.to(other, non_blocking=False, copy=False) → Tensor
Returns a Tensor with same torch.dtype and torch.device as the Tensor
other
. Whennon_blocking
, tries to convert asynchronously with respect to the host if possible, e.g., converting a CPU Tensor with pinned memory to a CUDA Tensor. Whencopy
is set, a new Tensor is created even when the Tensor already matches the desired conversion.
Example:
tensor = torch.randn(2, 2) # Initially dtype=float32, device=cpu tensor.to(torch.float64) tensor([[-0.5044, 0.0005], [ 0.3310, -0.0584]], dtype=torch.float64)
cuda0 = torch.device('cuda:0') tensor.to(cuda0) tensor([[-0.5044, 0.0005], [ 0.3310, -0.0584]], device='cuda:0')
tensor.to(cuda0, dtype=torch.float64) tensor([[-0.5044, 0.0005], [ 0.3310, -0.0584]], dtype=torch.float64, device='cuda:0')
other = torch.randn((), dtype=torch.float64, device=cuda0) tensor.to(other, non_blocking=True) tensor([[-0.5044, 0.0005], [ 0.3310, -0.0584]], dtype=torch.float64, device='cuda:0')