this time for sure

This commit is contained in:
AUTOMATIC1111 2023-07-30 15:30:33 +03:00
parent a64fbe8928
commit cc53db6652
1 changed files with 13 additions and 3 deletions

View File

@ -538,8 +538,12 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
return x
class DecodedSamples(list):
already_decoded = True
def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
samples = []
samples = DecodedSamples()
for i in range(batch.shape[0]):
sample = decode_first_stage(model, batch[i:i + 1])[0]
@ -793,7 +797,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
if getattr(samples_ddim, 'already_decoded', False):
x_samples_ddim = samples_ddim
else:
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
x_samples_ddim = torch.stack(x_samples_ddim).float()
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
@ -1161,9 +1169,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
self.is_hr_pass = False
return samples
return decoded_samples
def close(self):
super().close()