Run decompositions before the quantizer · pytorch/executorch@0a12e33 (original) (raw)
`@@ -28,6 +28,7 @@
`
28
28
`to_edge,
`
29
29
`)
`
30
30
`from executorch.exir.pass_base import PassResult
`
``
31
`+
from torch._inductor.decomposition import remove_decompositions
`
31
32
`from torch.ao.quantization.pt2e.export_utils import model_is_exported
`
32
33
`from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
`
33
34
``
`@@ -58,16 +59,33 @@ def convert_pt2(
`
58
59
` Returns a GraphModule with the converted model.
`
59
60
` """
`
60
61
``
``
62
`+
Get default decompositions
`
``
63
`+
decomp_table = torch.export.default_decompositions()
`
``
64
`+
Select ops to keep
`
``
65
`+
ops_to_keep = [
`
``
66
`+
torch.ops.aten.conv1d.default,
`
``
67
`+
torch.ops.aten.conv2d.default,
`
``
68
`+
torch.ops.aten.layer_norm.default,
`
``
69
`+
torch.ops.aten.linear.default,
`
``
70
`+
torch.ops.aten.matmul.default,
`
``
71
`+
]
`
``
72
`+
Remove decompositions for the ops we want to keep
`
``
73
`` +
pyre-fixme[6]: For 1st argument expected `Dict[typing.Callable[..., typing.Any
``
``
74
`+
remove_decompositions(decomp_table, ops_to_keep)
`
61
75
`# Export with dynamo
`
62
``
`-
model_gm = torch.export.export_for_training(model, inputs).module()
`
``
76
`+
model_gm = (
`
``
77
`+
torch.export.export_for_training(model, inputs)
`
``
78
`+
.run_decompositions(decomp_table)
`
``
79
`+
.module()
`
``
80
`+
)
`
63
81
``
64
``
`-
if model_gm_has_SDPA(model_gm): # pyre-fixme[6]
`
``
82
`+
if model_gm_has_SDPA(model_gm):
`
65
83
`# Decompose SDPA
`
66
``
`-
DecomposeScaledDotProductAttention(False)(model_gm) # pyre-fixme[6]
`
``
84
`+
DecomposeScaledDotProductAttention(False)(model_gm)
`
67
85
``
68
86
`# Swap _safe_softmax with _softmax (see https://github.com/pytorch/pytorch/pull/133882
`
69
87
`# for details).
`
70
``
`-
result = ReplaceSafeSoftmaxWithSoftmax()(model_gm) # pyre-fixme[6]
`
``
88
`+
result = ReplaceSafeSoftmaxWithSoftmax()(model_gm)
`
71
89
`assert result is not None
`
72
90
`model_gm = result.graph_module
`
73
91
``