Merge branch 'dev' into api_thread_safe

This commit is contained in:
kurisu_u 2023-12-30 21:47:59 +08:00 committed by GitHub
commit d05f9e8124
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 368 additions and 85 deletions

View File

@ -0,0 +1,98 @@
model:
target: sgm.models.diffusion.DiffusionEngine
params:
scale_factor: 0.13025
disable_first_stage_autocast: True
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
params:
num_idx: 1000
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
adm_in_channels: 2816
num_classes: sequential
use_checkpoint: True
in_channels: 9
out_channels: 4
model_channels: 320
attention_resolutions: [4, 2]
num_res_blocks: 2
channel_mult: [1, 2, 4]
num_head_channels: 64
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
context_dim: 2048
spatial_transformer_attn_type: softmax-xformers
legacy: False
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
# crossattn cond
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
params:
layer: hidden
layer_idx: 11
# crossattn and vector cond
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
params:
arch: ViT-bigG-14
version: laion2b_s39b_b160k
freeze: True
layer: penultimate
always_return_pooled: True
legacy: False
# vector cond
- is_trainable: False
input_key: original_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: crop_coords_top_left
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: target_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
first_stage_config:
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla-xformers
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity

View File

@ -1,3 +1,4 @@
import gradio as gr
import logging
import os
import re
@ -314,7 +315,12 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
emb_db.skipped_embeddings[name] = embedding
if failed_to_load_networks:
sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks))
lora_not_found_message = f'Lora not found: {", ".join(failed_to_load_networks)}'
sd_hijack.model_hijack.comments.append(lora_not_found_message)
if shared.opts.lora_not_found_warning_console:
print(f'\n{lora_not_found_message}\n')
if shared.opts.lora_not_found_gradio_warning:
gr.Warning(lora_not_found_message)
purge_networks_from_memory()

View File

@ -39,6 +39,8 @@ shared.options_templates.update(shared.options_section(('extra_networks', "Extra
"lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"),
"lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}),
"lora_in_memory_limit": shared.OptionInfo(0, "Number of Lora networks to keep cached in memory", gr.Number, {"precision": 0}),
"lora_not_found_warning_console": shared.OptionInfo(False, "Lora not found warning in console"),
"lora_not_found_gradio_warning": shared.OptionInfo(False, "Lora not found warning popup in webui"),
}))

View File

