fix: improve numerical stability by conditionally using float32 in Anima by kohya-ss · Pull Request #2302 · kohya-ss/sd-scripts (original) (raw)
This pull request refactors how float32 precision is handled during forward passes in the anima_models.py model code. Instead of inferring precision inside each block, a use_fp32 flag is now determined once in the main forward method and then explicitly passed down through all relevant forward calls and custom forward wrappers. This makes the precision logic clearer and more consistent, especially when dealing with float16 inputs for numerical stability.
Precision handling improvements:
- Added a
use_fp32argument to theforward,_forward, and related methods in model blocks, allowing explicit control over whether float32 precision is used for computations. [1] [2] [3] - Updated all calls to block forward methods and custom forward wrappers to pass the
use_fp32flag, ensuring consistent precision handling throughout the model. [1] [2] [3] [4]
Model forward pass changes:
- In
forward_mini_train_dit, theuse_fp32flag is now set once based on the input tensor's dtype and passed to all block and final layer calls, improving clarity and reducing duplicated logic.