GitHub - ChenyangSi/FreeU: FreeU: Free Lunch in Diffusion U-Net (CVPR2024 Oral) (original) (raw)

FreeU: Free Lunch in Diffusion U-Net

S-Lab, Nanyang Technological University

Paper | Project Page | Video | Demo

CVPR2024 Oral

Twitter Hits Hits Hits Hugging Face


We propose FreeU, a method that substantially improves diffusion model sample quality at no cost: no training, no additional parameter introduced, and no increase in memory or sampling time.

📖 For more visual results, go checkout our Project Page

Usage

FreeU Code

def Fourier_filter(x, threshold, scale): # FFT x_freq = fft.fftn(x, dim=(-2, -1)) x_freq = fft.fftshift(x_freq, dim=(-2, -1))

B, C, H, W = x_freq.shape
mask = torch.ones((B, C, H, W)).cuda() 

crow, ccol = H // 2, W //2
mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale
x_freq = x_freq * mask

# IFFT
x_freq = fft.ifftshift(x_freq, dim=(-2, -1))
x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real

return x_filtered

class Free_UNetModel(UNetModel): """ :param b1: backbone factor of the first stage block of decoder. :param b2: backbone factor of the second stage block of decoder. :param s1: skip factor of the first stage block of decoder. :param s2: skip factor of the second stage block of decoder. """

def __init__(
    self,
    b1,
    b2,
    s1,
    s2,
    *args,
    **kwargs
):
    super().__init__(*args, **kwargs)
    self.b1 = b1 
    self.b2 = b2
    self.s1 = s1
    self.s2 = s2

def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
    """
    Apply the model to an input batch.
    :param x: an [N x C x ...] Tensor of inputs.
    :param timesteps: a 1-D batch of timesteps.
    :param context: conditioning plugged in via crossattn
    :param y: an [N] Tensor of labels, if class-conditional.
    :return: an [N x C x ...] Tensor of outputs.
    """
    assert (y is not None) == (
        self.num_classes is not None
    ), "must specify y if and only if the model is class-conditional"
    hs = []
    t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
    emb = self.time_embed(t_emb)

    if self.num_classes is not None:
        assert y.shape[0] == x.shape[0]
        emb = emb + self.label_emb(y)

    h = x.type(self.dtype)
    for module in self.input_blocks:
        h = module(h, emb, context)
        hs.append(h)
    h = self.middle_block(h, emb, context)
    for module in self.output_blocks:
        hs_ = hs.pop()

        # --------------- FreeU code -----------------------
        # Only operate on the first two stages
        if h.shape[1] == 1280:
            hidden_mean = h.mean(1).unsqueeze(1)
            B = hidden_mean.shape[0]
            hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True) 
            hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True)
            hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3)

            h[:,:640] = h[:,:640] * ((self.b1 - 1 ) * hidden_mean + 1)
            hs_ = Fourier_filter(hs_, threshold=1, scale=self.s1)
        if h.shape[1] == 640:
            hidden_mean = h.mean(1).unsqueeze(1)
            B = hidden_mean.shape[0]
            hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True) 
            hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True)
            hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3)

            h[:,:320] = h[:,:320] * ((self.b2 - 1 ) * hidden_mean + 1)
            hs_ = Fourier_filter(hs_, threshold=1, scale=self.s2)
        # ---------------------------------------------------------

        h = th.cat([h, hs_], dim=1)
        h = module(h, emb, context)
    h = h.type(x.dtype)
    if self.predict_codebook_ids:
        return self.id_predictor(h)
    else:
        return self.out(h)

Parameters

You can adjust these parameters based on your models, image/video style, or tasks. You can look over the following parameters.

SD1.4: (will be updated soon)

b1: 1.3, b2: 1.4, s1: 0.9, s2: 0.2

SD1.5: (will be updated soon)

b1: 1.5, b2: 1.6, s1: 0.9, s2: 0.2

SD2.1

b1: 1.1, b2: 1.2, s1: 0.9, s2: 0.2

b1: 1.4, b2: 1.6, s1: 0.9, s2: 0.2

SDXL

b1: 1.3, b2: 1.4, s1: 0.9, s2: 0.2SDXL results

Range for More Parameters

When trying additional parameters, consider the following ranges:

Results from the community

If you tried FreeU and want to share your results, let me know and we can put up the link here.

BibTeX

@inproceedings{si2023freeu,
  title={FreeU: Free Lunch in Diffusion U-Net},
  author={Si, Chenyang and Huang, Ziqi and Jiang, Yuming and Liu, Ziwei},
  booktitle={CVPR},
  year={2024}
}

🗞️ License

Distributed under the MIT License. See LICENSE for more information.