A few more features_intermediate() models, AttentionExtract helper, related minor cleanup. by rwightman · Pull Request #2168 · huggingface/pytorch-image-models (original) (raw)

Expand Up

@@ -3,8 +3,8 @@

A collection of activations fn and modules with a common interface so that they can

easily be swapped. All have an `inplace` arg even if not used.

These activations are not compatible with jit scripting or ONNX export of the model, please use either

the JIT or basic versions of the activations.

These activations are not compatible with jit scripting or ONNX export of the model, please use

basic versions of the activations.

Hacked together by / Copyright 2020 Ross Wightman

"""

Expand All

@@ -14,19 +14,17 @@

from torch.nn import functional as F

@torch.jit.script

def swish_jit_fwd(x):

def swish_fwd(x):

return x.mul(torch.sigmoid(x))

@torch.jit.script

def swish_jit_bwd(x, grad_output):

def swish_bwd(x, grad_output):

x_sigmoid = torch.sigmoid(x)

return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid)))

class SwishJitAutoFn(torch.autograd.Function):

""" torch.jit.script optimised Swish w/ memory-efficient checkpoint

class SwishAutoFn(torch.autograd.Function):

""" optimised Swish w/ memory-efficient checkpoint

Inspired by conversation btw Jeremy Howard & Adam Pazske

https://twitter.com/jeremyphoward/status/1188251041835315200

"""

Expand All

@@ -37,123 +35,117 @@ def symbolic(g, x):

@staticmethod

def forward(ctx, x):

ctx.save_for_backward(x)

return swish_jit_fwd(x)

return swish_fwd(x)

@staticmethod

def backward(ctx, grad_output):

x = ctx.saved_tensors[0]

return swish_jit_bwd(x, grad_output)

return swish_bwd(x, grad_output)

def swish_me(x, inplace=False):

return SwishJitAutoFn.apply(x)

return SwishAutoFn.apply(x)

class SwishMe(nn.Module):

def __init__(self, inplace: bool = False):

super(SwishMe, self).__init__()

def forward(self, x):

return SwishJitAutoFn.apply(x)

return SwishAutoFn.apply(x)

@torch.jit.script

def mish_jit_fwd(x):

def mish_fwd(x):

return x.mul(torch.tanh(F.softplus(x)))

@torch.jit.script

def mish_jit_bwd(x, grad_output):

def mish_bwd(x, grad_output):

x_sigmoid = torch.sigmoid(x)

x_tanh_sp = F.softplus(x).tanh()

return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp))

class MishJitAutoFn(torch.autograd.Function):

class MishAutoFn(torch.autograd.Function):

""" Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681

A memory efficient, jit scripted variant of Mish

A memory efficient variant of Mish

"""

@staticmethod

def forward(ctx, x):

ctx.save_for_backward(x)

return mish_jit_fwd(x)

return mish_fwd(x)

@staticmethod

def backward(ctx, grad_output):

x = ctx.saved_tensors[0]

return mish_jit_bwd(x, grad_output)

return mish_bwd(x, grad_output)

def mish_me(x, inplace=False):

return MishJitAutoFn.apply(x)

return MishAutoFn.apply(x)

class MishMe(nn.Module):

def __init__(self, inplace: bool = False):

super(MishMe, self).__init__()

def forward(self, x):

return MishJitAutoFn.apply(x)

return MishAutoFn.apply(x)

@torch.jit.script

def hard_sigmoid_jit_fwd(x, inplace: bool = False):

def hard_sigmoid_fwd(x, inplace: bool = False):

return (x + 3).clamp(min=0, max=6).div(6.)

@torch.jit.script

def hard_sigmoid_jit_bwd(x, grad_output):

def hard_sigmoid_bwd(x, grad_output):

m = torch.ones_like(x) * ((x >= -3.) & (x <= 3.)) / 6.

return grad_output * m

class HardSigmoidJitAutoFn(torch.autograd.Function):

