Refactor QwenEmbedRope to only use the LRU cache for RoPE caching 路 huggingface/diffusers@52cf252 (original) (raw)

`@@ -180,7 +180,6 @@ def init(self, theta: int, axes_dim: List[int], scale_rope=False):

`

180

180

` ],

`

181

181

`dim=1,

`

182

182

` )

`

183

``

`-

self.rope_cache = {}

`

184

183

``

185

184

`# DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART

`

186

185

`self.scale_rope = scale_rope

`

`@@ -195,10 +194,20 @@ def rope_params(self, index, dim, theta=10000):

`

195

194

`freqs = torch.polar(torch.ones_like(freqs), freqs)

`

196

195

`return freqs

`

197

196

``

198

``

`-

def forward(self, video_fhw, txt_seq_lens, device):

`

``

197

`+

def forward(

`

``

198

`+

self,

`

``

199

`+

video_fhw: Union[Tuple[int, int, int], List[Tuple[int, int, int]]],

`

``

200

`+

txt_seq_lens: List[int],

`

``

201

`+

device: torch.device,

`

``

202

`+

) -> Tuple[torch.Tensor, torch.Tensor]:

`

199

203

`"""

`

200

``

`-

Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:

`

201

``

`-

txt_length: [bs] a list of 1 integers representing the length of the text

`

``

204

`+

Args:

`

``

205

`` +

video_fhw (Tuple[int, int, int] or List[Tuple[int, int, int]]):

``

``

206

`+

A list of 3 integers [frame, height, width] representing the shape of the video.

`

``

207

`` +

txt_seq_lens (List[int]):

``

``

208

`+

A list of integers of length batch_size representing the length of each text prompt.

`

``

209

`` +

device: (torch.device):

``

``

210

`+

The device on which to perform the RoPE computation.

`

202

211

` """

`

203

212

`if self.pos_freqs.device != device:

`

204

213

`self.pos_freqs = self.pos_freqs.to(device)

`

`@@ -213,14 +222,8 @@ def forward(self, video_fhw, txt_seq_lens, device):

`

213

222

`max_vid_index = 0

`

214

223

`for idx, fhw in enumerate(video_fhw):

`

215

224

`frame, height, width = fhw

`

216

``

`-

rope_key = f"{idx}{height}{width}"

`

217

``

-

218

``

`-

if not torch.compiler.is_compiling():

`

219

``

`-

if rope_key not in self.rope_cache:

`

220

``

`-

self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx)

`

221

``

`-

video_freq = self.rope_cache[rope_key]

`

222

``

`-

else:

`

223

``

`-

video_freq = self._compute_video_freqs(frame, height, width, idx)

`

``

225

`+

RoPE frequencies are cached via a lru_cache decorator on _compute_video_freqs

`

``

226

`+

video_freq = self._compute_video_freqs(frame, height, width, idx)

`

224

227

`video_freq = video_freq.to(device)

`

225

228

`vid_freqs.append(video_freq)

`

226

229

``

`@@ -235,8 +238,8 @@ def forward(self, video_fhw, txt_seq_lens, device):

`

235

238

``

236

239

`return vid_freqs, txt_freqs

`

237

240

``

238

``

`-

@functools.lru_cache(maxsize=None)

`

239

``

`-

def _compute_video_freqs(self, frame, height, width, idx=0):

`

``

241

`+

@functools.lru_cache(maxsize=128)

`

``

242

`+

def _compute_video_freqs(self, frame: int, height: int, width: int, idx: int = 0) -> torch.Tensor:

`

240

243

`seq_lens = frame * height * width

`

241

244

`freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)

`

242

245

`freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)

`