@ -317,8 +317,13 @@ class Api:
script_args[script.args_from:script.args_to] = ui_default_values
return script_args
def init_script_args(self, request, default_script_args, selectable_scripts, selectable_idx, script_runner):
def init_script_args(self, request, default_script_args, selectable_scripts, selectable_idx, script_runner, *, input_script_args=None):
script_args = default_script_args.copy()
if input_script_args is not None:
for index, value in input_script_args.items():
script_args[index] = value
# position 0 in script_arg is the idx+1 of the selectable script that is going to be run when using scripts.scripts_*2img.run()
if selectable_scripts:
script_args[selectable_scripts.args_from:selectable_scripts.args_to] = request.script_args
@ -340,14 +345,88 @@ class Api:
script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx]
return script_args
def apply_infotext(self, request, tabname, *, script_runner=None, mentioned_script_args=None):
"""Processes `infotext` field from the `request`, and sets other fields of the `request` accoring to what's in infotext.
If request already has a field set, and that field is encountered in infotext too, the value from infotext is ignored.
Additionally, fills `mentioned_script_args` dict with index: value pairs for script arguments read from infotext.
"""
if not request.infotext:
return {}
possible_fields = generation_parameters_copypaste.paste_fields[tabname]["fields"]
set_fields = request.model_dump(exclude_unset=True) if hasattr(request, "request") else request.dict(exclude_unset=True) # pydantic v1/v2 have differenrt names for this
params = generation_parameters_copypaste.parse_generation_parameters(request.infotext)
def get_field_value(field, params):
value = field.function(params) if field.function else params.get(field.label)
if value is None:
return None
if field.api in request.__fields__:
target_type = request.__fields__[field.api].type_
else:
target_type = type(field.component.value)
if target_type == type(None):
return None
if isinstance(value, dict) and value.get('__type__') == 'generic_update': # this is a gradio.update rather than a value
value = value.get('value')
if value is not None and not isinstance(value, target_type):
value = target_type(value)
return value
for field in possible_fields:
if not field.api:
continue
if field.api in set_fields:
continue
value = get_field_value(field, params)
if value is not None:
setattr(request, field.api, value)
if request.override_settings is None:
request.override_settings = {}
overriden_settings = generation_parameters_copypaste.get_override_settings(params)
for _, setting_name, value in overriden_settings:
if setting_name not in request.override_settings:
request.override_settings[setting_name] = value
if script_runner is not None and mentioned_script_args is not None:
indexes = {v: i for i, v in enumerate(script_runner.inputs)}
script_fields = ((field, indexes[field.component]) for field in possible_fields if field.component in indexes)
for field, index in script_fields:
value = get_field_value(field, params)
if value is None:
continue
mentioned_script_args[index] = value
return params
def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI):
task_id = txt2imgreq.force_task_id or create_task_id("txt2img")
script_runner = scripts.scripts_txt2img
with self.txt2img_script_arg_init_lock:
if not script_runner.scripts:
script_runner.initialize_scripts(False)
ui.create_ui()
infotext_script_args = {}
self.apply_infotext(txt2imgreq, "txt2img", script_runner=script_runner, mentioned_script_args=infotext_script_args)
if not self.default_script_arg_txt2img:
self.default_script_arg_txt2img = self.init_default_script_args(script_runner)
selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner)
@ -364,8 +443,9 @@ class Api:
args.pop('script_name', None)
args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
args.pop('alwayson_scripts', None)
args.pop('infotext', None)
script_args = self.init_script_args(txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner)
script_args = self.init_script_args(txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner, input_script_args=infotext_script_args)
send_images = args.pop('send_images', True)
args.pop('save_images', None)
@ -409,10 +489,15 @@ class Api:
mask = decode_base64_to_image(mask)
script_runner = scripts.scripts_img2img
with self.img2img_script_arg_init_lock:
if not script_runner.scripts:
script_runner.initialize_scripts(True)
ui.create_ui()
infotext_script_args = {}
self.apply_infotext(img2imgreq, "img2img", script_runner=script_runner, mentioned_script_args=infotext_script_args)
if not self.default_script_arg_img2img:
self.default_script_arg_img2img = self.init_default_script_args(script_runner)
selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner)
@ -432,7 +517,7 @@ class Api:
args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
args.pop('alwayson_scripts', None)
script_args = self.init_script_args(img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner)
script_args = self.init_script_args(img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner, input_script_args=infotext_script_args)
send_images = args.pop('send_images', True)
args.pop('save_images', None)

View File

@ -108,6 +108,7 @@ StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
{"key": "save_images", "type": bool, "default": False},
{"key": "alwayson_scripts", "type": dict, "default": {}},
{"key": "force_task_id", "type": str, "default": None},
{"key": "infotext", "type": str, "default": None},
]
).generate_model()
@ -126,6 +127,7 @@ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
{"key": "save_images", "type": bool, "default": False},
{"key": "alwayson_scripts", "type": dict, "default": {}},
{"key": "force_task_id", "type": str, "default": None},
{"key": "infotext", "type": str, "default": None},
]
).generate_model()

View File

