Support lcm models. · comfyanonymous/ComfyUI@002aefa (original) (raw)

1

1

`import folder_paths

`

2

2

`import comfy.sd

`

3

3

`import comfy.model_sampling

`

``

4

`+

import torch

`

``

5

+

``

6

`+

class LCM(comfy.model_sampling.EPS):

`

``

7

`+

def calculate_denoised(self, sigma, model_output, model_input):

`

``

8

`+

timestep = self.timestep(sigma).view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))

`

``

9

`+

sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))

`

``

10

`+

x0 = model_input - model_output * sigma

`

``

11

+

``

12

`+

sigma_data = 0.5

`

``

13

`+

scaled_timestep = timestep * 10.0 #timestep_scaling

`

``

14

+

``

15

`+

c_skip = sigma_data2 / (scaled_timestep2 + sigma_data**2)

`

``

16

`+

c_out = scaled_timestep / (scaled_timestep2 + sigma_data2) ** 0.5

`

``

17

+

``

18

`+

return c_out * x0 + c_skip * model_input

`

``

19

+

``

20

`+

class ModelSamplingDiscreteLCM(torch.nn.Module):

`

``

21

`+

def init(self):

`

``

22

`+

super().init()

`

``

23

`+

self.sigma_data = 1.0

`

``

24

`+

timesteps = 1000

`

``

25

`+

beta_start = 0.00085

`

``

26

`+

beta_end = 0.012

`

``

27

+

``

28

`+

betas = torch.linspace(beta_start0.5, beta_end0.5, timesteps, dtype=torch.float32) ** 2

`

``

29

`+

alphas = 1.0 - betas

`

``

30

`+

alphas_cumprod = torch.cumprod(alphas, dim=0)

`

``

31

+

``

32

`+

original_timesteps = 50

`

``

33

`+

self.skip_steps = timesteps // original_timesteps

`

``

34

+

``

35

+

``

36

`+

alphas_cumprod_valid = torch.zeros((original_timesteps), dtype=torch.float32)

`

``

37

`+

for x in range(original_timesteps):

`

``

38

`+

alphas_cumprod_valid[original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps]

`

``

39

+

``

40

`+

sigmas = ((1 - alphas_cumprod_valid) / alphas_cumprod_valid) ** 0.5

`

``

41

`+

self.set_sigmas(sigmas)

`

``

42

+

``

43

`+

def set_sigmas(self, sigmas):

`

``

44

`+

self.register_buffer('sigmas', sigmas)

`

``

45

`+

self.register_buffer('log_sigmas', sigmas.log())

`

``

46

+

``

47

`+

@property

`

``

48

`+

def sigma_min(self):

`

``

49

`+

return self.sigmas[0]

`

``

50

+

``

51

`+

@property

`

``

52

`+

def sigma_max(self):

`

``

53

`+

return self.sigmas[-1]

`

``

54

+

``

55

`+

def timestep(self, sigma):

`

``

56

`+

log_sigma = sigma.log()

`

``

57

`+

dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]

`

``

58

`+

return dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1)

`

``

59

+

``

60

`+

def sigma(self, timestep):

`

``

61

`+

t = torch.clamp(((timestep - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1))

`

``

62

`+

low_idx = t.floor().long()

`

``

63

`+

high_idx = t.ceil().long()

`

``

64

`+

w = t.frac()

`

``

65

`+

log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]

`

``

66

`+

return log_sigma.exp()

`

``

67

+

``

68

`+

def percent_to_sigma(self, percent):

`

``

69

`+

return self.sigma(torch.tensor(percent * 999.0))

`

4

70

``

5

71

``

6

72

`def rescale_zero_terminal_snr_sigmas(sigmas):

`

`@@ -26,7 +92,7 @@ class ModelSamplingDiscrete:

`

26

92

`@classmethod

`

27

93

`def INPUT_TYPES(s):

`

28

94

`return {"required": { "model": ("MODEL",),

`

29

``

`-

"sampling": (["eps", "v_prediction"],),

`

``

95

`+

"sampling": (["eps", "v_prediction", "lcm"],),

`

30

96

`"zsnr": ("BOOLEAN", {"default": False}),

`

31

97

` }}

`

32

98

``

`@@ -38,17 +104,22 @@ def INPUT_TYPES(s):

`

38

104

`def patch(self, model, sampling, zsnr):

`

39

105

`m = model.clone()

`

40

106

``

``

107

`+

sampling_base = comfy.model_sampling.ModelSamplingDiscrete

`

41

108

`if sampling == "eps":

`

42

109

`sampling_type = comfy.model_sampling.EPS

`

43

110

`elif sampling == "v_prediction":

`

44

111

`sampling_type = comfy.model_sampling.V_PREDICTION

`

``

112

`+

elif sampling == "lcm":

`

``

113

`+

sampling_type = LCM

`

``

114

`+

sampling_base = ModelSamplingDiscreteLCM

`

45

115

``

46

``

`-

class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingDiscrete, sampling_type):

`

``

116

`+

class ModelSamplingAdvanced(sampling_base, sampling_type):

`

47

117

`pass

`

48

118

``

49

119

`model_sampling = ModelSamplingAdvanced()

`

50

120

`if zsnr:

`

51

121

`model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas))

`

``

122

+

52

123

`m.add_object_patch("model_sampling", model_sampling)

`

53

124

`return (m, )

`

54

125

``