[tests] add tests to check for graph breaks and recompilation in pipelines during torch.compile() by sayakpaul · Pull Request #11085 · huggingface/diffusers (original) (raw)

This is good. Just to reiterate - graph breaks and recompilation are different/orthogonal concepts.

if you want to ensure that your model has no graph breaks, fullgraph=True is enough.

if you want to ensure that your model does not recompile but has graph breaks, you can use

torch._dynamo.config.recompile_limit = 1
torch._dynamo.config.fail_on_recompile_limit_hit = True

If you want to ensure that your model has no graph breaks and no recompilations, you can use

model = torch.compile(model, fullgraph=True)
torch._dynamo.config.recompile_limit = 1

Here, fullgraph=True internally ensures that it raises an error if the total number of compilations exceed recompile_limit.

What you have in this PR is also fine. You want no graph break and no recompilations. So you are using fullgraph=True and

with torch._dynamo.config.patch(error_on_recompile=True):

This works too.