xy_grid: Refactor confirm functions

This commit is contained in:
Milly 2022-10-10 02:20:35 +09:00 committed by AUTOMATIC1111
parent 7dba1c07cb
commit 2fffd4bddc
1 changed files with 39 additions and 34 deletions

View File

@ -77,12 +77,26 @@ def apply_sampler(p, x, xs):
p.sampler_index = sampler_index p.sampler_index = sampler_index
def confirm_samplers(p, xs):
samplers_dict = build_samplers_dict(p)
for x in xs:
if x.lower() not in samplers_dict.keys():
raise RuntimeError(f"Unknown sampler: {x}")
def apply_checkpoint(p, x, xs): def apply_checkpoint(p, x, xs):
info = modules.sd_models.get_closet_checkpoint_match(x) info = modules.sd_models.get_closet_checkpoint_match(x)
assert info is not None, f'Checkpoint for {x} not found' if info is None:
raise RuntimeError(f"Unknown checkpoint: {x}")
modules.sd_models.reload_model_weights(shared.sd_model, info) modules.sd_models.reload_model_weights(shared.sd_model, info)
def confirm_checkpoints(p, xs):
for x in xs:
if modules.sd_models.get_closet_checkpoint_match(x) is None:
raise RuntimeError(f"Unknown checkpoint: {x}")
def apply_hypernetwork(p, x, xs): def apply_hypernetwork(p, x, xs):
if x.lower() in ["", "none"]: if x.lower() in ["", "none"]:
name = None name = None
@ -93,7 +107,7 @@ def apply_hypernetwork(p, x, xs):
hypernetwork.load_hypernetwork(name) hypernetwork.load_hypernetwork(name)
def confirm_hypernetworks(xs): def confirm_hypernetworks(p, xs):
for x in xs: for x in xs:
if x.lower() in ["", "none"]: if x.lower() in ["", "none"]:
continue continue
@ -135,29 +149,29 @@ def str_permutations(x):
return x return x
AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value"]) AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value", "confirm"])
AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value"]) AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value", "confirm"])
axis_options = [ axis_options = [
AxisOption("Nothing", str, do_nothing, format_nothing), AxisOption("Nothing", str, do_nothing, format_nothing, None),
AxisOption("Seed", int, apply_field("seed"), format_value_add_label), AxisOption("Seed", int, apply_field("seed"), format_value_add_label, None),
AxisOption("Var. seed", int, apply_field("subseed"), format_value_add_label), AxisOption("Var. seed", int, apply_field("subseed"), format_value_add_label, None),
AxisOption("Var. strength", float, apply_field("subseed_strength"), format_value_add_label), AxisOption("Var. strength", float, apply_field("subseed_strength"), format_value_add_label, None),
AxisOption("Steps", int, apply_field("steps"), format_value_add_label), AxisOption("Steps", int, apply_field("steps"), format_value_add_label, None),
AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label), AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label, None),
AxisOption("Prompt S/R", str, apply_prompt, format_value), AxisOption("Prompt S/R", str, apply_prompt, format_value, None),
AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list), AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list, None),
AxisOption("Sampler", str, apply_sampler, format_value), AxisOption("Sampler", str, apply_sampler, format_value, confirm_samplers),
AxisOption("Checkpoint name", str, apply_checkpoint, format_value), AxisOption("Checkpoint name", str, apply_checkpoint, format_value, confirm_checkpoints),
AxisOption("Hypernetwork", str, apply_hypernetwork, format_value), AxisOption("Hypernetwork", str, apply_hypernetwork, format_value, confirm_hypernetworks),
AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label), AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label, None),
AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label), AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label, None),
AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label), AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label, None),
AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label), AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label, None),
AxisOption("Eta", float, apply_field("eta"), format_value_add_label), AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None),
AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label), AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None),
AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label), # as it is now all AxisOptionImg2Img items must go after AxisOption ones AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None), # as it is now all AxisOptionImg2Img items must go after AxisOption ones
] ]
@ -283,19 +297,10 @@ class Script(scripts.Script):
valslist = list(permutations(valslist)) valslist = list(permutations(valslist))
valslist = [opt.type(x) for x in valslist] valslist = [opt.type(x) for x in valslist]
# Confirm options are valid before starting # Confirm options are valid before starting
if opt.label == "Sampler": if opt.confirm:
samplers_dict = build_samplers_dict(p) opt.confirm(p, valslist)
for sampler_val in valslist:
if sampler_val.lower() not in samplers_dict.keys():
raise RuntimeError(f"Unknown sampler: {sampler_val}")
elif opt.label == "Checkpoint name":
for ckpt_val in valslist:
if modules.sd_models.get_closet_checkpoint_match(ckpt_val) is None:
raise RuntimeError(f"Checkpoint for {ckpt_val} not found")
elif opt.label == "Hypernetwork":
confirm_hypernetworks(valslist)
return valslist return valslist