Rewrite AuraFlowPatchEmbed.pe_selection_index_based_on_dim to be torch.compile compatible by AstraliteHeart · Pull Request #11297 · huggingface/diffusers (original) (raw)

import torch import rich

class TestPosEmbed: def init(self, pos_embed_max_size: int = 9216, embed_dim: int = 768, device='cpu'): """ Initialize with a dummy positional embedding parameter. pos_embed_max_size must be a perfect square (here, 9216 yields a 96x96 grid). """ self.pos_embed = torch.randn(1, pos_embed_max_size, embed_dim, device=device)

def pe_selection_index_based_on_dim_orig(self, h_p: int, w_p: int) -> torch.Tensor:
    """
    Original implementation using torch.narrow.
    h_p: number of patches in height.
    w_p: number of patches in width.
    """
    total_pe = self.pos_embed.shape[1]
    grid_size = int(total_pe ** 0.5)
    assert grid_size * grid_size == total_pe, "pos_embed_max_size must be a perfect square"

    # Create a grid of indices.
    original_pe_indexes = torch.arange(total_pe, device=self.pos_embed.device).view(grid_size, grid_size)
    
    # Compute starting indices using Python arithmetic.
    starth = (grid_size - h_p) // 2
    startw = (grid_size - w_p) // 2
    
    # Use narrow to select the center region.
    narrowed = original_pe_indexes.narrow(0, starth, h_p)
    narrowed = narrowed.narrow(1, startw, w_p)
    return narrowed.flatten()

def pe_selection_index_based_on_dim_new(self, h_p: int, w_p: int) -> torch.Tensor:
    """
    New implementation using inlined slicing arithmetic.
    h_p: number of patches in height.
    w_p: number of patches in width.
    """
    total_pe = int(self.pos_embed.shape[1])
    grid_size = int(total_pe ** 0.5)
    assert grid_size * grid_size == total_pe, "pos_embed_max_size must be a perfect square"

    # Compute starting indices (using Python ints).
    start_h = (grid_size - h_p) // 2
    start_w = (grid_size - w_p) // 2

    # Create a grid of indices.
    pe_grid = torch.arange(total_pe, device=self.pos_embed.device).view(grid_size, grid_size)
    # Select the central region and flatten.
    selected_pe = pe_grid[start_h: start_h + h_p, start_w: start_w + w_p].flatten()
    return selected_pe

def run_tests(): torch.manual_seed(42)

patch_size = 16
# Use pos_embed_max_size = 9216 -> a 96x96 grid.
pos_embed_max_size = 9216
embed_dim = 768

tester = TestPosEmbed(pos_embed_max_size=pos_embed_max_size, embed_dim=embed_dim, device='cpu')

resolutions = [
    (224, 224),
    (224, 256),
    (256, 224),
    (256, 256),
    (256, 320),
    (320, 256),
    (384, 384),
    (384, 512),
    (512, 384),
    (512, 512),
    (640, 640),
    (768, 768),
    (1024, 1024),
    (1280, 1280),
    (1536, 1536),
    (1536, 1024),
    (1024, 1536),
    (1280, 768),
    (1536, 768),
    (768, 1536),
    (1536, 896),
    (896, 1536),
]

for (height, width) in resolutions:
    h_p = height // patch_size
    w_p = width // patch_size

    orig_indices = tester.pe_selection_index_based_on_dim_orig(h_p, w_p)
    new_indices = tester.pe_selection_index_based_on_dim_new(h_p, w_p)
    
    match = torch.equal(orig_indices, new_indices)
    
    match_color = "green" if match else "red"
    rich.print(f"[cyan]Resolution: {height} x {width}[/cyan] | [yellow]Patch grid: {h_p} x {w_p}[/yellow] | Match: [{match_color}]{match}[/{match_color}]")

if name == "main": run_tests()

Screenshot 2025-04-11 at 10 04 10 PM