Merge pull request #8958 from MrCheeze/variations-model · Rewinged/stable-diffusion-webui@f1db987 (original) (raw)
`@@ -70,8 +70,13 @@ def before_sample(self, x, ts, cond, unconditional_conditioning):
`
70
70
``
71
71
`# Have to unwrap the inpainting conditioning here to perform pre-processing
`
72
72
`image_conditioning = None
`
``
73
`+
uc_image_conditioning = None
`
73
74
`if isinstance(cond, dict):
`
74
``
`-
image_conditioning = cond["c_concat"][0]
`
``
75
`+
if self.conditioning_key == "crossattn-adm":
`
``
76
`+
image_conditioning = cond["c_adm"]
`
``
77
`+
uc_image_conditioning = unconditional_conditioning["c_adm"]
`
``
78
`+
else:
`
``
79
`+
image_conditioning = cond["c_concat"][0]
`
75
80
`cond = cond["c_crossattn"][0]
`
76
81
`unconditional_conditioning = unconditional_conditioning["c_crossattn"][0]
`
77
82
``
`@@ -98,8 +103,12 @@ def before_sample(self, x, ts, cond, unconditional_conditioning):
`
98
103
`# Wrap the image conditioning back up since the DDIM code can accept the dict directly.
`
99
104
`# Note that they need to be lists because it just concatenates them later.
`
100
105
`if image_conditioning is not None:
`
101
``
`-
cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
`
102
``
`-
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
`
``
106
`+
if self.conditioning_key == "crossattn-adm":
`
``
107
`+
cond = {"c_adm": image_conditioning, "c_crossattn": [cond]}
`
``
108
`+
unconditional_conditioning = {"c_adm": uc_image_conditioning, "c_crossattn": [unconditional_conditioning]}
`
``
109
`+
else:
`
``
110
`+
cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
`
``
111
`+
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
`
103
112
``
104
113
`return x, ts, cond, unconditional_conditioning
`
105
114
``
`@@ -176,8 +185,12 @@ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning,
`
176
185
``
177
186
`# Wrap the conditioning models with additional image conditioning for inpainting model
`
178
187
`if image_conditioning is not None:
`
179
``
`-
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
`
180
``
`-
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
`
``
188
`+
if self.conditioning_key == "crossattn-adm":
`
``
189
`+
conditioning = {"c_adm": image_conditioning, "c_crossattn": [conditioning]}
`
``
190
`+
unconditional_conditioning = {"c_adm": torch.zeros_like(image_conditioning), "c_crossattn": [unconditional_conditioning]}
`
``
191
`+
else:
`
``
192
`+
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
`
``
193
`+
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
`
181
194
``
182
195
`samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
`
183
196
``
`@@ -195,8 +208,12 @@ def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, ima
`
195
208
`# Wrap the conditioning models with additional image conditioning for inpainting model
`
196
209
`# dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape
`
197
210
`if image_conditioning is not None:
`
198
``
`-
conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]}
`
199
``
`-
unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]}
`
``
211
`+
if self.conditioning_key == "crossattn-adm":
`
``
212
`+
conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_adm": image_conditioning}
`
``
213
`+
unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_adm": torch.zeros_like(image_conditioning)}
`
``
214
`+
else:
`
``
215
`+
conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]}
`
``
216
`+
unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]}
`
200
217
``
201
218
`samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
`
202
219
``