Relax QAT dtype assertion (#692) · pytorch/ao@b523f9f (original) (raw)

File tree

2 files changed

lines changed

2 files changed

lines changed

Original file line number Diff line number Diff line change
@@ -422,9 +422,6 @@ def test_qat_4w_primitives(self):
422 422
423 423 @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
424 424 @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
425 -# TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517
426 -@unittest.skipIf(TORCH_VERSION_AT_LEAST_2_4, "assert input.dtype == torch.float32" )
427 -@unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 doesn't work for 2.5+ right now")
428 425 def test_qat_4w_linear(self):
429 426 from torchao.quantization.prototype.qat.api import Int4WeightOnlyQATLinear
430 427 from torchao.quantization.GPTQ import WeightOnlyInt4Linear
@@ -453,9 +450,6 @@ def test_qat_4w_linear(self):
453 450
454 451 @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
455 452 @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
456 -# TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517
457 -@unittest.skipIf(TORCH_VERSION_AT_LEAST_2_4, "assert input.dtype == torch.float32" )
458 -@unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 doesn't work for 2.5+ right now")
459 453 def test_qat_4w_quantizer(self):
460 454 from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer
461 455 from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer
Original file line number Diff line number Diff line change
@@ -36,13 +36,6 @@ def forward(
36 36 block_size: List[int],
37 37 zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
38 38 ) -> torch.Tensor:
39 -# Note: for bf16 inputs, casting them to fp32 has the unexpected
40 -# side effect of reducing memory footprint significantly, presumably
41 -# because bf16 * fp32 kernels are not as memory efficient
42 -assert input.dtype == torch.float32
43 -assert scales.dtype == torch.float32
44 -assert zero_points.dtype == torch.int32
45 -
46 39 (fq, mask) = fake_quantize_affine_cachemask(
47 40 input,
48 41 block_size,