Only apply ztSNR related code if alphas_cumprod exists

This commit is contained in:
catboxanon 2023-11-29 18:33:32 -05:00
parent ffa7f8201d
commit de79597ab9
1 changed files with 9 additions and 8 deletions

View File

@ -882,15 +882,16 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
alphas_bar[-1] = 4.8973451890853435e-08 alphas_bar[-1] = 4.8973451890853435e-08
return alphas_bar return alphas_bar
p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device) if hasattr(p.sd_model, 'alphas_cumprod') and hasattr(p.sd_model, 'alphas_cumprod_original'):
p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device)
if opts.use_downcasted_alpha_bar: if opts.use_downcasted_alpha_bar:
p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar
p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device) p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device)
if opts.sd_noise_schedule == "Zero Terminal SNR": if opts.sd_noise_schedule == "Zero Terminal SNR":
p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule
print("rescaling noise schedule for zero snr") print("rescaling noise schedule for zero snr")
p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device) p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device)
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)