Refactor QwenEmbedRope to only use the LRU cache for RoPE caching 路 huggingface/diffusers@52cf252 (original) (raw)
`@@ -180,7 +180,6 @@ def init(self, theta: int, axes_dim: List[int], scale_rope=False):
`
180
180
` ],
`
181
181
`dim=1,
`
182
182
` )
`
183
``
`-
self.rope_cache = {}
`
184
183
``
185
184
`# DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART
`
186
185
`self.scale_rope = scale_rope
`
`@@ -195,10 +194,20 @@ def rope_params(self, index, dim, theta=10000):
`
195
194
`freqs = torch.polar(torch.ones_like(freqs), freqs)
`
196
195
`return freqs
`
197
196
``
198
``
`-
def forward(self, video_fhw, txt_seq_lens, device):
`
``
197
`+
def forward(
`
``
198
`+
self,
`
``
199
`+
video_fhw: Union[Tuple[int, int, int], List[Tuple[int, int, int]]],
`
``
200
`+
txt_seq_lens: List[int],
`
``
201
`+
device: torch.device,
`
``
202
`+
) -> Tuple[torch.Tensor, torch.Tensor]:
`
199
203
`"""
`
200
``
`-
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
`
201
``
`-
txt_length: [bs] a list of 1 integers representing the length of the text
`
``
204
`+
Args:
`
``
205
`` +
video_fhw (Tuple[int, int, int] or List[Tuple[int, int, int]]):
``
``
206
`+
A list of 3 integers [frame, height, width] representing the shape of the video.
`
``
207
`` +
txt_seq_lens (List[int]):
``
``
208
`+
A list of integers of length batch_size representing the length of each text prompt.
`
``
209
`` +
device: (torch.device):
``
``
210
`+
The device on which to perform the RoPE computation.
`
202
211
` """
`
203
212
`if self.pos_freqs.device != device:
`
204
213
`self.pos_freqs = self.pos_freqs.to(device)
`
`@@ -213,14 +222,8 @@ def forward(self, video_fhw, txt_seq_lens, device):
`
213
222
`max_vid_index = 0
`
214
223
`for idx, fhw in enumerate(video_fhw):
`
215
224
`frame, height, width = fhw
`
216
``
`-
rope_key = f"{idx}{height}{width}"
`
217
``
-
218
``
`-
if not torch.compiler.is_compiling():
`
219
``
`-
if rope_key not in self.rope_cache:
`
220
``
`-
self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx)
`
221
``
`-
video_freq = self.rope_cache[rope_key]
`
222
``
`-
else:
`
223
``
`-
video_freq = self._compute_video_freqs(frame, height, width, idx)
`
``
225
`+
RoPE frequencies are cached via a lru_cache decorator on _compute_video_freqs
`
``
226
`+
video_freq = self._compute_video_freqs(frame, height, width, idx)
`
224
227
`video_freq = video_freq.to(device)
`
225
228
`vid_freqs.append(video_freq)
`
226
229
``
`@@ -235,8 +238,8 @@ def forward(self, video_fhw, txt_seq_lens, device):
`
235
238
``
236
239
`return vid_freqs, txt_freqs
`
237
240
``
238
``
`-
@functools.lru_cache(maxsize=None)
`
239
``
`-
def _compute_video_freqs(self, frame, height, width, idx=0):
`
``
241
`+
@functools.lru_cache(maxsize=128)
`
``
242
`+
def _compute_video_freqs(self, frame: int, height: int, width: int, idx: int = 0) -> torch.Tensor:
`
240
243
`seq_lens = frame * height * width
`
241
244
`freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
`
242
245
`freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
`