GitHub - umzi2/MoSR (original) (raw)
Mamba Out Super Resolution
This architecture was inspired by MambaOut
def detect(state): # Get values from state n_block = get_seq_len(state, "gblocks") - 6 in_ch = state["gblocks.0.weight"].shape[1] dim = state["gblocks.0.weight"].shape[0]
# Calculate expansion ratio and convolution ratio
expansion_ratio = (state["gblocks.1.fc1.weight"].shape[0] / state["gblocks.1.fc1.weight"].shape[1]) / 2
conv_ratio = state["gblocks.1.conv.weight"].shape[0] / dim
kernel_size = state["gblocks.1.conv.weight"].shape[2]
# Determine upsampler type and calculate upscale
if "upsampler.init_pos" in state:
upsampler = "dys"
out_ch = state["upsampler.end_conv.weight"].shape[0]
upscale = math.isqrt(state["upsampler.offset.weight"].shape[0] // 8)
elif "upsampler.in_to_k.weight" in state:
upsampler = 'gps'
out_ch = in_ch
upscale = math.isqrt(state['upsampler.in_to_k.weight'].shape[0] // 8 // out_ch)
else:
upsampler = "ps"
out_ch = in_ch
upscale = math.isqrt(state["upsampler.0.weight"].shape[0] // out_ch)
# Print results
print(f""" in_ch: {in_ch}
out_ch: {out_ch}
dim: {dim}
n_block: {n_block}
upsampler: {upsampler}
upscale: {upscale}
kernel_size: {kernel_size}
expansion_ratio: {expansion_ratio}
conv_ratio: {conv_ratio}""")signature = [ 'gblocks.0.weight', 'gblocks.0.bias', 'gblocks.1.norm.weight', 'gblocks.1.norm.bias', 'gblocks.1.fc1.weight', 'gblocks.1.fc1.bias', 'gblocks.1.conv.weight', 'gblocks.1.conv.bias', 'gblocks.1.fc2.weight', 'gblocks.1.fc2.bias', ]
References:
Training code from NeoSR
TODO:
- release pretrain