PyTorch XLA .data assignment fails when the new tensor is a different shape 路 Issue #3502 路 pytorch/xla (original) (raw)

馃悰 Bug

In the latest nightly 20220413 PyTorch XLA build, the shape assignment example in #3392 (comment) is broken again. This is now breaking XLA FSDP implementation in (#3431).

To Reproduce

  1. Allocate a v3-8 TPU VM with tpu-vm-pt-1.10 runtime and install the nightly 20220413 environment
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch-nightly+20220413-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torchvision-nightly+20220413-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-nightly+20220413-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20220413-py3-none-any.whl
  1. Run the example below
import torch
import torch_xla.core.xla_model as xm

device = xm.xla_device()

x3 = torch.zeros(10, device=device)
y3 = x3.view(-1)
# This should NOT update y3 because `x3.data` is not in-place modified
x3.data = y3[:5] + 1

print(f"y3: {y3}")

which gives

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python3.8/dist-packages/torch/_tensor.py", line 655, in __format__
    return object.__format__(self, format_spec)
  File "/usr/local/lib/python3.8/dist-packages/torch/_tensor.py", line 338, in __repr__
    return torch._tensor_str._str(self)
  File "/usr/local/lib/python3.8/dist-packages/torch/_tensor_str.py", line 439, in _str
    return _str_intern(self)
  File "/usr/local/lib/python3.8/dist-packages/torch/_tensor_str.py", line 325, in _str_intern
    self = self.to('cpu')
RuntimeError: INVALID_ARGUMENT: From /job:localservice/replica:0/task:0:
2 root error(s) found.
  (0) INVALID_ARGUMENT: Run-time shape mismatch for XRTExecute argument[0] (4423548877118753). Expected element_type: F32
dimensions: 10
layout {
  minor_to_major: 0
  format: DENSE
  tiles {
    dimensions: 256
  }
}
is_dynamic_dimension: false
; got element_type: F32
dimensions: 10
layout {
  minor_to_major: 0
  format: DENSE
}
is_dynamic_dimension: false

         [[{{node XRTExecute}}]]
         [[XRTExecute_G12]]
  (1) INVALID_ARGUMENT: Run-time shape mismatch for XRTExecute argument[0] (4423548877118753). Expected element_type: F32
dimensions: 10
layout {
  minor_to_major: 0
  format: DENSE
  tiles {
    dimensions: 256
  }
}
is_dynamic_dimension: false
; got element_type: F32
dimensions: 10
layout {
  minor_to_major: 0
  format: DENSE
}
is_dynamic_dimension: false

         [[{{node XRTExecute}}]]
0 successful operations.
0 derived errors ignored.
Recent warning and error logs:
  OP_REQUIRES failed at tpu_execute_op.cc:266 : INVALID_ARGUMENT: Run-time shape mismatch for XRTExecute argument[0] (4423548877118753). Expected element_type: F32
dimensions: 10
layout {
  minor_to_major: 0
  format: DENSE
  tiles {
    dimensions: 256
  }
}
is_dynamic_dimension: false
; got element_type: F32
dimensions: 10
layout {
  minor_to_major: 0
  format: DENSE
}
is_dynamic_dimension: false

Expected behavior

On the previous nightly 20220408 build

sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch-nightly+20220408-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torchvision-nightly+20220408-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-nightly+20220408-cp38-cp38-linux_x86_64.whl
sudo pip3 install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20220408-py3-none-any.whl

this example was working well and prints

y3: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='xla:1')

as expected (consistent with normal PyTorch behavior on CPU/GPU)

Environment

Additional context

It would be great to create a test case for the example above to guard against future issues.

cc: @JackCaoG