Have upscale button use the same seed as hires fix.

This commit is contained in:
AUTOMATIC1111 2024-01-04 19:47:00 +03:00
parent f903b4dda3
commit 15ec54dd96
5 changed files with 53 additions and 20 deletions

View File

@ -91,6 +91,9 @@ class Script:
setup_for_ui_only = False
"""If true, the script setup will only be run in Gradio UI, not in API"""
controls = None
"""A list of controls retured by the ui()."""
def title(self):
"""this function should return the title of the script. This is what will be displayed in the dropdown menu."""
@ -624,6 +627,7 @@ class ScriptRunner:
import modules.api.models as api_models
controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
script.controls = controls
if controls is None:
return
@ -918,6 +922,23 @@ class ScriptRunner:
except Exception:
errors.report(f"Error running setup: {script.filename}", exc_info=True)
def set_named_arg(self, args, script_type, arg_elem_id, value):
script = next((x for x in self.scripts if type(x).__name__ == script_type), None)
if script is None:
return
for i, control in enumerate(script.controls):
if arg_elem_id in control.elem_id:
index = script.args_from + i
if isinstance(args, list):
args[index] = value
return args
elif isinstance(args, tuple):
return args[:index] + (value,) + args[index+1:]
else:
return None
scripts_txt2img: ScriptRunner = None
scripts_img2img: ScriptRunner = None

View File

@ -1,3 +1,4 @@
import json
from contextlib import closing
import modules.scripts
@ -9,12 +10,19 @@ from modules.ui import plaintext_to_html
import gradio as gr
def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, *args):
def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, generation_info, *args):
assert len(gallery) > 0, 'No image to upscale'
assert 0 <= gallery_index < len(gallery), f'Bad image index: {gallery_index}'
geninfo = json.loads(generation_info)
all_seeds = geninfo["all_seeds"]
image_info = gallery[gallery_index] if 0 <= gallery_index < len(gallery) else gallery[0]
image = infotext_utils.image_from_url_text(image_info)
gallery_index_from_end = len(gallery) - gallery_index
image.seed = all_seeds[-gallery_index_from_end if gallery_index_from_end < len(all_seeds) + 1 else 0]
return txt2img(id_task, request, *args, firstpass_image=image)
@ -22,6 +30,10 @@ def txt2img(id_task: str, request: gr.Request, prompt: str, negative_prompt: str
override_settings = create_override_settings_dict(override_settings_texts)
if firstpass_image is not None:
seed = getattr(firstpass_image, 'seed', None)
if seed:
args = modules.scripts.scripts_txt2img.set_named_arg(args, 'ScriptSeed', 'seed', seed)
enable_hr = True
batch_size = 1
n_iter = 1

View File

@ -405,8 +405,8 @@ def create_ui():
txt2img_outputs = [
output_panel.gallery,
output_panel.generation_info,
output_panel.infotext,
output_panel.html_info,
output_panel.html_log,
]
@ -424,7 +424,7 @@ def create_ui():
output_panel.button_upscale.click(
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img_upscale, extra_outputs=[None, '', '']),
_js="submit_txt2img_upscale",
inputs=txt2img_inputs[0:1] + [output_panel.gallery, dummy_component] + txt2img_inputs[1:],
inputs=txt2img_inputs[0:1] + [output_panel.gallery, dummy_component, output_panel.generation_info] + txt2img_inputs[1:],
outputs=txt2img_outputs,
show_progress=False,
)
@ -437,8 +437,8 @@ def create_ui():
inputs=[dummy_component],
outputs=[
output_panel.gallery,
output_panel.generation_info,
output_panel.infotext,
output_panel.html_info,
output_panel.html_log,
],
show_progress=False,
@ -766,8 +766,8 @@ def create_ui():
] + custom_inputs,
outputs=[
output_panel.gallery,
output_panel.generation_info,
output_panel.infotext,
output_panel.html_info,
output_panel.html_log,
],
show_progress=False,
@ -807,8 +807,8 @@ def create_ui():
inputs=[dummy_component],
outputs=[
output_panel.gallery,
output_panel.generation_info,
output_panel.infotext,
output_panel.html_info,
output_panel.html_log,
],
show_progress=False,

View File

@ -108,8 +108,8 @@ def save_files(js_data, images, do_make_zip, index):
@dataclasses.dataclass
class OutputPanel:
gallery = None
generation_info = None
infotext = None
html_info = None
html_log = None
button_upscale = None
@ -175,17 +175,17 @@ Requested path was: {f}
download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}')
with gr.Group():
res.html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext")
res.infotext = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext")
res.html_log = gr.HTML(elem_id=f'html_log_{tabname}', elem_classes="html-log")
res.infotext = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}')
res.generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}')
if tabname == 'txt2img' or tabname == 'img2img':
generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
generation_info_button.click(
fn=update_generation_info,
_js="function(x, y, z){ return [x, y, selected_gallery_index()] }",
inputs=[res.infotext, res.html_info, res.html_info],
outputs=[res.html_info, res.html_info],
inputs=[res.generation_info, res.infotext, res.infotext],
outputs=[res.infotext, res.infotext],
show_progress=False,
)
@ -193,10 +193,10 @@ Requested path was: {f}
fn=call_queue.wrap_gradio_call(save_files),
_js="(x, y, z, w) => [x, y, false, selected_gallery_index()]",
inputs=[
res.infotext,
res.generation_info,
res.gallery,
res.html_info,
res.html_info,
res.infotext,
res.infotext,
],
outputs=[
download_files,
@ -209,10 +209,10 @@ Requested path was: {f}
fn=call_queue.wrap_gradio_call(save_files),
_js="(x, y, z, w) => [x, y, true, selected_gallery_index()]",
inputs=[
res.infotext,
res.generation_info,
res.gallery,
res.html_info,
res.html_info,
res.infotext,
res.infotext,
],
outputs=[
download_files,
@ -221,8 +221,8 @@ Requested path was: {f}
)
else:
res.infotext = gr.HTML(elem_id=f'html_info_x_{tabname}')
res.html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext")
res.generation_info = gr.HTML(elem_id=f'html_info_x_{tabname}')
res.infotext = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext")
res.html_log = gr.HTML(elem_id=f'html_log_{tabname}')
paste_field_names = []

View File

@ -49,7 +49,7 @@ def create_ui():
],
outputs=[
output_panel.gallery,
output_panel.infotext,
output_panel.generation_info,
output_panel.html_log,
],
show_progress=False,