Run decompositions before the quantizer · pytorch/executorch@0a12e33 (original) (raw)

`@@ -28,6 +28,7 @@

`

28

28

`to_edge,

`

29

29

`)

`

30

30

`from executorch.exir.pass_base import PassResult

`

``

31

`+

from torch._inductor.decomposition import remove_decompositions

`

31

32

`from torch.ao.quantization.pt2e.export_utils import model_is_exported

`

32

33

`from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e

`

33

34

``

`@@ -58,16 +59,33 @@ def convert_pt2(

`

58

59

` Returns a GraphModule with the converted model.

`

59

60

` """

`

60

61

``

``

62

`+

Get default decompositions

`

``

63

`+

decomp_table = torch.export.default_decompositions()

`

``

64

`+

Select ops to keep

`

``

65

`+

ops_to_keep = [

`

``

66

`+

torch.ops.aten.conv1d.default,

`

``

67

`+

torch.ops.aten.conv2d.default,

`

``

68

`+

torch.ops.aten.layer_norm.default,

`

``

69

`+

torch.ops.aten.linear.default,

`

``

70

`+

torch.ops.aten.matmul.default,

`

``

71

`+

]

`

``

72

`+

Remove decompositions for the ops we want to keep

`

``

73

`` +

pyre-fixme[6]: For 1st argument expected `Dict[typing.Callable[..., typing.Any

``

``

74

`+

remove_decompositions(decomp_table, ops_to_keep)

`

61

75

`# Export with dynamo

`

62

``

`-

model_gm = torch.export.export_for_training(model, inputs).module()

`

``

76

`+

model_gm = (

`

``

77

`+

torch.export.export_for_training(model, inputs)

`

``

78

`+

.run_decompositions(decomp_table)

`

``

79

`+

.module()

`

``

80

`+

)

`

63

81

``

64

``

`-

if model_gm_has_SDPA(model_gm): # pyre-fixme[6]

`

``

82

`+

if model_gm_has_SDPA(model_gm):

`

65

83

`# Decompose SDPA

`

66

``

`-

DecomposeScaledDotProductAttention(False)(model_gm) # pyre-fixme[6]

`

``

84

`+

DecomposeScaledDotProductAttention(False)(model_gm)

`

67

85

``

68

86

`# Swap _safe_softmax with _softmax (see https://github.com/pytorch/pytorch/pull/133882

`

69

87

`# for details).

`

70

``

`-

result = ReplaceSafeSoftmaxWithSoftmax()(model_gm) # pyre-fixme[6]

`

``

88

`+

result = ReplaceSafeSoftmaxWithSoftmax()(model_gm)

`

71

89

`assert result is not None

`

72

90

`model_gm = result.graph_module

`

73

91

``