From 6d0a8dcd892f7ad9b399fed6edbad6ede13c5f69 Mon Sep 17 00:00:00 2001 From: drhead <1313496+drhead@users.noreply.github.com> Date: Wed, 29 Nov 2023 17:42:07 -0500 Subject: [PATCH] Implement zero terminal SNR schedule option --- modules/processing.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/modules/processing.py b/modules/processing.py index ac58ef869..c88eec705 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -863,6 +863,34 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.n_iter > 1: shared.state.job = f"Batch {n+1} out of {p.n_iter}" + + def rescale_zero_terminal_snr_abar(alphas_cumprod): + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= (alphas_bar_sqrt_T) + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas_bar[-1] = 4.8973451890853435e-08 + return alphas_bar + + p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device) + + if opts.use_downcasted_alpha_bar: + p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar + p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device) + if opts.sd_noise_schedule == "Zero Terminal SNR": + p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule + print("rescaling noise schedule for zero snr") + p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device) with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)