put refiner into main UI, into the new accordions section
add VAE from main model into infotext, not from refiner model option to make scripts UI without gr.Group fix inconsistencies with refiner when usings samplers that do more denoising than steps
This commit is contained in:
parent
26c92f056a
commit
64311faa68
|
@ -373,9 +373,10 @@ class StableDiffusionProcessing:
|
|||
negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height, is_negative_prompt=True)
|
||||
|
||||
sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
|
||||
self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
|
||||
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data)
|
||||
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data)
|
||||
total_steps = sampler_config.total_steps(self.steps) if sampler_config else self.steps
|
||||
self.step_multiplier = total_steps // self.steps
|
||||
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, total_steps, [self.cached_uc], self.extra_network_data)
|
||||
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, total_steps, [self.cached_c], self.extra_network_data)
|
||||
|
||||
def get_conds(self):
|
||||
return self.c, self.uc
|
||||
|
@ -579,8 +580,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
|||
"Size": f"{p.width}x{p.height}",
|
||||
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
|
||||
"Model": (None if not opts.add_model_name_to_info else shared.sd_model.sd_checkpoint_info.name_for_extra),
|
||||
"VAE hash": sd_vae.get_loaded_vae_hash() if opts.add_model_hash_to_info else None,
|
||||
"VAE": sd_vae.get_loaded_vae_name() if opts.add_model_name_to_info else None,
|
||||
"VAE hash": p.loaded_vae_hash if opts.add_model_hash_to_info else None,
|
||||
"VAE": p.loaded_vae_name if opts.add_model_name_to_info else None,
|
||||
"Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
|
||||
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
||||
"Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
||||
|
@ -669,6 +670,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||
if p.tiling is None:
|
||||
p.tiling = opts.tiling
|
||||
|
||||
p.loaded_vae_name = sd_vae.get_loaded_vae_name()
|
||||
p.loaded_vae_hash = sd_vae.get_loaded_vae_hash()
|
||||
|
||||
modules.sd_hijack.model_hijack.apply_circular(p.tiling)
|
||||
modules.sd_hijack.model_hijack.clear_comments()
|
||||
|
||||
|
@ -1188,8 +1192,12 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||
hr_prompts = prompt_parser.SdConditioning(self.hr_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y)
|
||||
hr_negative_prompts = prompt_parser.SdConditioning(self.hr_negative_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y, is_negative_prompt=True)
|
||||
|
||||
self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
|
||||
self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)
|
||||
sampler_config = sd_samplers.find_sampler_config(self.hr_sampler_name or self.sampler_name)
|
||||
steps = self.hr_second_pass_steps or self.steps
|
||||
total_steps = sampler_config.total_steps(steps) if sampler_config else steps
|
||||
|
||||
self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, total_steps, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
|
||||
self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, total_steps, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)
|
||||
|
||||
def setup_conds(self):
|
||||
super().setup_conds()
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
import gradio as gr
|
||||
|
||||
from modules import scripts, sd_models
|
||||
from modules.ui_common import create_refresh_button
|
||||
from modules.ui_components import InputAccordion
|
||||
|
||||
|
||||
class ScriptRefiner(scripts.Script):
|
||||
section = "accordions"
|
||||
create_group = False
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def title(self):
|
||||
return "Refiner"
|
||||
|
||||
def show(self, is_img2img):
|
||||
return scripts.AlwaysVisible
|
||||
|
||||
def ui(self, is_img2img):
|
||||
with InputAccordion(False, label="Refiner", elem_id=self.elem_id("enable")) as enable_refiner:
|
||||
with gr.Row():
|
||||
refiner_checkpoint = gr.Dropdown(label='Checkpoint', elem_id=self.elem_id("checkpoint"), choices=sd_models.checkpoint_tiles(), value='', tooltip="switch to another model in the middle of generation")
|
||||
create_refresh_button(refiner_checkpoint, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, self.elem_id("checkpoint_refresh"))
|
||||
|
||||
refiner_switch_at = gr.Slider(value=0.8, label="Switch at", minimum=0.01, maximum=1.0, step=0.01, elem_id=self.elem_id("switch_at"), tooltip="fraction of sampling steps when the swtch to refiner model should happen; 1=never, 0.5=switch in the middle of generation")
|
||||
|
||||
def lookup_checkpoint(title):
|
||||
info = sd_models.get_closet_checkpoint_match(title)
|
||||
return None if info is None else info.title
|
||||
|
||||
self.infotext_fields = [
|
||||
(enable_refiner, lambda d: 'Refiner' in d),
|
||||
(refiner_checkpoint, lambda d: lookup_checkpoint(d.get('Refiner'))),
|
||||
(refiner_switch_at, 'Refiner switch at'),
|
||||
]
|
||||
|
||||
return enable_refiner, refiner_checkpoint, refiner_switch_at
|
||||
|
||||
def before_process(self, p, enable_refiner, refiner_checkpoint, refiner_switch_at):
|
||||
# the actual implementation is in sd_samplers_common.py, apply_refiner
|
||||
|
||||
p.refiner_checkpoint_info = None
|
||||
p.refiner_switch_at = None
|
||||
|
||||
if not enable_refiner or refiner_checkpoint in (None, "", "None"):
|
||||
return
|
||||
|
||||
refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(refiner_checkpoint)
|
||||
if refiner_checkpoint_info is None:
|
||||
raise Exception(f'Could not find checkpoint with name {refiner_checkpoint}')
|
||||
|
||||
p.refiner_checkpoint_info = refiner_checkpoint_info
|
||||
p.refiner_switch_at = refiner_switch_at
|
|
@ -37,7 +37,10 @@ class Script:
|
|||
is_img2img = False
|
||||
|
||||
group = None
|
||||
"""A gr.Group component that has all script's UI inside it"""
|
||||
"""A gr.Group component that has all script's UI inside it."""
|
||||
|
||||
create_group = True
|
||||
"""If False, for alwayson scripts, a group component will not be created."""
|
||||
|
||||
infotext_fields = None
|
||||
"""if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
|
||||
|
@ -232,6 +235,7 @@ class Script:
|
|||
"""
|
||||
pass
|
||||
|
||||
|
||||
current_basedir = paths.script_path
|
||||
|
||||
|
||||
|
@ -250,7 +254,7 @@ postprocessing_scripts_data = []
|
|||
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
|
||||
|
||||
|
||||
def list_scripts(scriptdirname, extension):
|
||||
def list_scripts(scriptdirname, extension, *, include_extensions=True):
|
||||
scripts_list = []
|
||||
|
||||
basedir = os.path.join(paths.script_path, scriptdirname)
|
||||
|
@ -258,8 +262,9 @@ def list_scripts(scriptdirname, extension):
|
|||
for filename in sorted(os.listdir(basedir)):
|
||||
scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))
|
||||
|
||||
for ext in extensions.active():
|
||||
scripts_list += ext.list_files(scriptdirname, extension)
|
||||
if include_extensions:
|
||||
for ext in extensions.active():
|
||||
scripts_list += ext.list_files(scriptdirname, extension)
|
||||
|
||||
scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
|
||||
|
||||
|
@ -288,7 +293,7 @@ def load_scripts():
|
|||
postprocessing_scripts_data.clear()
|
||||
script_callbacks.clear_callbacks()
|
||||
|
||||
scripts_list = list_scripts("scripts", ".py")
|
||||
scripts_list = list_scripts("scripts", ".py") + list_scripts("modules/processing_scripts", ".py", include_extensions=False)
|
||||
|
||||
syspath = sys.path
|
||||
|
||||
|
@ -429,10 +434,13 @@ class ScriptRunner:
|
|||
if script.alwayson and script.section != section:
|
||||
continue
|
||||
|
||||
with gr.Group(visible=script.alwayson) as group:
|
||||
self.create_script_ui(script)
|
||||
if script.create_group:
|
||||
with gr.Group(visible=script.alwayson) as group:
|
||||
self.create_script_ui(script)
|
||||
|
||||
script.group = group
|
||||
script.group = group
|
||||
else:
|
||||
self.create_script_ui(script)
|
||||
|
||||
def prepare_ui(self):
|
||||
self.inputs = [None]
|
||||
|
|
|
@ -147,6 +147,9 @@ re_strip_checksum = re.compile(r"\s*\[[^]]+]\s*$")
|
|||
|
||||
|
||||
def get_closet_checkpoint_match(search_string):
|
||||
if not search_string:
|
||||
return None
|
||||
|
||||
checkpoint_info = checkpoint_aliases.get(search_string, None)
|
||||
if checkpoint_info is not None:
|
||||
return checkpoint_info
|
||||
|
|
|
@ -45,6 +45,11 @@ class CFGDenoiser(torch.nn.Module):
|
|||
self.nmask = None
|
||||
self.init_latent = None
|
||||
self.steps = None
|
||||
"""number of steps as specified by user in UI"""
|
||||
|
||||
self.total_steps = None
|
||||
"""expected number of calls to denoiser calculated from self.steps and specifics of the selected sampler"""
|
||||
|
||||
self.step = 0
|
||||
self.image_cfg_scale = None
|
||||
self.padded_cond_uncond = False
|
||||
|
@ -56,7 +61,6 @@ class CFGDenoiser(torch.nn.Module):
|
|||
def inner_model(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
|
||||
denoised_uncond = x_out[-uncond.shape[0]:]
|
||||
denoised = torch.clone(denoised_uncond)
|
||||
|
|
|
@ -7,7 +7,16 @@ from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, s
|
|||
from modules.shared import opts, state
|
||||
import k_diffusion.sampling
|
||||
|
||||
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
|
||||
|
||||
SamplerDataTuple = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
|
||||
|
||||
|
||||
class SamplerData(SamplerDataTuple):
|
||||
def total_steps(self, steps):
|
||||
if self.options.get("second_order", False):
|
||||
steps = steps * 2
|
||||
|
||||
return steps
|
||||
|
||||
|
||||
def setup_img2img_steps(p, steps=None):
|
||||
|
@ -131,31 +140,26 @@ def replace_torchsde_browinan():
|
|||
replace_torchsde_browinan()
|
||||
|
||||
|
||||
def apply_refiner(sampler):
|
||||
completed_ratio = sampler.step / sampler.steps
|
||||
def apply_refiner(cfg_denoiser):
|
||||
completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps
|
||||
refiner_switch_at = cfg_denoiser.p.refiner_switch_at
|
||||
refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info
|
||||
|
||||
if completed_ratio <= shared.opts.sd_refiner_switch_at:
|
||||
if refiner_switch_at is not None and completed_ratio <= refiner_switch_at:
|
||||
return False
|
||||
|
||||
if shared.opts.sd_refiner_checkpoint == "None":
|
||||
if refiner_checkpoint_info is None or shared.sd_model.sd_checkpoint_info == refiner_checkpoint_info:
|
||||
return False
|
||||
|
||||
if shared.sd_model.sd_checkpoint_info.title == shared.opts.sd_refiner_checkpoint:
|
||||
return False
|
||||
|
||||
refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(shared.opts.sd_refiner_checkpoint)
|
||||
if refiner_checkpoint_info is None:
|
||||
raise Exception(f'Could not find checkpoint with name {shared.opts.sd_refiner_checkpoint}')
|
||||
|
||||
sampler.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title
|
||||
sampler.p.extra_generation_params['Refiner switch at'] = shared.opts.sd_refiner_switch_at
|
||||
cfg_denoiser.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title
|
||||
cfg_denoiser.p.extra_generation_params['Refiner switch at'] = refiner_switch_at
|
||||
|
||||
with sd_models.SkipWritingToConfig():
|
||||
sd_models.reload_model_weights(info=refiner_checkpoint_info)
|
||||
|
||||
devices.torch_gc()
|
||||
sampler.p.setup_conds()
|
||||
sampler.update_inner_model()
|
||||
cfg_denoiser.p.setup_conds()
|
||||
cfg_denoiser.update_inner_model()
|
||||
|
||||
return True
|
||||
|
||||
|
@ -192,7 +196,7 @@ class Sampler:
|
|||
self.sampler_noises = None
|
||||
self.stop_at = None
|
||||
self.eta = None
|
||||
self.config = None # set by the function calling the constructor
|
||||
self.config: SamplerData = None # set by the function calling the constructor
|
||||
self.last_latent = None
|
||||
self.s_min_uncond = None
|
||||
self.s_churn = 0.0
|
||||
|
@ -208,6 +212,7 @@ class Sampler:
|
|||
self.p = None
|
||||
self.model_wrap_cfg = None
|
||||
self.sampler_extra_args = None
|
||||
self.options = {}
|
||||
|
||||
def callback_state(self, d):
|
||||
step = d['i']
|
||||
|
@ -220,6 +225,7 @@ class Sampler:
|
|||
|
||||
def launch_sampling(self, steps, func):
|
||||
self.model_wrap_cfg.steps = steps
|
||||
self.model_wrap_cfg.total_steps = self.config.total_steps(steps)
|
||||
state.sampling_steps = steps
|
||||
state.sampling_step = 0
|
||||
|
||||
|
|
|
@ -64,9 +64,10 @@ class CFGDenoiserKDiffusion(sd_samplers_cfg_denoiser.CFGDenoiser):
|
|||
|
||||
|
||||
class KDiffusionSampler(sd_samplers_common.Sampler):
|
||||
def __init__(self, funcname, sd_model):
|
||||
def __init__(self, funcname, sd_model, options=None):
|
||||
super().__init__(funcname)
|
||||
|
||||
self.options = options or {}
|
||||
self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname)
|
||||
|
||||
self.model_wrap_cfg = CFGDenoiserKDiffusion(self)
|
||||
|
|
|
@ -69,8 +69,8 @@ def reload_hypernetworks():
|
|||
ui_reorder_categories_builtin_items = [
|
||||
"inpaint",
|
||||
"sampler",
|
||||
"accordions",
|
||||
"checkboxes",
|
||||
"hires_fix",
|
||||
"dimensions",
|
||||
"cfg",
|
||||
"seed",
|
||||
|
@ -86,7 +86,7 @@ def ui_reorder_categories():
|
|||
|
||||
sections = {}
|
||||
for script in scripts.scripts_txt2img.scripts + scripts.scripts_img2img.scripts:
|
||||
if isinstance(script.section, str):
|
||||
if isinstance(script.section, str) and script.section not in ui_reorder_categories_builtin_items:
|
||||
sections[script.section] = 1
|
||||
|
||||
yield from sections
|
||||
|
|
|
@ -140,8 +140,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
|||
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
|
||||
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"),
|
||||
"tiling": OptionInfo(False, "Tiling", infotext='Tiling').info("produce a tileable picture"),
|
||||
"sd_refiner_checkpoint": OptionInfo("None", "Refiner checkpoint", gr.Dropdown, lambda: {"choices": ["None"] + shared_items.list_checkpoint_tiles()}, refresh=shared_items.refresh_checkpoints, infotext="Refiner").info("switch to another model in the middle of generation"),
|
||||
"sd_refiner_switch_at": OptionInfo(1.0, "Refiner switch at", gr.Slider, {"minimum": 0.01, "maximum": 1.0, "step": 0.01}, infotext='Refiner switch at').info("fraction of sampling steps when the swtch to refiner model should happen; 1=never, 0.5=switch in the middle of generation"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {
|
||||
|
|
|
@ -438,35 +438,38 @@ def create_ui():
|
|||
with FormRow(elem_classes="checkboxes-row", variant="compact"):
|
||||
pass
|
||||
|
||||
elif category == "hires_fix":
|
||||
with InputAccordion(False, label="Hires. fix") as enable_hr:
|
||||
with enable_hr.extra():
|
||||
hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False, min_width=0)
|
||||
elif category == "accordions":
|
||||
with gr.Row(elem_id="txt2img_accordions", elem_classes="accordions"):
|
||||
with InputAccordion(False, label="Hires. fix", elem_id="txt2img_hr") as enable_hr:
|
||||
with enable_hr.extra():
|
||||
hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False, min_width=0)
|
||||
|
||||
with FormRow(elem_id="txt2img_hires_fix_row1", variant="compact"):
|
||||
hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)
|
||||
hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps")
|
||||
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength")
|
||||
with FormRow(elem_id="txt2img_hires_fix_row1", variant="compact"):
|
||||
hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)
|
||||
hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps")
|
||||
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength")
|
||||
|
||||
with FormRow(elem_id="txt2img_hires_fix_row2", variant="compact"):
|
||||
hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale")
|
||||
hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x")
|
||||
hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y")
|
||||
with FormRow(elem_id="txt2img_hires_fix_row2", variant="compact"):
|
||||
hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale")
|
||||
hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x")
|
||||
hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y")
|
||||
|
||||
with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container:
|
||||
with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container:
|
||||
|
||||
hr_checkpoint_name = gr.Dropdown(label='Hires checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint")
|
||||
create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh")
|
||||
hr_checkpoint_name = gr.Dropdown(label='Hires checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint")
|
||||
create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh")
|
||||
|
||||
hr_sampler_name = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + sd_samplers.visible_sampler_names(), value="Use same sampler")
|
||||
hr_sampler_name = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + sd_samplers.visible_sampler_names(), value="Use same sampler")
|
||||
|
||||
with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
|
||||
with gr.Column(scale=80):
|
||||
with gr.Row():
|
||||
hr_prompt = gr.Textbox(label="Hires prompt", elem_id="hires_prompt", show_label=False, lines=3, placeholder="Prompt for hires fix pass.\nLeave empty to use the same prompt as in first pass.", elem_classes=["prompt"])
|
||||
with gr.Column(scale=80):
|
||||
with gr.Row():
|
||||
hr_negative_prompt = gr.Textbox(label="Hires negative prompt", elem_id="hires_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt for hires fix pass.\nLeave empty to use the same negative prompt as in first pass.", elem_classes=["prompt"])
|
||||
with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
|
||||
with gr.Column(scale=80):
|
||||
with gr.Row():
|
||||
hr_prompt = gr.Textbox(label="Hires prompt", elem_id="hires_prompt", show_label=False, lines=3, placeholder="Prompt for hires fix pass.\nLeave empty to use the same prompt as in first pass.", elem_classes=["prompt"])
|
||||
with gr.Column(scale=80):
|
||||
with gr.Row():
|
||||
hr_negative_prompt = gr.Textbox(label="Hires negative prompt", elem_id="hires_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt for hires fix pass.\nLeave empty to use the same negative prompt as in first pass.", elem_classes=["prompt"])
|
||||
|
||||
scripts.scripts_txt2img.setup_ui_for_section(category)
|
||||
|
||||
elif category == "batch":
|
||||
if not opts.dimensions_and_batch_together:
|
||||
|
@ -482,7 +485,7 @@ def create_ui():
|
|||
with FormGroup(elem_id="txt2img_script_container"):
|
||||
custom_inputs = scripts.scripts_txt2img.setup_ui()
|
||||
|
||||
else:
|
||||
if category not in {"accordions"}:
|
||||
scripts.scripts_txt2img.setup_ui_for_section(category)
|
||||
|
||||
hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y]
|
||||
|
@ -794,6 +797,10 @@ def create_ui():
|
|||
with FormRow(elem_classes="checkboxes-row", variant="compact"):
|
||||
pass
|
||||
|
||||
elif category == "accordions":
|
||||
with gr.Row(elem_id="img2img_accordions", elem_classes="accordions"):
|
||||
scripts.scripts_img2img.setup_ui_for_section(category)
|
||||
|
||||
elif category == "batch":
|
||||
if not opts.dimensions_and_batch_together:
|
||||
with FormRow(elem_id="img2img_column_batch"):
|
||||
|
@ -836,7 +843,8 @@ def create_ui():
|
|||
inputs=[],
|
||||
outputs=[inpaint_controls, mask_alpha],
|
||||
)
|
||||
else:
|
||||
|
||||
if category not in {"accordions"}:
|
||||
scripts.scripts_img2img.setup_ui_for_section(category)
|
||||
|
||||
img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
|
||||
|
|
|
@ -87,13 +87,23 @@ class InputAccordion(gr.Checkbox):
|
|||
self.accordion_id = f"input-accordion-{InputAccordion.global_index}"
|
||||
InputAccordion.global_index += 1
|
||||
|
||||
kwargs['elem_id'] = self.accordion_id + "-checkbox"
|
||||
kwargs['visible'] = False
|
||||
super().__init__(value, **kwargs)
|
||||
kwargs_checkbox = {
|
||||
**kwargs,
|
||||
"elem_id": f"{self.accordion_id}-checkbox",
|
||||
"visible": False,
|
||||
}
|
||||
super().__init__(value, **kwargs_checkbox)
|
||||
|
||||
self.change(fn=None, _js='function(checked){ inputAccordionChecked("' + self.accordion_id + '", checked); }', inputs=[self])
|
||||
|
||||
self.accordion = gr.Accordion(kwargs.get('label', 'Accordion'), open=value, elem_id=self.accordion_id, elem_classes=['input-accordion'])
|
||||
kwargs_accordion = {
|
||||
**kwargs,
|
||||
"elem_id": self.accordion_id,
|
||||
"label": kwargs.get('label', 'Accordion'),
|
||||
"elem_classes": ['input-accordion'],
|
||||
"open": value,
|
||||
}
|
||||
self.accordion = gr.Accordion(**kwargs_accordion)
|
||||
|
||||
def extra(self):
|
||||
"""Allows you to put something into the label of the accordion.
|
||||
|
|
32
style.css
32
style.css
|
@ -166,16 +166,6 @@ a{
|
|||
color: var(--button-secondary-text-color-hover);
|
||||
}
|
||||
|
||||
.checkboxes-row{
|
||||
margin-bottom: 0.5em;
|
||||
margin-left: 0em;
|
||||
}
|
||||
.checkboxes-row > div{
|
||||
flex: 0;
|
||||
white-space: nowrap;
|
||||
min-width: auto !important;
|
||||
}
|
||||
|
||||
button.custom-button{
|
||||
border-radius: var(--button-large-radius);
|
||||
padding: var(--button-large-padding);
|
||||
|
@ -352,7 +342,7 @@ div.block.gradio-accordion {
|
|||
}
|
||||
|
||||
div.dimensions-tools{
|
||||
min-width: 0 !important;
|
||||
min-width: 1.6em !important;
|
||||
max-width: fit-content;
|
||||
flex-direction: column;
|
||||
place-content: center;
|
||||
|
@ -1012,10 +1002,28 @@ div.block.gradio-box.popup-dialog > div:last-child, .popup-dialog > div:last-chi
|
|||
}
|
||||
|
||||
div.block.input-accordion{
|
||||
margin-bottom: 0.4em;
|
||||
|
||||
}
|
||||
|
||||
.input-accordion-extra{
|
||||
flex: 0 0 auto !important;
|
||||
margin: 0 0.5em 0 auto;
|
||||
}
|
||||
|
||||
div.accordions > div.input-accordion{
|
||||
min-width: fit-content !important;
|
||||
}
|
||||
|
||||
div.accordions > div.gradio-accordion .label-wrap span{
|
||||
white-space: nowrap;
|
||||
margin-right: 0.25em;
|
||||
}
|
||||
|
||||
div.accordions{
|
||||
gap: 0.5em;
|
||||
}
|
||||
|
||||
div.accordions > div.input-accordion.input-accordion-open{
|
||||
flex: 1 auto;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue