Relax QAT dtype assertion (#692) · pytorch/ao@b523f9f (original) (raw)
File tree
2 files changed
lines changed
- torchao/quantization/prototype/qat
2 files changed
lines changed
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -422,9 +422,6 @@ def test_qat_4w_primitives(self): | ||
422 | 422 | |
423 | 423 | @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") |
424 | 424 | @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") |
425 | -# TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517 | |
426 | -@unittest.skipIf(TORCH_VERSION_AT_LEAST_2_4, "assert input.dtype == torch.float32" ) | |
427 | -@unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 doesn't work for 2.5+ right now") | |
428 | 425 | def test_qat_4w_linear(self): |
429 | 426 | from torchao.quantization.prototype.qat.api import Int4WeightOnlyQATLinear |
430 | 427 | from torchao.quantization.GPTQ import WeightOnlyInt4Linear |
@@ -453,9 +450,6 @@ def test_qat_4w_linear(self): | ||
453 | 450 | |
454 | 451 | @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") |
455 | 452 | @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") |
456 | -# TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517 | |
457 | -@unittest.skipIf(TORCH_VERSION_AT_LEAST_2_4, "assert input.dtype == torch.float32" ) | |
458 | -@unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 doesn't work for 2.5+ right now") | |
459 | 453 | def test_qat_4w_quantizer(self): |
460 | 454 | from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer |
461 | 455 | from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -36,13 +36,6 @@ def forward( | ||
36 | 36 | block_size: List[int], |
37 | 37 | zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, |
38 | 38 | ) -> torch.Tensor: |
39 | -# Note: for bf16 inputs, casting them to fp32 has the unexpected | |
40 | -# side effect of reducing memory footprint significantly, presumably | |
41 | -# because bf16 * fp32 kernels are not as memory efficient | |
42 | -assert input.dtype == torch.float32 | |
43 | -assert scales.dtype == torch.float32 | |
44 | -assert zero_points.dtype == torch.int32 | |
45 | - | |
46 | 39 | (fq, mask) = fake_quantize_affine_cachemask( |
47 | 40 | input, |
48 | 41 | block_size, |