Use real-valued instead of complex tensors in Wan2.1 RoPE by mjkvaak-amd · Pull Request #11649 · huggingface/diffusers (original) (raw)
Avoids the complex tensors in Wan2.1 RoPE by using the real-valued cosine and sine instead. This boosts the performance of compiled models (inductor), where complex tensors are not supported.
To verify that the proposed RoPE and utils result in identical and stable behavior compared to the original, I ran a 100-step training of Wan2.1 (image-to-video) with both the proposed (orange) and the original (blue) implementations
- the losses are on top of each other, but you can see there are two identical curves from the hovering tooltip.
import torch from diffusers.models.embeddings import get_1d_rotary_pos_embed from typing import * from torch import nn
class WanRotaryPosEmbed(nn.Module): def init( self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0, ): super().init()
self.attention_head_dim = attention_head_dim
self.patch_size = patch_size
self.max_seq_len = max_seq_len
h_dim = w_dim = 2 * (attention_head_dim // 6)
t_dim = attention_head_dim - h_dim - w_dim
freqs_dtype = (
torch.float32 if torch.backends.mps.is_available() else torch.float64
)
freqs_cos = []
freqs_sin = []
for dim in [t_dim, h_dim, w_dim]:
freq_cos, freq_sin = get_1d_rotary_pos_embed(
dim,
max_seq_len,
theta,
use_real=True,
repeat_interleave_real=True,
freqs_dtype=freqs_dtype,
)
freqs_cos.append(freq_cos)
freqs_sin.append(freq_sin)
self.freqs_cos = torch.cat(freqs_cos, dim=1)
self.freqs_sin = torch.cat(freqs_sin, dim=1)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.patch_size
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
self.freqs_cos = self.freqs_cos.to(hidden_states.device)
self.freqs_sin = self.freqs_sin.to(hidden_states.device)
split_sizes = [
self.attention_head_dim - 2 * (self.attention_head_dim // 3),
self.attention_head_dim // 3,
self.attention_head_dim // 3,
]
freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(
1, 1, ppf * pph * ppw, -1
)
freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(
1, 1, ppf * pph * ppw, -1
)
return freqs_cos, freqs_sindef apply_rotary_emb( hidden_states: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor, ): dtype = torch.float32 if hidden_states.device.type == "mps" else torch.float64 x = hidden_states.view(*hidden_states.shape[:-1], -1, 2).to(dtype) x1, x2 = x[..., 0], x[..., 1] cos = freqs_cos[..., 0::2] sin = freqs_sin[..., 1::2] out = torch.empty_like(hidden_states) out[..., 0::2] = x1 * cos - x2 * sin out[..., 1::2] = x1 * sin + x2 * cos return out
class WanRotaryPosEmbedOriginal(nn.Module): def init( self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0, ): super().init()
self.attention_head_dim = attention_head_dim
self.patch_size = patch_size
self.max_seq_len = max_seq_len
h_dim = w_dim = 2 * (attention_head_dim // 6)
t_dim = attention_head_dim - h_dim - w_dim
freqs = []
freqs_dtype = (
torch.float32 if torch.backends.mps.is_available() else torch.float64
)
for dim in [t_dim, h_dim, w_dim]:
freq = get_1d_rotary_pos_embed(
dim,
max_seq_len,
theta,
use_real=False,
repeat_interleave_real=False,
freqs_dtype=freqs_dtype,
)
freqs.append(freq)
self.freqs = torch.cat(freqs, dim=1)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.patch_size
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
freqs = self.freqs.to(hidden_states.device)
freqs = freqs.split_with_sizes(
[
self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6),
self.attention_head_dim // 6,
self.attention_head_dim // 6,
],
dim=1,
)
freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(
1, 1, ppf * pph * ppw, -1
)
return freqsdef apply_rotary_emb_original(hidden_states: torch.Tensor, freqs: torch.Tensor): dtype = torch.float32 if hidden_states.device.type == "mps" else torch.float64 x_rotated = torch.view_as_complex(hidden_states.to(dtype).unflatten(3, (-1, 2))) x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4) return x_out.type_as(hidden_states)
def test_rotary_pos_embed_value_equivalence(): attention_head_dim = 12 patch_size = (2, 2, 2) max_seq_len = 16 batch, channels, frames, height, width = 1, attention_head_dim, 8, 8, 8 hidden_states = torch.randn(batch, channels, frames, height, width)
rope = WanRotaryPosEmbed(attention_head_dim, patch_size, max_seq_len)
rope_orig = WanRotaryPosEmbedOriginal(attention_head_dim, patch_size, max_seq_len)
# New returns (cos, sin), original returns complex
cos, sin = rope(hidden_states)
orig = rope_orig(hidden_states) # shape: (1, 1, N, D)
# Remove batch dims for comparison
cos = cos.squeeze(0).squeeze(0) # (N, D)
sin = sin.squeeze(0).squeeze(0) # (N, D)
orig = orig.squeeze(0).squeeze(0) # (N, D/2), complex
cos_real = cos[:, 0::2]
sin_real = sin[:, 1::2]
# Reconstruct complex tensor
recon = cos_real + 1j * sin_real
# Compare real and imaginary parts
assert torch.allclose(recon.real.float(), orig.real.float(), atol=1e-5)
assert torch.allclose(recon.imag.float(), orig.imag.float(), atol=1e-5)def test_rotary_emb_equivalence(): attention_head_dim = 12 patch_size = (2, 2, 2) max_seq_len = 16 batch, channels, frames, height, width = 1, attention_head_dim, 8, 8, 8 hidden_states = torch.randn(batch, channels, frames, height, width)
rope = WanRotaryPosEmbed(attention_head_dim, patch_size, max_seq_len)
rope_orig = WanRotaryPosEmbedOriginal(attention_head_dim, patch_size, max_seq_len)
# Get rotary embeddings
cos, sin = rope(hidden_states)
freqs = rope_orig(hidden_states)
# Prepare a fake attention input (B, H, N, D)
B, H, N, D = cos.shape
x = torch.randn(B, H, N, D, dtype=torch.float32)
# Apply both rotary embeddings
out_orig = apply_rotary_emb_original(x, freqs)
out_real = apply_rotary_emb(x, cos, sin)
# Check equivalence
assert torch.allclose(
out_real, out_orig, atol=1e-5
), "Real-valued rotary embedding does not match original complex version"