remove timing check · pytorch/ao@711d001 (original) (raw)

`@@ -55,7 +55,6 @@

`

55

55

`import copy

`

56

56

`import tempfile

`

57

57

`import gc

`

58

``

`-

import time

`

59

58

`from torch.testing._internal.common_utils import TestCase

`

60

59

``

61

60

``

`@@ -692,25 +691,17 @@ def reset_memory():

`

692

691

``

693

692

`reset_memory()

`

694

693

`m = ToyLinearModel()

`

695

``

`-

time0 = time.perf_counter()

`

696

``

`-

m.to(device="cuda")

`

697

``

`-

quantize_(m, int8_weight_only())

`

698

``

`-

torch.cuda.synchronize()

`

699

``

`-

time_baseline = time.perf_counter() - time0

`

``

694

`+

quantize_(m.to(device="cuda"), int8_weight_only())

`

700

695

`memory_baseline = torch.cuda.max_memory_allocated()

`

701

696

``

702

697

`del m

`

703

698

`reset_memory()

`

704

699

`m = ToyLinearModel()

`

705

``

`-

time0 = time.perf_counter()

`

706

700

`quantize_(m, int8_weight_only(), device="cuda")

`

707

``

`-

torch.cuda.synchronize()

`

708

``

`-

time_streaming = time.perf_counter() - time0

`

709

701

`memory_streaming = torch.cuda.max_memory_allocated()

`

710

702

``

711

703

`for param in m.parameters():

`

712

704

`assert param.is_cuda

`

713

``

`-

self.assertLess(time_streaming, time_baseline * 1.5)

`

714

705

`self.assertLess(memory_streaming, memory_baseline)

`

715

706

``

716

707

``