fix: improve numerical stability by conditionally using float32 in Anima by kohya-ss · Pull Request #2302 · kohya-ss/sd-scripts (original) (raw)

re-fix for #2293 and #2297


This pull request refactors how float32 precision is handled during forward passes in the anima_models.py model code. Instead of inferring precision inside each block, a use_fp32 flag is now determined once in the main forward method and then explicitly passed down through all relevant forward calls and custom forward wrappers. This makes the precision logic clearer and more consistent, especially when dealing with float16 inputs for numerical stability.

Precision handling improvements:

Model forward pass changes: