send noisy latent into refiner without adding noise

This commit is contained in:
AUTOMATIC1111 2023-08-07 12:10:16 +03:00
parent 3f82820612
commit 686598387f
2 changed files with 19 additions and 16 deletions

View File

@ -384,11 +384,11 @@ class StableDiffusionProcessing:
shared.state.nextjob()
stopped_at = self.sampler.stop_at
noisy_output = self.sampler.noisy_output
self.sampler = None
a_is_sdxl = shared.sd_model.is_sdxl
decoded_samples = decode_latent_batch(shared.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
decoded_noisy = decode_latent_batch(shared.sd_model, noisy_output, target_device=devices.cpu, check_for_nans=True)
refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(shared.opts.sd_refiner_checkpoint)
if refiner_checkpoint_info is None:
@ -408,21 +408,21 @@ class StableDiffusionProcessing:
b_is_sdxl = shared.sd_model.is_sdxl
if a_is_sdxl != b_is_sdxl:
decoded_samples = torch.stack(decoded_samples).float()
decoded_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
latent = images_tensor_to_samples(decoded_samples, approximation_indexes.get(opts.sd_vae_encode_method), shared.sd_model)
decoded_noisy = torch.stack(decoded_noisy).float()
decoded_noisy = torch.clamp((decoded_noisy + 1.0) / 2.0, min=0.0, max=1.0)
noisy_latent = images_tensor_to_samples(decoded_noisy, approximation_indexes.get(opts.sd_vae_encode_method), shared.sd_model)
else:
latent = samples
noisy_latent = noisy_output
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=self.seeds, subseeds=self.subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
x = torch.zeros_like(noisy_latent)
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
denoising_strength = self.denoising_strength
self.denoising_strength = 1.0 - stopped_at / self.steps
self.image_conditioning = txt2img_image_conditioning(shared.sd_model, latent, self.width, self.height)
self.denoising_strength = 1.0 - (stopped_at + 1) / self.steps
self.image_conditioning = txt2img_image_conditioning(shared.sd_model, noisy_latent, self.width, self.height)
self.sampler = sd_samplers.create_sampler(self.sampler_name, shared.sd_model)
samples = self.sampler.sample_img2img(self, latent, x, self.c, self.uc, image_conditioning=self.image_conditioning, steps=max(1, self.steps - stopped_at - 1))
samples = self.sampler.sample_img2img(self, noisy_latent, x, self.c, self.uc, image_conditioning=self.image_conditioning, steps=max(1, self.steps - stopped_at - 1))
self.denoising_strength = denoising_strength
@ -823,6 +823,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if state.interrupted:
break
sd_models.reload_model_weights() # model can be changed for example by refiner
p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
@ -862,10 +864,12 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.n_iter > 1:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
if have_refiner:
p.sampler.stop_at = max(1, int(shared.opts.sd_refiner_switch_at * p.steps - 1))
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
p.sampler = sd_samplers.create_sampler(p.sampler_name, p.sd_model)
if have_refiner:
p.sampler.stop_at = max(1, int(shared.opts.sd_refiner_switch_at * p.steps - 1))
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
if opts.sd_vae_decode_method != 'Full':
@ -1056,8 +1060,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.hr_uc = None
def init(self, all_prompts, all_seeds, all_subseeds):
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
if self.enable_hr:
if self.hr_checkpoint_name:
self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name)
@ -1355,7 +1357,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.image_conditioning = None
def init(self, all_prompts, all_seeds, all_subseeds):
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
crop_region = None
image_mask = self.image_mask

View File

@ -276,6 +276,7 @@ class KDiffusionSampler:
self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
self.sampler_noises = None
self.stop_at = None
self.noisy_output = None
self.eta = None
self.config = None # set by the function calling the constructor
self.last_latent = None
@ -297,6 +298,7 @@ class KDiffusionSampler:
if opts.live_preview_content == "Combined":
sd_samplers_common.store_latent(latent)
self.last_latent = latent
self.noisy_output = d['x']
if self.stop_at is not None and step > self.stop_at:
raise sd_samplers_common.InterruptedException