TorchAO compile + offloading tests by a-r-r-o-w · Pull Request #11697 · huggingface/diffusers (original) (raw)
Was able to spend some time and the following diff solves the problem:
Expand
diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index ddf97aca5..28454aae9 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -631,11 +631,14 @@ class TorchAoSerializationTest(unittest.TestCase):
@require_torchao_version_greater_or_equal("0.7.0") class TorchAoCompileTest(QuantCompileTests):
- quantization_config = PipelineQuantizationConfig(
quant_mapping={"transformer": TorchAoConfig(quant_type="int8_weight_only"),},- )
- @property
- def quantization_config(self):
config = PipelineQuantizationConfig(quant_mapping={"transformer": TorchAoConfig(quant_type="int8_weight_only"),},)
def test_torch_compile(self): super()._test_torch_compile(quantization_config=self.quantization_config)return config
ChatGPT does a nice job of explaining what is happening:
https://chatgpt.com/share/685951bc-7c88-8013-b317-62683d1a1fa9. What I didn't investigate is that how come the other TorchAO tests are not getting flagged because of torchao installation errors 🤷