Asymmetric vqgan by cross-attention · Pull Request #3956 · huggingface/diffusers (original) (raw)
I used the original checkpoints from https://github.com/buxiangzhiren/Asymmetric_VQGAN/
To match the keys I used the following code
import torch from diffusers import AsymmetricAutoencoderKL
x1.5
ckpt = torch.load("./checkpoints/larger1.5.ckpt", map_location="cpu") vae = AsymmetricAutoencoderKL( in_channels = 3, out_channels = 3, down_block_types = ("DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"), down_block_out_channels = (128, 256, 512, 512), layers_per_down_block = 2, up_block_types = ("UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"), up_block_out_channels = (192, 384, 768, 768), layers_per_up_block = 3, act_fn = "silu", latent_channels = 4, norm_num_groups = 32, sample_size = 256, scaling_factor = 0.18215, )
x2
ckpt = torch.load("./checkpoints/larger2.ckpt", map_location="cpu") vae = AsymmetricAutoencoderKL( in_channels = 3, out_channels = 3, down_block_types = ("DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"), down_block_out_channels = (128, 256, 512, 512), layers_per_down_block = 2, up_block_types = ("UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"), up_block_out_channels = (256, 512, 1024, 1024), layers_per_up_block = 5, act_fn = "silu", latent_channels = 4, norm_num_groups = 32, sample_size = 256, scaling_factor = 0.18215, )
match keys
enc_dict = { k .replace("encoder.down.", "encoder.down_blocks.") .replace("encoder.mid.", "encoder.mid_block.") .replace("encoder.norm_out.", "encoder.conv_norm_out.") .replace(".downsample.", ".downsamplers.0.") .replace(".nin_shortcut.", ".conv_shortcut.") .replace(".block.", ".resnets.") .replace(".block_1.", ".resnets.0.") .replace(".block_2.", ".resnets.1.") .replace(".attn_1.k.", ".attentions.0.to_k.") .replace(".attn_1.q.", ".attentions.0.to_q.") .replace(".attn_1.v.", ".attentions.0.to_v.") .replace(".attn_1.proj_out.", ".attentions.0.to_out.0.") .replace(".attn_1.norm.", ".attentions.0.group_norm.") : v for k, v in ckpt["state_dict"].items() if k.startswith("encoder.") } for k in enc_dict.keys(): if ( k.startswith("encoder.mid_block.attentions.0") and k.endswith("weight") and ("to_q" in k or "to_k" in k or "to_v" in k or "to_out" in k) ): enc_dict[k] = enc_dict[k][:, :, 0, 0] dec_dict = { k .replace(".norm_out.", ".conv_norm_out.") .replace(".up.0.", ".up_blocks.3.") .replace(".up.1.", ".up_blocks.2.") .replace(".up.2.", ".up_blocks.1.") .replace(".up.3.", ".up_blocks.0.") .replace(".block.", ".resnets.") .replace("mid", "mid_block") .replace(".0.upsample.", ".0.upsamplers.0.") .replace(".1.upsample.", ".1.upsamplers.0.") .replace(".2.upsample.", ".2.upsamplers.0.") .replace(".nin_shortcut.", ".conv_shortcut.") .replace(".block_1.", ".resnets.0.") .replace(".block_2.", ".resnets.1.") .replace(".attn_1.k.", ".attentions.0.to_k.") .replace(".attn_1.q.", ".attentions.0.to_q.") .replace(".attn_1.v.", ".attentions.0.to_v.") .replace(".attn_1.proj_out.", ".attentions.0.to_out.0.") .replace(".attn_1.norm.", ".attentions.0.group_norm.") : v for k, v in ckpt["state_dict"].items() if ( k.startswith("decoder.") and not k.startswith("decoder.up_layers.") and not k.startswith("decoder.encoder.") ) } for k in dec_dict.keys(): if ( k.startswith("decoder.mid_block.attentions.0") and k.endswith("weight") and ("to_q" in k or "to_k" in k or "to_v" in k or "to_out" in k) ): dec_dict[k] = dec_dict[k][:, :, 0, 0] cond_enc_dict = { k .replace("decoder.up_layers.", "decoder.condition_encoder.up_layers.") .replace("decoder.encoder.", "decoder.condition_encoder.") : v for k, v in ckpt["state_dict"].items() if ( k.startswith("decoder.up_layers.") or k.startswith("decoder.encoder.") ) } quant_conv_dict = {k: v for k, v in ckpt["state_dict"].items() if k.startswith("quant_conv.")} post_quant_conv_dict = {k: v for k, v in ckpt["state_dict"].items() if k.startswith("post_quant_conv.")}
vae.load_state_dict({**quant_conv_dict, **post_quant_conv_dict, **enc_dict, **dec_dict, **cond_enc_dict})