Rewrite AuraFlowPatchEmbed.pe_selection_index_based_on_dim to be torch.compile compatible by AstraliteHeart · Pull Request #11297 · huggingface/diffusers (original) (raw)
import torch import rich
class TestPosEmbed: def init(self, pos_embed_max_size: int = 9216, embed_dim: int = 768, device='cpu'): """ Initialize with a dummy positional embedding parameter. pos_embed_max_size must be a perfect square (here, 9216 yields a 96x96 grid). """ self.pos_embed = torch.randn(1, pos_embed_max_size, embed_dim, device=device)
def pe_selection_index_based_on_dim_orig(self, h_p: int, w_p: int) -> torch.Tensor:
"""
Original implementation using torch.narrow.
h_p: number of patches in height.
w_p: number of patches in width.
"""
total_pe = self.pos_embed.shape[1]
grid_size = int(total_pe ** 0.5)
assert grid_size * grid_size == total_pe, "pos_embed_max_size must be a perfect square"
# Create a grid of indices.
original_pe_indexes = torch.arange(total_pe, device=self.pos_embed.device).view(grid_size, grid_size)
# Compute starting indices using Python arithmetic.
starth = (grid_size - h_p) // 2
startw = (grid_size - w_p) // 2
# Use narrow to select the center region.
narrowed = original_pe_indexes.narrow(0, starth, h_p)
narrowed = narrowed.narrow(1, startw, w_p)
return narrowed.flatten()
def pe_selection_index_based_on_dim_new(self, h_p: int, w_p: int) -> torch.Tensor:
"""
New implementation using inlined slicing arithmetic.
h_p: number of patches in height.
w_p: number of patches in width.
"""
total_pe = int(self.pos_embed.shape[1])
grid_size = int(total_pe ** 0.5)
assert grid_size * grid_size == total_pe, "pos_embed_max_size must be a perfect square"
# Compute starting indices (using Python ints).
start_h = (grid_size - h_p) // 2
start_w = (grid_size - w_p) // 2
# Create a grid of indices.
pe_grid = torch.arange(total_pe, device=self.pos_embed.device).view(grid_size, grid_size)
# Select the central region and flatten.
selected_pe = pe_grid[start_h: start_h + h_p, start_w: start_w + w_p].flatten()
return selected_pe
def run_tests(): torch.manual_seed(42)
patch_size = 16
# Use pos_embed_max_size = 9216 -> a 96x96 grid.
pos_embed_max_size = 9216
embed_dim = 768
tester = TestPosEmbed(pos_embed_max_size=pos_embed_max_size, embed_dim=embed_dim, device='cpu')
resolutions = [
(224, 224),
(224, 256),
(256, 224),
(256, 256),
(256, 320),
(320, 256),
(384, 384),
(384, 512),
(512, 384),
(512, 512),
(640, 640),
(768, 768),
(1024, 1024),
(1280, 1280),
(1536, 1536),
(1536, 1024),
(1024, 1536),
(1280, 768),
(1536, 768),
(768, 1536),
(1536, 896),
(896, 1536),
]
for (height, width) in resolutions:
h_p = height // patch_size
w_p = width // patch_size
orig_indices = tester.pe_selection_index_based_on_dim_orig(h_p, w_p)
new_indices = tester.pe_selection_index_based_on_dim_new(h_p, w_p)
match = torch.equal(orig_indices, new_indices)
match_color = "green" if match else "red"
rich.print(f"[cyan]Resolution: {height} x {width}[/cyan] | [yellow]Patch grid: {h_p} x {w_p}[/yellow] | Match: [{match_color}]{match}[/{match_color}]")
if name == "main": run_tests()