@ -28,6 +28,19 @@ class ParamBinding:
self.paste_field_names = paste_field_names or []
class PasteField(tuple):
def __new__(cls, component, target, *, api=None):
return super().__new__(cls, (component, target))
def __init__(self, component, target, *, api=None):
super().__init__()
self.api = api
self.component = component
self.label = target if isinstance(target, str) else None
self.function = target if callable(target) else None
paste_fields: dict[str, dict] = {}
registered_param_bindings: list[ParamBinding] = []
@ -84,6 +97,12 @@ def image_from_url_text(filedata):
def add_paste_fields(tabname, init_img, fields, override_settings_component=None):
if fields:
for i in range(len(fields)):
if not isinstance(fields[i], PasteField):
fields[i] = PasteField(*fields[i])
paste_fields[tabname] = {"init_img": init_img, "fields": fields, "override_settings_component": override_settings_component}
# backwards compatibility for existing extensions
@ -371,6 +390,48 @@ def create_override_settings_dict(text_pairs):
return res
def get_override_settings(params, *, skip_fields=None):
"""Returns a list of settings overrides from the infotext parameters dictionary.
This function checks the `params` dictionary for any keys that correspond to settings in `shared.opts` and returns
a list of tuples containing the parameter name, setting name, and new value cast to correct type.
It checks for conditions before adding an override:
- ignores settings that match the current value
- ignores parameter keys present in skip_fields argument.
Example input:
{"Clip skip": "2"}
Example output:
[("Clip skip", "CLIP_stop_at_last_layers", 2)]
"""
res = []
mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext]
for param_name, setting_name in mapping + infotext_to_setting_name_mapping:
if param_name in (skip_fields or {}):
continue
v = params.get(param_name, None)
if v is None:
continue
if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap:
continue
v = shared.opts.cast_value(setting_name, v)
current_value = getattr(shared.opts, setting_name, None)
if v == current_value:
continue
res.append((param_name, setting_name, v))
return res
def connect_paste(button, paste_fields, input_comp, override_settings_component, tabname):
def paste_func(prompt):
if not prompt and not shared.cmd_opts.hide_ui_dir_config:
@ -412,29 +473,9 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
already_handled_fields = {key: 1 for _, key in paste_fields}
def paste_settings(params):
vals = {}
vals = get_override_settings(params, skip_fields=already_handled_fields)
mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext]
for param_name, setting_name in mapping + infotext_to_setting_name_mapping:
if param_name in already_handled_fields:
continue
v = params.get(param_name, None)
if v is None:
continue
if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap:
continue
v = shared.opts.cast_value(setting_name, v)
current_value = getattr(shared.opts, setting_name, None)
if v == current_value:
continue
vals[param_name] = v
vals_pairs = [f"{k}: {v}" for k, v in vals.items()]
vals_pairs = [f"{infotext_text}: {value}" for infotext_text, setting_name, value in vals]
return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=bool(vals_pairs))

View File

@ -28,5 +28,6 @@ models_path = os.path.join(data_path, "models")
extensions_dir = os.path.join(data_path, "extensions")
extensions_builtin_dir = os.path.join(script_path, "extensions-builtin")
config_states_dir = os.path.join(script_path, "config_states")
default_output_dir = os.path.join(data_path, "output")
roboto_ttf_file = os.path.join(modules_path, 'Roboto-Regular.ttf')

View File

@ -113,6 +113,21 @@ def txt2img_image_conditioning(sd_model, x, width, height):
return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device)
else:
sd = sd_model.model.state_dict()
diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None)
if diffusion_model_input is not None:
if diffusion_model_input.shape[1] == 9:
# The "masked-image" in this case will just be all 0.5 since the entire image is masked.
image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5
image_conditioning = images_tensor_to_samples(image_conditioning,
approximation_indexes.get(opts.sd_vae_encode_method))
# Add the fake full 1s mask to the first dimension.
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
image_conditioning = image_conditioning.to(x.dtype)
return image_conditioning
# Dummy zero conditioning if we're not using inpainting or unclip models.
# Still takes up a bit of memory, but no encoder call.
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
@ -371,6 +386,12 @@ class StableDiffusionProcessing:
if self.sampler.conditioning_key == "crossattn-adm":
return self.unclip_image_conditioning(source_image)
sd = self.sampler.model_wrap.inner_model.model.state_dict()
diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None)
if diffusion_model_input is not None:
if diffusion_model_input.shape[1] == 9:
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
# Dummy zero conditioning if we're not using inpainting or depth model.
return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
@ -1135,7 +1156,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
def init(self, all_prompts, all_seeds, all_subseeds):
if self.enable_hr:
if self.hr_checkpoint_name:
if self.hr_checkpoint_name and self.hr_checkpoint_name != 'Use same checkpoint':
self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name)
if self.hr_checkpoint_info is None:
@ -1482,7 +1503,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
# Save init image
if opts.save_init_img:
self.init_img_hash = hashlib.md5(img.tobytes()).hexdigest()
images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False)
images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False, existing_info=img.info)
image = images.flatten(img, opts.img2img_background_color)

View File

