hr conditioning
This commit is contained in:
parent
f774a8d24e
commit
a9f0e7d536
|
@ -235,7 +235,7 @@ class StableDiffusionProcessing:
|
||||||
def init(self, all_prompts, all_seeds, all_subseeds):
|
def init(self, all_prompts, all_seeds, all_subseeds):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
def sample(self, conditioning, unconditional_conditioning, hr_conditioning, hr_uconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
|
@ -516,25 +516,25 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||||
else:
|
else:
|
||||||
p.all_negative_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)]
|
p.all_negative_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)]
|
||||||
|
|
||||||
# if type(p) == StableDiffusionProcessingTxt2Img:
|
if type(p) == StableDiffusionProcessingTxt2Img:
|
||||||
# if p.enable_hr and p.is_hr_pass:
|
if p.enable_hr and p.is_hr_pass:
|
||||||
# logging.info("Running hr pass with custom prompt")
|
logging.info("Running hr pass with custom prompt")
|
||||||
# if p.hr_prompt:
|
if p.hr_prompt:
|
||||||
# if type(p.prompt) == list:
|
if type(p.prompt) == list:
|
||||||
# p.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, p.styles) for x in p.hr_prompt]
|
p.all_hr_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, p.styles) for x in p.hr_prompt]
|
||||||
# else:
|
else:
|
||||||
# p.all_prompts = p.batch_size * p.n_iter * [
|
p.all_hr_prompts = p.batch_size * p.n_iter * [
|
||||||
# shared.prompt_styles.apply_styles_to_prompt(p.hr_prompt, p.styles)]
|
shared.prompt_styles.apply_styles_to_prompt(p.hr_prompt, p.styles)]
|
||||||
# logging.info(p.all_prompts)
|
logging.info(p.all_prompts)
|
||||||
#
|
|
||||||
# if p.hr_negative_prompt:
|
if p.hr_negative_prompt:
|
||||||
# if type(p.negative_prompt) == list:
|
if type(p.negative_prompt) == list:
|
||||||
# p.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, p.styles) for x in
|
p.all_hr_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, p.styles) for x in
|
||||||
# p.hr_negative_prompt]
|
p.hr_negative_prompt]
|
||||||
# else:
|
else:
|
||||||
# p.all_negative_prompts = p.batch_size * p.n_iter * [
|
p.all_hr_negative_prompts = p.batch_size * p.n_iter * [
|
||||||
# shared.prompt_styles.apply_negative_styles_to_prompt(p.hr_negative_prompt, p.styles)]
|
shared.prompt_styles.apply_negative_styles_to_prompt(p.hr_negative_prompt, p.styles)]
|
||||||
# logging.info(p.all_negative_prompts)
|
logging.info(p.all_negative_prompts)
|
||||||
|
|
||||||
if type(seed) == list:
|
if type(seed) == list:
|
||||||
p.all_seeds = seed
|
p.all_seeds = seed
|
||||||
|
@ -607,6 +607,12 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||||
|
|
||||||
prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|
prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|
||||||
negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|
negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|
||||||
|
|
||||||
|
if type(p) == StableDiffusionProcessingTxt2Img:
|
||||||
|
if p.enable_hr:
|
||||||
|
hr_prompts = p.all_hr_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|
||||||
|
hr_negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|
||||||
|
|
||||||
seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
|
seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
|
||||||
subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
|
subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
|
||||||
|
|
||||||
|
@ -620,6 +626,12 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||||
|
|
||||||
uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps, cached_uc)
|
uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps, cached_uc)
|
||||||
c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps, cached_c)
|
c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps, cached_c)
|
||||||
|
if type(p) == StableDiffusionProcessingTxt2Img:
|
||||||
|
if p.enable_hr:
|
||||||
|
hr_uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps,
|
||||||
|
cached_uc)
|
||||||
|
hr_c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps,
|
||||||
|
cached_c)
|
||||||
|
|
||||||
if len(model_hijack.comments) > 0:
|
if len(model_hijack.comments) > 0:
|
||||||
for comment in model_hijack.comments:
|
for comment in model_hijack.comments:
|
||||||
|
@ -629,7 +641,16 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||||
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
||||||
|
|
||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
|
if type(p) == StableDiffusionProcessingTxt2Img:
|
||||||
|
if p.enable_hr:
|
||||||
|
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, hr_conditioning=hr_c, hr_uconditional_conditioning=hr_uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
|
||||||
|
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds,
|
||||||
|
subseeds=subseeds,
|
||||||
|
subseed_strength=p.subseed_strength, prompts=prompts)
|
||||||
|
else:
|
||||||
|
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds,
|
||||||
|
subseeds=subseeds,
|
||||||
|
subseed_strength=p.subseed_strength, prompts=prompts)
|
||||||
|
|
||||||
x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]
|
x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]
|
||||||
for x in x_samples_ddim:
|
for x in x_samples_ddim:
|
||||||
|
@ -744,6 +765,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
self.hr_sampler = hr_sampler
|
self.hr_sampler = hr_sampler
|
||||||
self.hr_prompt = hr_prompt if hr_prompt != '' else self.prompt
|
self.hr_prompt = hr_prompt if hr_prompt != '' else self.prompt
|
||||||
self.hr_negative_prompt = hr_negative_prompt if hr_negative_prompt != '' else self.negative_prompt
|
self.hr_negative_prompt = hr_negative_prompt if hr_negative_prompt != '' else self.negative_prompt
|
||||||
|
self.all_hr_prompts = None
|
||||||
|
self.all_hr_negative_prompts = None
|
||||||
|
|
||||||
if firstphase_width != 0 or firstphase_height != 0:
|
if firstphase_width != 0 or firstphase_height != 0:
|
||||||
self.hr_upscale_to_x = self.width
|
self.hr_upscale_to_x = self.width
|
||||||
|
@ -817,7 +840,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
if self.hr_upscaler is not None:
|
if self.hr_upscaler is not None:
|
||||||
self.extra_generation_params["Hires upscaler"] = self.hr_upscaler
|
self.extra_generation_params["Hires upscaler"] = self.hr_upscaler
|
||||||
|
|
||||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
def sample(self, conditioning, unconditional_conditioning, hr_conditioning, hr_uconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
||||||
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
||||||
|
|
||||||
latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
|
latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
|
||||||
|
@ -830,9 +853,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
if not self.enable_hr:
|
if not self.enable_hr:
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
self.prompt = self.hr_prompt
|
|
||||||
self.negative_prompt = self.hr_negative_prompt
|
|
||||||
|
|
||||||
target_width = self.hr_upscale_to_x
|
target_width = self.hr_upscale_to_x
|
||||||
target_height = self.hr_upscale_to_y
|
target_height = self.hr_upscale_to_y
|
||||||
|
|
||||||
|
@ -904,7 +924,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
x = None
|
x = None
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
|
samples = self.sampler.sample_img2img(self, samples, noise, hr_conditioning, hr_unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue