DDIM support returned for img2img
This commit is contained in:
parent
9427e4e290
commit
53e7616b51
79
webui.py
79
webui.py
|
@ -94,7 +94,7 @@ samplers = [
|
|||
SamplerData('DDIM', lambda: VanillaStableDiffusionSampler(DDIMSampler)),
|
||||
SamplerData('PLMS', lambda: VanillaStableDiffusionSampler(PLMSSampler)),
|
||||
]
|
||||
samplers_for_img2img = [x for x in samplers if x.name != 'DDIM' and x.name != 'PLMS']
|
||||
samplers_for_img2img = [x for x in samplers if x.name != 'PLMS']
|
||||
|
||||
RealesrganModelInfo = namedtuple("RealesrganModelInfo", ["name", "location", "model", "netscale"])
|
||||
|
||||
|
@ -835,9 +835,37 @@ class StableDiffusionProcessing:
|
|||
raise NotImplementedError()
|
||||
|
||||
|
||||
def p_sample_ddim_hook(sampler_wrapper, x_dec, cond, ts, *args, **kwargs):
|
||||
if sampler_wrapper.mask is not None:
|
||||
img_orig = sampler_wrapper.sampler.model.q_sample(sampler_wrapper.init_latent, ts)
|
||||
x_dec = img_orig * sampler_wrapper.mask + sampler_wrapper.nmask * x_dec
|
||||
|
||||
return sampler_wrapper.orig_p_sample_ddim(x_dec, cond, ts, *args, **kwargs)
|
||||
|
||||
|
||||
class VanillaStableDiffusionSampler:
|
||||
def __init__(self, constructor):
|
||||
self.sampler = constructor(sd_model)
|
||||
self.orig_p_sample_ddim = self.sampler.p_sample_ddim
|
||||
self.sampler.p_sample_ddim = lambda x_dec, cond, ts, *args, **kwargs: p_sample_ddim_hook(self, x_dec, cond, ts, *args, **kwargs)
|
||||
self.mask = None
|
||||
self.nmask = None
|
||||
self.init_latent = None
|
||||
|
||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning):
|
||||
t_enc = int(min(p.denoising_strength, 0.999) * p.steps)
|
||||
|
||||
self.sampler.make_schedule(ddim_num_steps=p.steps, ddim_eta=0.0, verbose=False)
|
||||
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(device), noise=noise)
|
||||
|
||||
self.mask = p.mask
|
||||
self.nmask = p.nmask
|
||||
self.init_latent = p.init_latent
|
||||
|
||||
samples = self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning)
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
def sample(self, p: StableDiffusionProcessing, x, conditioning, unconditional_conditioning):
|
||||
samples_ddim, _ = self.sampler.sample(S=p.steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x)
|
||||
|
@ -864,6 +892,27 @@ class KDiffusionSampler:
|
|||
self.func = getattr(k_diffusion.sampling, self.funcname)
|
||||
self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
|
||||
|
||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning):
|
||||
t_enc = int(min(p.denoising_strength, 0.999) * p.steps)
|
||||
sigmas = self.model_wrap.get_sigmas(p.steps)
|
||||
noise = noise * sigmas[p.steps - t_enc - 1]
|
||||
|
||||
xi = x + noise
|
||||
|
||||
if p.mask is not None:
|
||||
if p.inpainting_fill == 2:
|
||||
xi = xi * p.mask + noise * p.nmask
|
||||
elif p.inpainting_fill == 3:
|
||||
xi = xi * p.mask
|
||||
|
||||
sigma_sched = sigmas[p.steps - t_enc - 1:]
|
||||
|
||||
def mask_cb(v):
|
||||
v["denoised"][:] = v["denoised"][:] * p.nmask + p.init_latent * p.mask
|
||||
|
||||
return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=mask_cb if p.mask is not None else None)
|
||||
|
||||
|
||||
def sample(self, p: StableDiffusionProcessing, x, conditioning, unconditional_conditioning):
|
||||
sigmas = self.model_wrap.get_sigmas(p.steps)
|
||||
x = x * sigmas[0]
|
||||
|
@ -1246,39 +1295,20 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
|||
self.original_mask = self.original_mask.filter(ImageFilter.GaussianBlur(self.mask_blur)).convert('L')
|
||||
|
||||
latmask = self.original_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
|
||||
latmask = np.moveaxis(np.array(latmask, dtype=np.float), 2, 0) / 255
|
||||
latmask = np.moveaxis(np.array(latmask, dtype=np.float64), 2, 0) / 255
|
||||
latmask = latmask[0]
|
||||
latmask = np.tile(latmask[None], (4, 1, 1))
|
||||
|
||||
self.mask = torch.asarray(1.0 - latmask).to(device).type(sd_model.dtype)
|
||||
self.nmask = torch.asarray(latmask).to(device).type(sd_model.dtype)
|
||||
|
||||
|
||||
|
||||
def sample(self, x, conditioning, unconditional_conditioning):
|
||||
t_enc = int(min(self.denoising_strength, 0.999) * self.steps)
|
||||
|
||||
sigmas = self.sampler.model_wrap.get_sigmas(self.steps)
|
||||
noise = x * sigmas[self.steps - t_enc - 1]
|
||||
xi = self.init_latent + noise
|
||||
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning)
|
||||
|
||||
if self.mask is not None:
|
||||
if self.inpainting_fill == 2:
|
||||
xi = xi * self.mask + noise * self.nmask
|
||||
elif self.inpainting_fill == 3:
|
||||
xi = xi * self.mask
|
||||
samples = samples * self.nmask + self.init_latent * self.mask
|
||||
|
||||
sigma_sched = sigmas[self.steps - t_enc - 1:]
|
||||
|
||||
def mask_cb(v):
|
||||
v["denoised"][:] = v["denoised"][:] * self.nmask + self.init_latent * self.mask
|
||||
|
||||
samples_ddim = self.sampler.func(self.sampler.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': self.cfg_scale}, disable=False, callback=mask_cb if self.mask is not None else None)
|
||||
|
||||
if self.mask is not None:
|
||||
samples_ddim = samples_ddim * self.nmask + self.init_latent * self.mask
|
||||
|
||||
return samples_ddim
|
||||
return samples
|
||||
|
||||
|
||||
def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, use_GFPGAN: bool, prompt_matrix, mode: int, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int):
|
||||
|
@ -1544,6 +1574,7 @@ def run_extras(image, GFPGAN_strength, RealESRGAN_upscaling, RealESRGAN_model_in
|
|||
if have_realesrgan and RealESRGAN_upscaling != 1.0:
|
||||
image = upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index)
|
||||
|
||||
os.makedirs(outpath, exist_ok=True)
|
||||
base_count = len(os.listdir(outpath))
|
||||
save_image(image, outpath, f"{base_count:05}", None, '', opts.samples_format, short_filename=True)
|
||||
|
||||
|
|
Loading…
Reference in New Issue