[tests] Test attention backends by sayakpaul · Pull Request #12388 · huggingface/diffusers (original) (raw)
Hi! I'm implementing a new attention backend and in preparation for that I tried the unit tests from this PR. I was working in an environment with a different torch version from nightly 2.10.0.dev20250924+cu128 indicated in the unit test file. In my environment the native backend produces numerically divergent results from expected, as seen in the following pytest output snippet:
_____________________________________________________________________________________________________________________ test_forward_with_compile[native] ______________________________________________________________________________________________________________________
output = FluxPipelineOutput(images=tensor([[[[0.0391, 0.0391, 0.0410, ..., 0.2090, 0.2090, 0.2070],
[0.0391, 0.0586,...8],
[0.0879, 0.0801, 0.0801, ..., 0.2930, 0.2891, 0.3184]]]],
device='cuda:0', dtype=torch.bfloat16))
expected_slice = tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0605, 0.0605, 0.2344,
0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2773, 0.3066],
dtype=torch.bfloat16)
def _check_if_slices_match(output, expected_slice):
img = output.images.detach().cpu()
generated_slice = img.flatten()
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
> assert torch.allclose(generated_slice, expected_slice, atol=1e-4)
E assert False
E + where False = <built-in method allclose of type object at 0x7f4b31c218c0>(tensor([0.0391, 0.0391, 0.0410, 0.0488, 0.0449, 0.0566, 0.0586, 0.0566, 0.2422,\n 0.2539, 0.2656, 0.2871, 0.2969, 0.2930, 0.2891, 0.3184],\n dtype=torch.bfloat16), tensor(
[0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0605, 0.0605, 0.2344,\n 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2773, 0.3066],\n dtype=torch.bfloat16), atol=0.0001)
E + where <built-in method allclose of type object at 0x7f4b31c218c0> = torch.allclose
tests/others/test_attention_backends.py:103: AssertionError
There were differences also in the eager mode tests.
Is it expected that the values diverge between versions? Could there be a better way to test than comparing numerical accuracy if the values are expected to vary between versions?