Merge pull request #8958 from MrCheeze/variations-model · Rewinged/stable-diffusion-webui@f1db987 (original) (raw)

`@@ -70,8 +70,13 @@ def before_sample(self, x, ts, cond, unconditional_conditioning):

`

70

70

``

71

71

`# Have to unwrap the inpainting conditioning here to perform pre-processing

`

72

72

`image_conditioning = None

`

``

73

`+

uc_image_conditioning = None

`

73

74

`if isinstance(cond, dict):

`

74

``

`-

image_conditioning = cond["c_concat"][0]

`

``

75

`+

if self.conditioning_key == "crossattn-adm":

`

``

76

`+

image_conditioning = cond["c_adm"]

`

``

77

`+

uc_image_conditioning = unconditional_conditioning["c_adm"]

`

``

78

`+

else:

`

``

79

`+

image_conditioning = cond["c_concat"][0]

`

75

80

`cond = cond["c_crossattn"][0]

`

76

81

`unconditional_conditioning = unconditional_conditioning["c_crossattn"][0]

`

77

82

``

`@@ -98,8 +103,12 @@ def before_sample(self, x, ts, cond, unconditional_conditioning):

`

98

103

`# Wrap the image conditioning back up since the DDIM code can accept the dict directly.

`

99

104

`# Note that they need to be lists because it just concatenates them later.

`

100

105

`if image_conditioning is not None:

`

101

``

`-

cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}

`

102

``

`-

unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}

`

``

106

`+

if self.conditioning_key == "crossattn-adm":

`

``

107

`+

cond = {"c_adm": image_conditioning, "c_crossattn": [cond]}

`

``

108

`+

unconditional_conditioning = {"c_adm": uc_image_conditioning, "c_crossattn": [unconditional_conditioning]}

`

``

109

`+

else:

`

``

110

`+

cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}

`

``

111

`+

unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}

`

103

112

``

104

113

`return x, ts, cond, unconditional_conditioning

`

105

114

``

`@@ -176,8 +185,12 @@ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning,

`

176

185

``

177

186

`# Wrap the conditioning models with additional image conditioning for inpainting model

`

178

187

`if image_conditioning is not None:

`

179

``

`-

conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}

`

180

``

`-

unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}

`

``

188

`+

if self.conditioning_key == "crossattn-adm":

`

``

189

`+

conditioning = {"c_adm": image_conditioning, "c_crossattn": [conditioning]}

`

``

190

`+

unconditional_conditioning = {"c_adm": torch.zeros_like(image_conditioning), "c_crossattn": [unconditional_conditioning]}

`

``

191

`+

else:

`

``

192

`+

conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}

`

``

193

`+

unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}

`

181

194

``

182

195

`samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))

`

183

196

``

`@@ -195,8 +208,12 @@ def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, ima

`

195

208

`# Wrap the conditioning models with additional image conditioning for inpainting model

`

196

209

`# dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape

`

197

210

`if image_conditioning is not None:

`

198

``

`-

conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]}

`

199

``

`-

unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]}

`

``

211

`+

if self.conditioning_key == "crossattn-adm":

`

``

212

`+

conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_adm": image_conditioning}

`

``

213

`+

unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_adm": torch.zeros_like(image_conditioning)}

`

``

214

`+

else:

`

``

215

`+

conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]}

`

``

216

`+

unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]}

`

200

217

``

201

218

`samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])

`

202

219

``