torch.cat — PyTorch 2.7 documentation (original) (raw)

torch.cat(tensors, dim=0, *, out=None) → Tensor

Concatenates the given sequence of tensors in tensors in the given dimension. All tensors must either have the same shape (except in the concatenating dimension) or be a 1-D empty tensor with size (0,).

torch.cat() can be seen as an inverse operation for torch.split()and torch.chunk().

torch.cat() can be best understood via examples.

See also

torch.stack() concatenates the given sequence along a new dimension.

Parameters

Keyword Arguments

out (Tensor, optional) – the output tensor.

Example:

x = torch.randn(2, 3) x tensor([[ 0.6580, -1.0969, -0.4614], [-0.1034, -0.5790, 0.1497]]) torch.cat((x, x, x), 0) tensor([[ 0.6580, -1.0969, -0.4614], [-0.1034, -0.5790, 0.1497], [ 0.6580, -1.0969, -0.4614], [-0.1034, -0.5790, 0.1497], [ 0.6580, -1.0969, -0.4614], [-0.1034, -0.5790, 0.1497]]) torch.cat((x, x, x), 1) tensor([[ 0.6580, -1.0969, -0.4614, 0.6580, -1.0969, -0.4614, 0.6580, -1.0969, -0.4614], [-0.1034, -0.5790, 0.1497, -0.1034, -0.5790, 0.1497, -0.1034, -0.5790, 0.1497]])