UNetMotionModel.dtype is really expensive to call, is it possible to cache it during inference? · Issue #9520 · huggingface/diffusers (original) (raw)
What API design would you like to have changed or added to the library? Why?
we are using class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin)
and its forward()
implementation is calling self.dtype, which is very expensive
from my profiling trace result, calling self.dtype takes 6-10ms each time.
can we somehow cache it to save time?
I took a look at ModelMixin.dtype() property function, it get all parameters of the model into tuple to check only first parameter's dtype, i don't thinkmake sense to do this everytime. right?
What use case would this enable or better enable? Can you give us a code example?
We are using this model to do video generation, so the inference is running repeatedly. Is it easy to optimize this ~10ms latency?
Thanks!