torch.jit.unused (original) (raw)

import torch import torch.nn as nn

class MyModule(nn.Module): def init(self, use_memory_efficient): super().init() self.use_memory_efficient = use_memory_efficient

@torch.jit.unused
def memory_efficient(self, x):
    import pdb

    pdb.set_trace()
    return x + 10

def forward(self, x):
    # Use not-yet-scriptable memory efficient mode
    if self.use_memory_efficient:
        return self.memory_efficient(x)
    else:
        return x + 10

m = torch.jit.script(MyModule(use_memory_efficient=False)) m.save("m.pt")

m = torch.jit.script(MyModule(use_memory_efficient=True))

exception raised

m(torch.rand(100))