class HardSigmoidAutoFn(torch.autograd.Function):

@staticmethod

def forward(ctx, x):

ctx.save_for_backward(x)

return hard_sigmoid_jit_fwd(x)

return hard_sigmoid_fwd(x)

@staticmethod

def backward(ctx, grad_output):

x = ctx.saved_tensors[0]

return hard_sigmoid_jit_bwd(x, grad_output)

return hard_sigmoid_bwd(x, grad_output)

def hard_sigmoid_me(x, inplace: bool = False):

return HardSigmoidJitAutoFn.apply(x)

return HardSigmoidAutoFn.apply(x)

class HardSigmoidMe(nn.Module):

def __init__(self, inplace: bool = False):

super(HardSigmoidMe, self).__init__()

def forward(self, x):

return HardSigmoidJitAutoFn.apply(x)

return HardSigmoidAutoFn.apply(x)

@torch.jit.script

def hard_swish_jit_fwd(x):

def hard_swish_fwd(x):

return x * (x + 3).clamp(min=0, max=6).div(6.)

@torch.jit.script

def hard_swish_jit_bwd(x, grad_output):

def hard_swish_bwd(x, grad_output):

m = torch.ones_like(x) * (x >= 3.)

m = torch.where((x >= -3.) & (x <= 3.), x / 3. + .5, m)

return grad_output * m

class HardSwishJitAutoFn(torch.autograd.Function):

"""A memory efficient, jit-scripted HardSwish activation"""

class HardSwishAutoFn(torch.autograd.Function):

"""A memory efficient HardSwish activation"""

@staticmethod

def forward(ctx, x):

ctx.save_for_backward(x)

return hard_swish_jit_fwd(x)

return hard_swish_fwd(x)

@staticmethod

def backward(ctx, grad_output):

x = ctx.saved_tensors[0]

return hard_swish_jit_bwd(x, grad_output)

return hard_swish_bwd(x, grad_output)

@staticmethod

def symbolic(g, self):

Expand All

@@ -164,55 +156,53 @@ def symbolic(g, self):

def hard_swish_me(x, inplace=False):

return HardSwishJitAutoFn.apply(x)

return HardSwishAutoFn.apply(x)

class HardSwishMe(nn.Module):

def __init__(self, inplace: bool = False):

super(HardSwishMe, self).__init__()

def forward(self, x):

return HardSwishJitAutoFn.apply(x)

return HardSwishAutoFn.apply(x)

@torch.jit.script

def hard_mish_jit_fwd(x):

def hard_mish_fwd(x):

return 0.5 * x * (x + 2).clamp(min=0, max=2)

@torch.jit.script

def hard_mish_jit_bwd(x, grad_output):

def hard_mish_bwd(x, grad_output):

m = torch.ones_like(x) * (x >= -2.)

m = torch.where((x >= -2.) & (x <= 0.), x + 1., m)

return grad_output * m

class HardMishJitAutoFn(torch.autograd.Function):

""" A memory efficient, jit scripted variant of Hard Mish

class HardMishAutoFn(torch.autograd.Function):

""" A memory efficient variant of Hard Mish

Experimental, based on notes by Mish author Diganta Misra at

https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md

"""

@staticmethod

def forward(ctx, x):

ctx.save_for_backward(x)

return hard_mish_jit_fwd(x)

return hard_mish_fwd(x)

@staticmethod

def backward(ctx, grad_output):

x = ctx.saved_tensors[0]

return hard_mish_jit_bwd(x, grad_output)

return hard_mish_bwd(x, grad_output)

def hard_mish_me(x, inplace: bool = False):

return HardMishJitAutoFn.apply(x)

return HardMishAutoFn.apply(x)

class HardMishMe(nn.Module):

def __init__(self, inplace: bool = False):

super(HardMishMe, self).__init__()

def forward(self, x):

return HardMishJitAutoFn.apply(x)

return HardMishAutoFn.apply(x)