remove timing check · pytorch/ao@711d001 (original) (raw)
`@@ -55,7 +55,6 @@
`
55
55
`import copy
`
56
56
`import tempfile
`
57
57
`import gc
`
58
``
`-
import time
`
59
58
`from torch.testing._internal.common_utils import TestCase
`
60
59
``
61
60
``
`@@ -692,25 +691,17 @@ def reset_memory():
`
692
691
``
693
692
`reset_memory()
`
694
693
`m = ToyLinearModel()
`
695
``
`-
time0 = time.perf_counter()
`
696
``
`-
m.to(device="cuda")
`
697
``
`-
quantize_(m, int8_weight_only())
`
698
``
`-
torch.cuda.synchronize()
`
699
``
`-
time_baseline = time.perf_counter() - time0
`
``
694
`+
quantize_(m.to(device="cuda"), int8_weight_only())
`
700
695
`memory_baseline = torch.cuda.max_memory_allocated()
`
701
696
``
702
697
`del m
`
703
698
`reset_memory()
`
704
699
`m = ToyLinearModel()
`
705
``
`-
time0 = time.perf_counter()
`
706
700
`quantize_(m, int8_weight_only(), device="cuda")
`
707
``
`-
torch.cuda.synchronize()
`
708
``
`-
time_streaming = time.perf_counter() - time0
`
709
701
`memory_streaming = torch.cuda.max_memory_allocated()
`
710
702
``
711
703
`for param in m.parameters():
`
712
704
`assert param.is_cuda
`
713
``
`-
self.assertLess(time_streaming, time_baseline * 1.5)
`
714
705
`self.assertLess(memory_streaming, memory_baseline)
`
715
706
``
716
707
``