Use real-valued instead of complex tensors in Wan2.1 RoPE by mjkvaak-amd · Pull Request #11649 · huggingface/diffusers (original) (raw)

Avoids the complex tensors in Wan2.1 RoPE by using the real-valued cosine and sine instead. This boosts the performance of compiled models (inductor), where complex tensors are not supported.

To verify that the proposed RoPE and utils result in identical and stable behavior compared to the original, I ran a 100-step training of Wan2.1 (image-to-video) with both the proposed (orange) and the original (blue) implementations image - the losses are on top of each other, but you can see there are two identical curves from the hovering tooltip.

import torch from diffusers.models.embeddings import get_1d_rotary_pos_embed from typing import * from torch import nn

class WanRotaryPosEmbed(nn.Module): def init( self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0, ): super().init()

    self.attention_head_dim = attention_head_dim
    self.patch_size = patch_size
    self.max_seq_len = max_seq_len

    h_dim = w_dim = 2 * (attention_head_dim // 6)
    t_dim = attention_head_dim - h_dim - w_dim
    freqs_dtype = (
        torch.float32 if torch.backends.mps.is_available() else torch.float64
    )

    freqs_cos = []
    freqs_sin = []

    for dim in [t_dim, h_dim, w_dim]:
        freq_cos, freq_sin = get_1d_rotary_pos_embed(
            dim,
            max_seq_len,
            theta,
            use_real=True,
            repeat_interleave_real=True,
            freqs_dtype=freqs_dtype,
        )
        freqs_cos.append(freq_cos)
        freqs_sin.append(freq_sin)

    self.freqs_cos = torch.cat(freqs_cos, dim=1)
    self.freqs_sin = torch.cat(freqs_sin, dim=1)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
    batch_size, num_channels, num_frames, height, width = hidden_states.shape
    p_t, p_h, p_w = self.patch_size
    ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w

    self.freqs_cos = self.freqs_cos.to(hidden_states.device)
    self.freqs_sin = self.freqs_sin.to(hidden_states.device)

    split_sizes = [
        self.attention_head_dim - 2 * (self.attention_head_dim // 3),
        self.attention_head_dim // 3,
        self.attention_head_dim // 3,
    ]

    freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
    freqs_sin = self.freqs_sin.split(split_sizes, dim=1)

    freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
    freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
    freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)

    freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
    freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
    freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)

    freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(
        1, 1, ppf * pph * ppw, -1
    )
    freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(
        1, 1, ppf * pph * ppw, -1
    )

    return freqs_cos, freqs_sin

def apply_rotary_emb( hidden_states: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor, ): dtype = torch.float32 if hidden_states.device.type == "mps" else torch.float64 x = hidden_states.view(*hidden_states.shape[:-1], -1, 2).to(dtype) x1, x2 = x[..., 0], x[..., 1] cos = freqs_cos[..., 0::2] sin = freqs_sin[..., 1::2] out = torch.empty_like(hidden_states) out[..., 0::2] = x1 * cos - x2 * sin out[..., 1::2] = x1 * sin + x2 * cos return out

class WanRotaryPosEmbedOriginal(nn.Module): def init( self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0, ): super().init()

    self.attention_head_dim = attention_head_dim
    self.patch_size = patch_size
    self.max_seq_len = max_seq_len

    h_dim = w_dim = 2 * (attention_head_dim // 6)
    t_dim = attention_head_dim - h_dim - w_dim

    freqs = []
    freqs_dtype = (
        torch.float32 if torch.backends.mps.is_available() else torch.float64
    )
    for dim in [t_dim, h_dim, w_dim]:
        freq = get_1d_rotary_pos_embed(
            dim,
            max_seq_len,
            theta,
            use_real=False,
            repeat_interleave_real=False,
            freqs_dtype=freqs_dtype,
        )
        freqs.append(freq)
    self.freqs = torch.cat(freqs, dim=1)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
    batch_size, num_channels, num_frames, height, width = hidden_states.shape
    p_t, p_h, p_w = self.patch_size
    ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w

    freqs = self.freqs.to(hidden_states.device)
    freqs = freqs.split_with_sizes(
        [
            self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6),
            self.attention_head_dim // 6,
            self.attention_head_dim // 6,
        ],
        dim=1,
    )

    freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
    freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
    freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
    freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(
        1, 1, ppf * pph * ppw, -1
    )
    return freqs

def apply_rotary_emb_original(hidden_states: torch.Tensor, freqs: torch.Tensor): dtype = torch.float32 if hidden_states.device.type == "mps" else torch.float64 x_rotated = torch.view_as_complex(hidden_states.to(dtype).unflatten(3, (-1, 2))) x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4) return x_out.type_as(hidden_states)

def test_rotary_pos_embed_value_equivalence(): attention_head_dim = 12 patch_size = (2, 2, 2) max_seq_len = 16 batch, channels, frames, height, width = 1, attention_head_dim, 8, 8, 8 hidden_states = torch.randn(batch, channels, frames, height, width)

rope = WanRotaryPosEmbed(attention_head_dim, patch_size, max_seq_len)
rope_orig = WanRotaryPosEmbedOriginal(attention_head_dim, patch_size, max_seq_len)

# New returns (cos, sin), original returns complex
cos, sin = rope(hidden_states)
orig = rope_orig(hidden_states)  # shape: (1, 1, N, D)

# Remove batch dims for comparison
cos = cos.squeeze(0).squeeze(0)  # (N, D)
sin = sin.squeeze(0).squeeze(0)  # (N, D)
orig = orig.squeeze(0).squeeze(0)  # (N, D/2), complex
cos_real = cos[:, 0::2]
sin_real = sin[:, 1::2]

# Reconstruct complex tensor
recon = cos_real + 1j * sin_real

# Compare real and imaginary parts
assert torch.allclose(recon.real.float(), orig.real.float(), atol=1e-5)
assert torch.allclose(recon.imag.float(), orig.imag.float(), atol=1e-5)

def test_rotary_emb_equivalence(): attention_head_dim = 12 patch_size = (2, 2, 2) max_seq_len = 16 batch, channels, frames, height, width = 1, attention_head_dim, 8, 8, 8 hidden_states = torch.randn(batch, channels, frames, height, width)

rope = WanRotaryPosEmbed(attention_head_dim, patch_size, max_seq_len)
rope_orig = WanRotaryPosEmbedOriginal(attention_head_dim, patch_size, max_seq_len)

# Get rotary embeddings
cos, sin = rope(hidden_states)
freqs = rope_orig(hidden_states)

# Prepare a fake attention input (B, H, N, D)
B, H, N, D = cos.shape
x = torch.randn(B, H, N, D, dtype=torch.float32)

# Apply both rotary embeddings
out_orig = apply_rotary_emb_original(x, freqs)
out_real = apply_rotary_emb(x, cos, sin)

# Check equivalence
assert torch.allclose(
    out_real, out_orig, atol=1e-5
), "Real-valued rotary embedding does not match original complex version"