Handle multiple inplace update input output aliasing (#7023) · pytorch/xla@e3fc033 (original) (raw)

`@@ -29,6 +29,7 @@

`

29

29

`#include "torch_xla/csrc/helpers.h"

`

30

30

`#include "torch_xla/csrc/ops/as_strided.h"

`

31

31

`#include "torch_xla/csrc/ops/as_strided_view_update.h"

`

``

32

`+

#include "torch_xla/csrc/ops/device_data.h"

`

32

33

`#include "torch_xla/csrc/ops/diagonal_view_update.h"

`

33

34

`#include "torch_xla/csrc/ops/einsum_utilities.h"

`

34

35

`#include "torch_xla/csrc/ops/index_ops.h"

`

`@@ -2538,7 +2539,38 @@ void XLANativeFunctions::_propagate_xla_data(const at::Tensor& input,

`

2538

2539

`// 1) Aid XLA's InputOutputAlias.

`

2539

2540

`auto input_tensor = bridge::GetXlaTensor(input);

`

2540

2541

`auto output_tensor = bridge::GetXlaTensor(output);

`

2541

``

`-

output_tensor->data()->alias_id = input_tensor->GetUniqueId();

`

``

2542

`+

if (input_tensor->CurrentDataHandle() != nullptr ||

`

``

2543

`+

(input_tensor->CurrentIrValue().node != nullptr &&

`

``

2544

`+

torch_xla::DeviceData::Cast(

`

``

2545

`+

input_tensor->CurrentIrValue().node.get()))) {

`

``

2546

`+

/*

`

``

2547

`+

if input has a XLAData or holds a devicedata node, set alias_id to

`

``

2548

`+

tensor_id. Consider the case.

`

``

2549

+

``

2550

`+

// x.tensor_id = 1, x.alias_id = 1

`

``

2551

`+

x = torch.randn(5,5).to(xla_device())

`

``

2552

`+

// x.tensor_id = 2, x.alias_id should be 1

`

``

2553

`+

x += 1

`

``

2554

`+

xm.mark_step()

`

``

2555

`+

// x.tensor_id =3, x.alias_id should be 2 since input tensor id will be 2

`

``

2556

`+

// for this graph

`

``

2557

`+

x *= 1 of 1

`

``

2558

`+

*/

`

``

2559

`+

output_tensor->data()->alias_id = input_tensor->GetUniqueId();

`

``

2560

`+

} else {

`

``

2561

`+

/*

`

``

2562

`+

Consider the case

`

``

2563

+

``

2564

`+

// x.tensor_id = 1, x.alias_id = 1

`

``

2565

`+

x = torch.randn(5,5).to(xla_device())

`

``

2566

`+

// x.tensor_id = 2, x.alias_id should be 1

`

``

2567

`+

x += 1

`

``

2568

`+

// x.tensor_id = 3, x.alias_id should still be 1

`

``

2569

`+

x * = 2

`

``

2570

`+

xm.mark_step()

`

``

2571

`+

*/

`

``

2572

`+

output_tensor->data()->alias_id = input_tensor->data()->alias_id;

`

``

2573

`+

}

`

2542

2574

``

2543

2575

`// 2) Aid SPMD.

`

2544

2576

` XLATensor::ShardingSpecPtr sharding = input_tensor->sharding_spec();

`