empty_permute decomposition by apbose · Pull Request #2698 · pytorch/TensorRT (original) (raw)

With the empty_permute decomposition the graph is this
Pre-AOT Autograd graph:=============

graph():
   %l_x_ : torch.Tensor [num_users=1] = placeholder[target=L_x_]
   %add : [num_users=2] = call_function[target=torch.ops.aten.add](args = (%l_x_, %l_x_), kwargs = {})
   %empty_like_default : [num_users=1] = call_function[target=torch.ops.aten.empty_like.default](args = (%add,), kwargs = {})
   %add_1 : [num_users=1] = call_function[target=operator.add](args = (%empty_like_default, %add), kwargs = {})
   return (add_1,)

Post-AOT Autograd graph:=======

graph():
   %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
   %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%arg0_1,), kwargs = {})
   %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%clone, %clone), kwargs = {})
   %empty : [num_users=1] = call_function[target=torch.ops.aten.empty.memory_format](args = ([3, 2],), kwargs = {dtype: torch.float32,
 layout: torch.strided, device: cuda:0, pin_memory: False})
   %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%empty, [0, 1]), kwargs = {})
   %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%permute, %add), kwargs = {})
   return (add_1,)

Graph after constant folding:

graph():
   %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
   %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg0_1, %arg0_1), kwargs = {})
   %_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
   %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%_frozen_param0, %add), kwargs = {})
   return (add_1,)

Post-lowering passes Autograd graph:=======

graph():
   %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
   %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg0_1, %arg0_1), kwargs = {})
   %_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
   %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%_frozen_param0, %add), kwargs = {})
   return (add_1,)

Without the decomposition, the graph is
Pre-AOT Autograd graph:=============

graph():
    %l_x_ : torch.Tensor [num_users=1] = placeholder[target=L_x_]
    %add : [num_users=2] = call_function[target=torch.ops.aten.add](args = (%l_x_, %l_x_), kwargs = {})
    %empty_like_default : [num_users=1] = call_function[target=torch.ops.aten.empty_like.default](args = (%add,), kwargs = {})
    %add_1 : [num_users=1] = call_function[target=operator.add](args = (%empty_like_default, %add), kwargs = {})
    return (add_1,)

Post-AOT Autograd graph:=======

graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%arg0_1,), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%clone, %clone), kwargs = {})
    %empty_permuted : [num_users=1] = call_function[target=torch.ops.aten.empty_permuted.default](args = ([3, 2], [0, 1]), kwargs = {dt
ype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
    %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%empty_permuted, %add), kwargs = {})
    return (add_1,)

Graph after constant folding:

graph():

    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg0_1, %arg0_1), kwargs = {})
    %_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
    %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%_frozen_param0, %add), kwargs = {})
    return (add_1,)

Post-lowering passes Autograd graph:=======

graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg0_1, %arg0_1), kwargs = {})
    %_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
    %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%_frozen_param0, %add), kwargs = {})
    return (add_1,)

So empty_like decomposes into empty_permute which decomposes into empty.memory_format. The above test does not give error, even though empty.memory_format is not supported since constant folding removes the op.

I am working on empty.memory_format in PR #2745