get_1d_sincos_pos_embed_from_grid uses float64 unavailable on MPS (original) (raw)

Describe the bug

for some reason, positional embeds in pixart now require fp64, but this pattern has been discouraged as Mac users don't have fp64 at all - not even emulation is available currently.

Reproduction

def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np", flip_sin_to_cos=False): """ This function generates 1D positional embeddings from a grid.

Args:
    embed_dim (`int`): The embedding dimension `D`
    pos (`torch.Tensor`): 1D tensor of positions with shape `(M,)`

Returns:
    `torch.Tensor`: Sinusoidal positional embeddings of shape `(M, D)`.
"""
if output_type == "np":
    deprecation_message = (
        "`get_1d_sincos_pos_embed_from_grid` uses `torch` and supports `device`."
        " `from_numpy` is no longer required."
        "  Pass `output_type='pt' to use the new version now."
    )
    deprecate("output_type=='np'", "0.34.0", deprecation_message, standard_warn=False)
    return get_1d_sincos_pos_embed_from_grid_np(embed_dim=embed_dim, pos=pos)
if embed_dim % 2 != 0:
    raise ValueError("embed_dim must be divisible by 2")

omega = torch.arange(embed_dim // 2, device=pos.device, dtype=torch.float64)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega  # (D/2,)

pos = pos.reshape(-1)  # (M,)
out = torch.outer(pos, omega)  # (M, D/2), outer product

emb_sin = torch.sin(out)  # (M, D/2)
emb_cos = torch.cos(out)  # (M, D/2)

emb = torch.concat([emb_sin, emb_cos], dim=1)  # (M, D)

# flip sine and cosine embeddings
if flip_sin_to_cos:
    emb = torch.cat([emb[:, embed_dim // 2 :], emb[:, : embed_dim // 2]], dim=1)

return emb

Logs

System Info

diffusers latest

Who can help?

No response