@ -1,6 +1,7 @@
import gradio as gr
from modules import scripts, sd_models
from modules.generation_parameters_copypaste import PasteField
from modules.ui_common import create_refresh_button
from modules.ui_components import InputAccordion
@ -31,9 +32,9 @@ class ScriptRefiner(scripts.ScriptBuiltinUI):
return None if info is None else info.title
self.infotext_fields = [
(enable_refiner, lambda d: 'Refiner' in d),
(refiner_checkpoint, lambda d: lookup_checkpoint(d.get('Refiner'))),
(refiner_switch_at, 'Refiner switch at'),
PasteField(enable_refiner, lambda d: 'Refiner' in d),
PasteField(refiner_checkpoint, lambda d: lookup_checkpoint(d.get('Refiner')), api="refiner_checkpoint"),
PasteField(refiner_switch_at, 'Refiner switch at', api="refiner_switch_at"),
]
return enable_refiner, refiner_checkpoint, refiner_switch_at

View File

@ -3,6 +3,7 @@ import json
import gradio as gr
from modules import scripts, ui, errors
from modules.generation_parameters_copypaste import PasteField
from modules.shared import cmd_opts
from modules.ui_components import ToolButton
@ -51,12 +52,12 @@ class ScriptSeed(scripts.ScriptBuiltinUI):
seed_checkbox.change(lambda x: gr.update(visible=x), show_progress=False, inputs=[seed_checkbox], outputs=[seed_extras])
self.infotext_fields = [
(self.seed, "Seed"),
(seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d),
(subseed, "Variation seed"),
(subseed_strength, "Variation seed strength"),
(seed_resize_from_w, "Seed resize from-1"),
(seed_resize_from_h, "Seed resize from-2"),
PasteField(self.seed, "Seed", api="seed"),
PasteField(seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d),
PasteField(subseed, "Variation seed", api="subseed"),
PasteField(subseed_strength, "Variation seed strength", api="subseed_strength"),
PasteField(seed_resize_from_w, "Seed resize from-1", api="seed_resize_from_h"),
PasteField(seed_resize_from_h, "Seed resize from-2", api="seed_resize_from_w"),
]
self.on_after_component(lambda x: connect_reuse_seed(self.seed, reuse_seed, x.component, False), elem_id=f'generation_info_{self.tabname}')

View File

@ -566,7 +566,12 @@ class ScriptRunner:
auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()
for script_data in auto_processing_scripts + scripts_data:
script = script_data.script_class()
try:
script = script_data.script_class()
except Exception:
errors.report(f"Error # failed to initialize Script {script_data.module}: ", exc_info=True)
continue
script.filename = script_data.path
script.is_txt2img = not is_img2img
script.is_img2img = is_img2img

View File

@ -15,6 +15,7 @@ config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml")
config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml")
config_sdxl_inpainting = os.path.join(sd_configs_path, "sd_xl_inpaint.yaml")
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml")
config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml")
@ -71,7 +72,10 @@ def guess_model_config_from_state_dict(sd, filename):
sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None)
if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None:
return config_sdxl
if diffusion_model_input.shape[1] == 9:
return config_sdxl_inpainting
else:
return config_sdxl
if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None:
return config_sdxl_refiner
elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:

View File

@ -34,6 +34,12 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch:
def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond):
sd = self.model.state_dict()
diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None)
if diffusion_model_input is not None:
if diffusion_model_input.shape[1] == 9:
x = torch.cat([x] + cond['c_concat'], dim=1)
return self.model(x, t, cond)

View File

