Handle multiple inplace update input output aliasing (#7023) · pytorch/xla@e3fc033 (original) (raw)
`@@ -29,6 +29,7 @@
`
29
29
`#include "torch_xla/csrc/helpers.h"
`
30
30
`#include "torch_xla/csrc/ops/as_strided.h"
`
31
31
`#include "torch_xla/csrc/ops/as_strided_view_update.h"
`
``
32
`+
#include "torch_xla/csrc/ops/device_data.h"
`
32
33
`#include "torch_xla/csrc/ops/diagonal_view_update.h"
`
33
34
`#include "torch_xla/csrc/ops/einsum_utilities.h"
`
34
35
`#include "torch_xla/csrc/ops/index_ops.h"
`
`@@ -2538,7 +2539,38 @@ void XLANativeFunctions::_propagate_xla_data(const at::Tensor& input,
`
2538
2539
`// 1) Aid XLA's InputOutputAlias.
`
2539
2540
`auto input_tensor = bridge::GetXlaTensor(input);
`
2540
2541
`auto output_tensor = bridge::GetXlaTensor(output);
`
2541
``
`-
output_tensor->data()->alias_id = input_tensor->GetUniqueId();
`
``
2542
`+
if (input_tensor->CurrentDataHandle() != nullptr ||
`
``
2543
`+
(input_tensor->CurrentIrValue().node != nullptr &&
`
``
2544
`+
torch_xla::DeviceData::Cast(
`
``
2545
`+
input_tensor->CurrentIrValue().node.get()))) {
`
``
2546
`+
/*
`
``
2547
`+
if input has a XLAData or holds a devicedata node, set alias_id to
`
``
2548
`+
tensor_id. Consider the case.
`
``
2549
+
``
2550
`+
// x.tensor_id = 1, x.alias_id = 1
`
``
2551
`+
x = torch.randn(5,5).to(xla_device())
`
``
2552
`+
// x.tensor_id = 2, x.alias_id should be 1
`
``
2553
`+
x += 1
`
``
2554
`+
xm.mark_step()
`
``
2555
`+
// x.tensor_id =3, x.alias_id should be 2 since input tensor id will be 2
`
``
2556
`+
// for this graph
`
``
2557
`+
x *= 1 of 1
`
``
2558
`+
*/
`
``
2559
`+
output_tensor->data()->alias_id = input_tensor->GetUniqueId();
`
``
2560
`+
} else {
`
``
2561
`+
/*
`
``
2562
`+
Consider the case
`
``
2563
+
``
2564
`+
// x.tensor_id = 1, x.alias_id = 1
`
``
2565
`+
x = torch.randn(5,5).to(xla_device())
`
``
2566
`+
// x.tensor_id = 2, x.alias_id should be 1
`
``
2567
`+
x += 1
`
``
2568
`+
// x.tensor_id = 3, x.alias_id should still be 1
`
``
2569
`+
x * = 2
`
``
2570
`+
xm.mark_step()
`
``
2571
`+
*/
`
``
2572
`+
output_tensor->data()->alias_id = input_tensor->data()->alias_id;
`
``
2573
`+
}
`
2542
2574
``
2543
2575
`// 2) Aid SPMD.
`
2544
2576
` XLATensor::ShardingSpecPtr sharding = input_tensor->sharding_spec();
`