allow first pass and hires pass to use a single prompt to do different prompt editing, hires is 1.0..2.0:
relative time range is [1..2] absolute time range is [steps+1..steps+hire_steps], e.g. with 30 steps and 20 hires steps, '20' is 2/3rds through first pass, and 40 is halfway through hires pass
This commit is contained in:
parent
5a38a9c0ee
commit
d1ba46b6e1
|
@ -319,12 +319,14 @@ class StableDiffusionProcessing:
|
||||||
self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts]
|
self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts]
|
||||||
self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts]
|
self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts]
|
||||||
|
|
||||||
def cached_params(self, required_prompts, steps, extra_network_data):
|
def cached_params(self, required_prompts, steps, hires_steps, extra_network_data, use_old_scheduling):
|
||||||
"""Returns parameters that invalidate the cond cache if changed"""
|
"""Returns parameters that invalidate the cond cache if changed"""
|
||||||
|
|
||||||
return (
|
return (
|
||||||
required_prompts,
|
required_prompts,
|
||||||
steps,
|
steps,
|
||||||
|
hires_steps,
|
||||||
|
use_old_scheduling,
|
||||||
opts.CLIP_stop_at_last_layers,
|
opts.CLIP_stop_at_last_layers,
|
||||||
shared.sd_model.sd_checkpoint_info,
|
shared.sd_model.sd_checkpoint_info,
|
||||||
extra_network_data,
|
extra_network_data,
|
||||||
|
@ -334,7 +336,7 @@ class StableDiffusionProcessing:
|
||||||
self.height,
|
self.height,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data):
|
def get_conds_with_caching(self, function, required_prompts, steps, hires_steps, caches, extra_network_data):
|
||||||
"""
|
"""
|
||||||
Returns the result of calling function(shared.sd_model, required_prompts, steps)
|
Returns the result of calling function(shared.sd_model, required_prompts, steps)
|
||||||
using a cache to store the result if the same arguments have been used before.
|
using a cache to store the result if the same arguments have been used before.
|
||||||
|
@ -347,7 +349,7 @@ class StableDiffusionProcessing:
|
||||||
caches is a list with items described above.
|
caches is a list with items described above.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
cached_params = self.cached_params(required_prompts, steps, extra_network_data)
|
cached_params = self.cached_params(required_prompts, steps, hires_steps, extra_network_data, shared.opts.use_old_scheduling)
|
||||||
|
|
||||||
for cache in caches:
|
for cache in caches:
|
||||||
if cache[0] is not None and cached_params == cache[0]:
|
if cache[0] is not None and cached_params == cache[0]:
|
||||||
|
@ -356,7 +358,7 @@ class StableDiffusionProcessing:
|
||||||
cache = caches[0]
|
cache = caches[0]
|
||||||
|
|
||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
cache[1] = function(shared.sd_model, required_prompts, steps)
|
cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling)
|
||||||
|
|
||||||
cache[0] = cached_params
|
cache[0] = cached_params
|
||||||
return cache[1]
|
return cache[1]
|
||||||
|
@ -367,8 +369,9 @@ class StableDiffusionProcessing:
|
||||||
|
|
||||||
sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
|
sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
|
||||||
self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
|
self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
|
||||||
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data)
|
self.firstpass_steps = self.steps * self.step_multiplier
|
||||||
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data)
|
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, self.firstpass_steps, None, [self.cached_uc], self.extra_network_data)
|
||||||
|
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, self.firstpass_steps, None, [self.cached_c], self.extra_network_data)
|
||||||
|
|
||||||
def parse_extra_network_prompts(self):
|
def parse_extra_network_prompts(self):
|
||||||
self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)
|
self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)
|
||||||
|
@ -1225,8 +1228,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
hr_prompts = prompt_parser.SdConditioning(self.hr_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y)
|
hr_prompts = prompt_parser.SdConditioning(self.hr_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y)
|
||||||
hr_negative_prompts = prompt_parser.SdConditioning(self.hr_negative_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y, is_negative_prompt=True)
|
hr_negative_prompts = prompt_parser.SdConditioning(self.hr_negative_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y, is_negative_prompt=True)
|
||||||
|
|
||||||
self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
|
hires_steps = (self.hr_second_pass_steps or self.steps) * self.step_multiplier
|
||||||
self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)
|
self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.firstpass_steps, hires_steps, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
|
||||||
|
self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.firstpass_steps, hires_steps, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)
|
||||||
|
|
||||||
def setup_conds(self):
|
def setup_conds(self):
|
||||||
super().setup_conds()
|
super().setup_conds()
|
||||||
|
|
|
@ -26,7 +26,7 @@ plain: /([^\\\[\]():|]|\\.)+/
|
||||||
%import common.SIGNED_NUMBER -> NUMBER
|
%import common.SIGNED_NUMBER -> NUMBER
|
||||||
""")
|
""")
|
||||||
|
|
||||||
def get_learned_conditioning_prompt_schedules(prompts, steps):
|
def get_learned_conditioning_prompt_schedules(prompts, base_steps, hires_steps=None, use_old_scheduling=False):
|
||||||
"""
|
"""
|
||||||
>>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0]
|
>>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0]
|
||||||
>>> g("test")
|
>>> g("test")
|
||||||
|
@ -57,17 +57,38 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
|
||||||
[[1, 'female'], [2, 'male'], [3, 'female'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'female'], [8, 'male'], [9, 'female'], [10, 'male']]
|
[[1, 'female'], [2, 'male'], [3, 'female'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'female'], [8, 'male'], [9, 'female'], [10, 'male']]
|
||||||
>>> g("[fe|||]male")
|
>>> g("[fe|||]male")
|
||||||
[[1, 'female'], [2, 'male'], [3, 'male'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'male'], [8, 'male'], [9, 'female'], [10, 'male']]
|
[[1, 'female'], [2, 'male'], [3, 'male'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'male'], [8, 'male'], [9, 'female'], [10, 'male']]
|
||||||
|
>>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10, 10)[0]
|
||||||
|
>>> g("a [b:.5] c")
|
||||||
|
[[10, 'a b c']]
|
||||||
|
>>> g("a [b:1.5] c")
|
||||||
|
[[5, 'a c'], [10, 'a b c']]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if hires_steps is None or use_old_scheduling:
|
||||||
|
int_offset = 0
|
||||||
|
flt_offset = 0
|
||||||
|
steps = base_steps
|
||||||
|
else:
|
||||||
|
int_offset = base_steps
|
||||||
|
flt_offset = 1.0
|
||||||
|
steps = hires_steps
|
||||||
|
|
||||||
def collect_steps(steps, tree):
|
def collect_steps(steps, tree):
|
||||||
res = [steps]
|
res = [steps]
|
||||||
|
|
||||||
class CollectSteps(lark.Visitor):
|
class CollectSteps(lark.Visitor):
|
||||||
def scheduled(self, tree):
|
def scheduled(self, tree):
|
||||||
tree.children[-2] = float(tree.children[-2])
|
s = tree.children[-2]
|
||||||
if tree.children[-2] < 1:
|
v = float(s)
|
||||||
tree.children[-2] *= steps
|
if use_old_scheduling:
|
||||||
tree.children[-2] = min(steps, int(tree.children[-2]))
|
v = v*steps if v<1 else v
|
||||||
|
else:
|
||||||
|
if "." in s:
|
||||||
|
v = (v - flt_offset) * steps
|
||||||
|
else:
|
||||||
|
v = (v - int_offset)
|
||||||
|
tree.children[-2] = min(steps, int(v))
|
||||||
|
if tree.children[-2] >= 1:
|
||||||
res.append(tree.children[-2])
|
res.append(tree.children[-2])
|
||||||
|
|
||||||
def alternate(self, tree):
|
def alternate(self, tree):
|
||||||
|
@ -134,7 +155,7 @@ class SdConditioning(list):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps):
|
def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps, hires_steps=None, use_old_scheduling=False):
|
||||||
"""converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
|
"""converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
|
||||||
and the sampling step at which this condition is to be replaced by the next one.
|
and the sampling step at which this condition is to be replaced by the next one.
|
||||||
|
|
||||||
|
@ -154,7 +175,7 @@ def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps):
|
||||||
"""
|
"""
|
||||||
res = []
|
res = []
|
||||||
|
|
||||||
prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps)
|
prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps, hires_steps, use_old_scheduling)
|
||||||
cache = {}
|
cache = {}
|
||||||
|
|
||||||
for prompt, prompt_schedule in zip(prompts, prompt_schedules):
|
for prompt, prompt_schedule in zip(prompts, prompt_schedules):
|
||||||
|
@ -229,7 +250,7 @@ class MulticondLearnedConditioning:
|
||||||
self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
|
self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
|
||||||
|
|
||||||
|
|
||||||
def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning:
|
def get_multicond_learned_conditioning(model, prompts, steps, hires_steps=None, use_old_scheduling=False) -> MulticondLearnedConditioning:
|
||||||
"""same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
|
"""same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
|
||||||
For each prompt, the list is obtained by splitting the prompt using the AND separator.
|
For each prompt, the list is obtained by splitting the prompt using the AND separator.
|
||||||
|
|
||||||
|
@ -238,7 +259,7 @@ def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearne
|
||||||
|
|
||||||
res_indexes, prompt_flat_list, prompt_indexes = get_multicond_prompt_list(prompts)
|
res_indexes, prompt_flat_list, prompt_indexes = get_multicond_prompt_list(prompts)
|
||||||
|
|
||||||
learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps)
|
learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps, hires_steps, use_old_scheduling)
|
||||||
|
|
||||||
res = []
|
res = []
|
||||||
for indexes in res_indexes:
|
for indexes in res_indexes:
|
||||||
|
|
|
@ -516,6 +516,7 @@ options_templates.update(options_section(('compatibility', "Compatibility"), {
|
||||||
"use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."),
|
"use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."),
|
||||||
"dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."),
|
"dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."),
|
||||||
"hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."),
|
"hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."),
|
||||||
|
"use_old_scheduling": OptionInfo(False, "Use old prompt where first pass and hires both used the same timeline, and < 1 meant relative and >= 1 meant absolute"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('interrogate', "Interrogate"), {
|
options_templates.update(options_section(('interrogate', "Interrogate"), {
|
||||||
|
|
Loading…
Reference in New Issue