@ -1,7 +1,8 @@
import os
import gradio as gr
from modules import localization, ui_components, shared_items, shared, interrogate, shared_gradio_themes
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
from modules import localization, ui_components, shared_items, shared, interrogate, shared_gradio_themes, util
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir, default_output_dir # noqa: F401
from modules.shared_cmd_options import cmd_opts
from modules.options import options_section, OptionInfo, OptionHTML, categories
@ -74,14 +75,14 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
options_templates.update(options_section(('saving-paths', "Paths for saving", "saving"), {
"outdir_samples": OptionInfo("", "Output directory for images; if empty, defaults to three directories below", component_args=hide_dirs),
"outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output directory for txt2img images', component_args=hide_dirs),
"outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output directory for img2img images', component_args=hide_dirs),
"outdir_extras_samples": OptionInfo("outputs/extras-images", 'Output directory for images from extras tab', component_args=hide_dirs),
"outdir_txt2img_samples": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'txt2img-images')), 'Output directory for txt2img images', component_args=hide_dirs),
"outdir_img2img_samples": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'img2img-images')), 'Output directory for img2img images', component_args=hide_dirs),
"outdir_extras_samples": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'extras-images')), 'Output directory for images from extras tab', component_args=hide_dirs),
"outdir_grids": OptionInfo("", "Output directory for grids; if empty, defaults to two directories below", component_args=hide_dirs),
"outdir_txt2img_grids": OptionInfo("outputs/txt2img-grids", 'Output directory for txt2img grids', component_args=hide_dirs),
"outdir_img2img_grids": OptionInfo("outputs/img2img-grids", 'Output directory for img2img grids', component_args=hide_dirs),
"outdir_save": OptionInfo("log/images", "Directory for saving images using the Save button", component_args=hide_dirs),
"outdir_init_images": OptionInfo("outputs/init-images", "Directory for saving init images when using img2img", component_args=hide_dirs),
"outdir_txt2img_grids": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'txt2img-grids')), 'Output directory for txt2img grids', component_args=hide_dirs),
"outdir_img2img_grids": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'img2img-grids')), 'Output directory for img2img grids', component_args=hide_dirs),
"outdir_save": OptionInfo(util.truncate_path(os.path.join(data_path, 'log', 'images')), "Directory for saving images using the Save button", component_args=hide_dirs),
"outdir_init_images": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'init-images')), "Directory for saving init images when using img2img", component_args=hide_dirs),
}))
options_templates.update(options_section(('saving-to-dirs', "Saving to a directory", "saving"), {

View File

@ -28,7 +28,7 @@ import modules.textual_inversion.textual_inversion as textual_inversion
import modules.shared as shared
from modules import prompt_parser
from modules.sd_hijack import model_hijack
from modules.generation_parameters_copypaste import image_from_url_text
from modules.generation_parameters_copypaste import image_from_url_text, PasteField
create_setting_component = ui_settings.create_setting_component
@ -436,28 +436,28 @@ def create_ui():
)
txt2img_paste_fields = [
(toprow.prompt, "Prompt"),
(toprow.negative_prompt, "Negative prompt"),
(steps, "Steps"),
(sampler_name, "Sampler"),
(cfg_scale, "CFG scale"),
(width, "Size-1"),
(height, "Size-2"),
(batch_size, "Batch size"),
(toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
(denoising_strength, "Denoising strength"),
(enable_hr, lambda d: "Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d)),
(hr_scale, "Hires upscale"),
(hr_upscaler, "Hires upscaler"),
(hr_second_pass_steps, "Hires steps"),
(hr_resize_x, "Hires resize-1"),
(hr_resize_y, "Hires resize-2"),
(hr_checkpoint_name, "Hires checkpoint"),
(hr_sampler_name, "Hires sampler"),
(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()),
(hr_prompt, "Hires prompt"),
(hr_negative_prompt, "Hires negative prompt"),
(hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()),
PasteField(toprow.prompt, "Prompt", api="prompt"),
PasteField(toprow.negative_prompt, "Negative prompt", api="negative_prompt"),
PasteField(steps, "Steps", api="steps"),
PasteField(sampler_name, "Sampler", api="sampler_name"),
PasteField(cfg_scale, "CFG scale", api="cfg_scale"),
PasteField(width, "Size-1", api="width"),
PasteField(height, "Size-2", api="height"),
PasteField(batch_size, "Batch size", api="batch_size"),
PasteField(toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update(), api="styles"),
PasteField(denoising_strength, "Denoising strength", api="denoising_strength"),
PasteField(enable_hr, lambda d: "Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d), api="enable_hr"),
PasteField(hr_scale, "Hires upscale", api="hr_scale"),
PasteField(hr_upscaler, "Hires upscaler", api="hr_upscaler"),
PasteField(hr_second_pass_steps, "Hires steps", api="hr_second_pass_steps"),
PasteField(hr_resize_x, "Hires resize-1", api="hr_resize_x"),
PasteField(hr_resize_y, "Hires resize-2", api="hr_resize_y"),
PasteField(hr_checkpoint_name, "Hires checkpoint", api="hr_checkpoint_name"),
PasteField(hr_sampler_name, "Hires sampler", api="hr_sampler_name"),
PasteField(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()),
PasteField(hr_prompt, "Hires prompt", api="hr_prompt"),
PasteField(hr_negative_prompt, "Hires negative prompt", api="hr_negative_prompt"),
PasteField(hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()),
*scripts.scripts_txt2img.infotext_fields
]
parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings)

View File

@ -1,17 +1,12 @@
import os
import gradio as gr
from modules import localization, shared, scripts
from modules.paths import script_path, data_path, cwd
from modules import localization, shared, scripts, util
from modules.paths import script_path, data_path
def webpath(fn):
if fn.startswith(cwd):
web_path = os.path.relpath(fn, cwd)
else:
web_path = os.path.abspath(fn)
return f'file={web_path}?{os.path.getmtime(fn)}'
return f'file={util.truncate_path(fn)}?{os.path.getmtime(fn)}'
def javascript_html():

View File

@ -2,7 +2,7 @@ import os
import re
from modules import shared
from modules.paths_internal import script_path
from modules.paths_internal import script_path, cwd
def natural_sort_key(s, regex=re.compile('([0-9]+)')):
@ -56,3 +56,13 @@ def ldm_print(*args, **kwargs):
return
print(*args, **kwargs)
def truncate_path(target_path, base_path=cwd):
abs_target, abs_base = os.path.abspath(target_path), os.path.abspath(base_path)
try:
if os.path.commonpath([abs_target, abs_base]) == abs_base:
return os.path.relpath(abs_target, abs_base)
except ValueError:
pass
return abs_target

View File

@ -4,7 +4,7 @@ import gradio as gr
class ScriptPostprocessingCeption(scripts_postprocessing.ScriptPostprocessing):
name = "Caption"
order = 4000
order = 4040
def ui(self):
with ui_components.InputAccordion(False, label="Caption") as enable:

View File

@ -6,7 +6,7 @@ import gradio as gr
class ScriptPostprocessingCreateFlippedCopies(scripts_postprocessing.ScriptPostprocessing):
name = "Create flipped copies"
order = 4000
order = 4030
def ui(self):
with ui_components.InputAccordion(False, label="Create flipped copies") as enable:

View File

@ -7,7 +7,7 @@ from modules.textual_inversion import autocrop
class ScriptPostprocessingFocalCrop(scripts_postprocessing.ScriptPostprocessing):
name = "Auto focal point crop"
order = 4000
order = 4010
def ui(self):
with ui_components.InputAccordion(False, label="Auto focal point crop") as enable:

View File

@ -28,7 +28,7 @@ def multicrop_pic(image: Image, mindim, maxdim, minarea, maxarea, objective, thr
class ScriptPostprocessingAutosizedCrop(scripts_postprocessing.ScriptPostprocessing):
name = "Auto-sized crop"
order = 4000
order = 4020
def ui(self):
with ui_components.InputAccordion(False, label="Auto-sized crop") as enable:

View File

@ -476,6 +476,8 @@ class Script(scripts.Script):
fill_z_button.click(fn=fill, inputs=[z_type, csv_mode], outputs=[z_values, z_values_dropdown])
def select_axis(axis_type, axis_values, axis_values_dropdown, csv_mode):
axis_type = axis_type or 0 # if axle type is None set to 0
choices = self.current_axis_options[axis_type].choices
has_choices = choices is not None
@ -526,6 +528,8 @@ class Script(scripts.Script):
return [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size, csv_mode]
def run(self, p, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size, csv_mode):
x_type, y_type, z_type = x_type or 0, y_type or 0, z_type or 0 # if axle type is None set to 0
if not no_fixed_seeds:
modules.processing.fix_seed(p)

View File

@ -39,7 +39,7 @@ def api_only():
print(f"Startup time: {startup_timer.summary()}.")
api.launch(
server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1",
server_name=initialize_util.gradio_server_name(),
port=cmd_opts.port if cmd_opts.port else 7861,
root_path=f"/{cmd_opts.subpath}" if cmd_opts.subpath else ""
)