[tests] add tests to check for graph breaks and recompilation in pipelines during torch.compile() by sayakpaul · Pull Request #11085 · huggingface/diffusers (original) (raw)
This is good. Just to reiterate - graph breaks and recompilation are different/orthogonal concepts.
if you want to ensure that your model has no graph breaks, fullgraph=True is enough.
if you want to ensure that your model does not recompile but has graph breaks, you can use
torch._dynamo.config.recompile_limit = 1
torch._dynamo.config.fail_on_recompile_limit_hit = True
If you want to ensure that your model has no graph breaks and no recompilations, you can use
model = torch.compile(model, fullgraph=True)
torch._dynamo.config.recompile_limit = 1
Here, fullgraph=True internally ensures that it raises an error if the total number of compilations exceed recompile_limit.
What you have in this PR is also fine. You want no graph break and no recompilations. So you are using fullgraph=True and
with torch._dynamo.config.patch(error_on_recompile=True):
This works too.