Support lcm models. · comfyanonymous/ComfyUI@002aefa (original) (raw)
1
1
`import folder_paths
`
2
2
`import comfy.sd
`
3
3
`import comfy.model_sampling
`
``
4
`+
import torch
`
``
5
+
``
6
`+
class LCM(comfy.model_sampling.EPS):
`
``
7
`+
def calculate_denoised(self, sigma, model_output, model_input):
`
``
8
`+
timestep = self.timestep(sigma).view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
`
``
9
`+
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
`
``
10
`+
x0 = model_input - model_output * sigma
`
``
11
+
``
12
`+
sigma_data = 0.5
`
``
13
`+
scaled_timestep = timestep * 10.0 #timestep_scaling
`
``
14
+
``
15
`+
c_skip = sigma_data2 / (scaled_timestep2 + sigma_data**2)
`
``
16
`+
c_out = scaled_timestep / (scaled_timestep2 + sigma_data2) ** 0.5
`
``
17
+
``
18
`+
return c_out * x0 + c_skip * model_input
`
``
19
+
``
20
`+
class ModelSamplingDiscreteLCM(torch.nn.Module):
`
``
21
`+
def init(self):
`
``
22
`+
super().init()
`
``
23
`+
self.sigma_data = 1.0
`
``
24
`+
timesteps = 1000
`
``
25
`+
beta_start = 0.00085
`
``
26
`+
beta_end = 0.012
`
``
27
+
``
28
`+
betas = torch.linspace(beta_start0.5, beta_end0.5, timesteps, dtype=torch.float32) ** 2
`
``
29
`+
alphas = 1.0 - betas
`
``
30
`+
alphas_cumprod = torch.cumprod(alphas, dim=0)
`
``
31
+
``
32
`+
original_timesteps = 50
`
``
33
`+
self.skip_steps = timesteps // original_timesteps
`
``
34
+
``
35
+
``
36
`+
alphas_cumprod_valid = torch.zeros((original_timesteps), dtype=torch.float32)
`
``
37
`+
for x in range(original_timesteps):
`
``
38
`+
alphas_cumprod_valid[original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps]
`
``
39
+
``
40
`+
sigmas = ((1 - alphas_cumprod_valid) / alphas_cumprod_valid) ** 0.5
`
``
41
`+
self.set_sigmas(sigmas)
`
``
42
+
``
43
`+
def set_sigmas(self, sigmas):
`
``
44
`+
self.register_buffer('sigmas', sigmas)
`
``
45
`+
self.register_buffer('log_sigmas', sigmas.log())
`
``
46
+
``
47
`+
@property
`
``
48
`+
def sigma_min(self):
`
``
49
`+
return self.sigmas[0]
`
``
50
+
``
51
`+
@property
`
``
52
`+
def sigma_max(self):
`
``
53
`+
return self.sigmas[-1]
`
``
54
+
``
55
`+
def timestep(self, sigma):
`
``
56
`+
log_sigma = sigma.log()
`
``
57
`+
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
`
``
58
`+
return dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1)
`
``
59
+
``
60
`+
def sigma(self, timestep):
`
``
61
`+
t = torch.clamp(((timestep - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1))
`
``
62
`+
low_idx = t.floor().long()
`
``
63
`+
high_idx = t.ceil().long()
`
``
64
`+
w = t.frac()
`
``
65
`+
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
`
``
66
`+
return log_sigma.exp()
`
``
67
+
``
68
`+
def percent_to_sigma(self, percent):
`
``
69
`+
return self.sigma(torch.tensor(percent * 999.0))
`
4
70
``
5
71
``
6
72
`def rescale_zero_terminal_snr_sigmas(sigmas):
`
`@@ -26,7 +92,7 @@ class ModelSamplingDiscrete:
`
26
92
`@classmethod
`
27
93
`def INPUT_TYPES(s):
`
28
94
`return {"required": { "model": ("MODEL",),
`
29
``
`-
"sampling": (["eps", "v_prediction"],),
`
``
95
`+
"sampling": (["eps", "v_prediction", "lcm"],),
`
30
96
`"zsnr": ("BOOLEAN", {"default": False}),
`
31
97
` }}
`
32
98
``
`@@ -38,17 +104,22 @@ def INPUT_TYPES(s):
`
38
104
`def patch(self, model, sampling, zsnr):
`
39
105
`m = model.clone()
`
40
106
``
``
107
`+
sampling_base = comfy.model_sampling.ModelSamplingDiscrete
`
41
108
`if sampling == "eps":
`
42
109
`sampling_type = comfy.model_sampling.EPS
`
43
110
`elif sampling == "v_prediction":
`
44
111
`sampling_type = comfy.model_sampling.V_PREDICTION
`
``
112
`+
elif sampling == "lcm":
`
``
113
`+
sampling_type = LCM
`
``
114
`+
sampling_base = ModelSamplingDiscreteLCM
`
45
115
``
46
``
`-
class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingDiscrete, sampling_type):
`
``
116
`+
class ModelSamplingAdvanced(sampling_base, sampling_type):
`
47
117
`pass
`
48
118
``
49
119
`model_sampling = ModelSamplingAdvanced()
`
50
120
`if zsnr:
`
51
121
`model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas))
`
``
122
+
52
123
`m.add_object_patch("model_sampling", model_sampling)
`
53
124
`return (m, )
`
54
125
``