Merge remote-tracking branch 'origin/master'
This commit is contained in:
commit
56c83e453a
|
@ -0,0 +1,99 @@
|
|||
# File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
|
||||
# See more details in LICENSE.
|
||||
|
||||
model:
|
||||
base_learning_rate: 1.0e-04
|
||||
target: modules.models.diffusion.ddpm_edit.LatentDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: edited
|
||||
cond_stage_key: edit
|
||||
# image_size: 64
|
||||
# image_size: 32
|
||||
image_size: 16
|
||||
channels: 4
|
||||
cond_stage_trainable: false # Note: different from the one we trained before
|
||||
conditioning_key: hybrid
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: true
|
||||
load_ema: true
|
||||
|
||||
scheduler_config: # 10000 warmup steps
|
||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||
params:
|
||||
warm_up_steps: [ 0 ]
|
||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||
f_start: [ 1.e-6 ]
|
||||
f_max: [ 1. ]
|
||||
f_min: [ 1. ]
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
image_size: 32 # unused
|
||||
in_channels: 8
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_heads: 8
|
||||
use_spatial_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 768
|
||||
use_checkpoint: True
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||
|
||||
data:
|
||||
target: main.DataModuleFromConfig
|
||||
params:
|
||||
batch_size: 128
|
||||
num_workers: 1
|
||||
wrap: false
|
||||
validation:
|
||||
target: edit_dataset.EditDataset
|
||||
params:
|
||||
path: data/clip-filtered-dataset
|
||||
cache_dir: data/
|
||||
cache_name: data_10k
|
||||
split: val
|
||||
min_text_sim: 0.2
|
||||
min_image_sim: 0.75
|
||||
min_direction_sim: 0.2
|
||||
max_samples_per_prompt: 1
|
||||
min_resize_res: 512
|
||||
max_resize_res: 512
|
||||
crop_res: 512
|
||||
output_as_edit: False
|
||||
real_input: True
|
|
@ -1,8 +1,7 @@
|
|||
model:
|
||||
base_learning_rate: 1.0e-4
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
base_learning_rate: 7.5e-05
|
||||
target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
|
||||
params:
|
||||
parameterization: "v"
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
|
@ -12,29 +11,36 @@ model:
|
|||
cond_stage_key: "txt"
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: false
|
||||
conditioning_key: crossattn
|
||||
cond_stage_trainable: false # Note: different from the one we trained before
|
||||
conditioning_key: hybrid # important
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: False # we set this to false because this is an inference only config
|
||||
finetune_keys: null
|
||||
|
||||
scheduler_config: # 10000 warmup steps
|
||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||
params:
|
||||
warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch
|
||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||
f_start: [ 1.e-6 ]
|
||||
f_max: [ 1. ]
|
||||
f_min: [ 1. ]
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
use_checkpoint: True
|
||||
use_fp16: True
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
in_channels: 9 # 4 data + 4 downscaled image + 1 mask
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_head_channels: 64 # need to fix for flash-attn
|
||||
num_heads: 8
|
||||
use_spatial_transformer: True
|
||||
use_linear_in_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 1024
|
||||
context_dim: 768
|
||||
use_checkpoint: True
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
|
@ -43,7 +49,6 @@ model:
|
|||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
#attn_type: "vanilla-xformers"
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
|
@ -62,7 +67,4 @@ model:
|
|||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
||||
params:
|
||||
freeze: True
|
||||
layer: "penultimate"
|
||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
|
@ -18,7 +18,8 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_
|
|||
from modules.textual_inversion.preprocess import preprocess
|
||||
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
||||
from PIL import PngImagePlugin,Image
|
||||
from modules.sd_models import checkpoints_list, find_checkpoint_config
|
||||
from modules.sd_models import checkpoints_list
|
||||
from modules.sd_models_config import find_checkpoint_config_near_filename
|
||||
from modules.realesrgan_model import get_realesrgan_models
|
||||
from modules import devices
|
||||
from typing import List
|
||||
|
@ -387,7 +388,7 @@ class Api:
|
|||
]
|
||||
|
||||
def get_sd_models(self):
|
||||
return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()]
|
||||
return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()]
|
||||
|
||||
def get_hypernetworks(self):
|
||||
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
|
||||
|
|
|
@ -228,7 +228,7 @@ class SDModelItem(BaseModel):
|
|||
hash: Optional[str] = Field(title="Short hash")
|
||||
sha256: Optional[str] = Field(title="sha256 hash")
|
||||
filename: str = Field(title="Filename")
|
||||
config: str = Field(title="Config file")
|
||||
config: Optional[str] = Field(title="Config file")
|
||||
|
||||
class HypernetworkItem(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
|
|
|
@ -34,14 +34,18 @@ def get_cuda_device_string():
|
|||
return "cuda"
|
||||
|
||||
|
||||
def get_optimal_device():
|
||||
def get_optimal_device_name():
|
||||
if torch.cuda.is_available():
|
||||
return torch.device(get_cuda_device_string())
|
||||
return get_cuda_device_string()
|
||||
|
||||
if has_mps():
|
||||
return torch.device("mps")
|
||||
return "mps"
|
||||
|
||||
return cpu
|
||||
return "cpu"
|
||||
|
||||
|
||||
def get_optimal_device():
|
||||
return torch.device(get_optimal_device_name())
|
||||
|
||||
|
||||
def get_device_for(task):
|
||||
|
@ -139,6 +143,8 @@ def test_for_nans(x, where):
|
|||
else:
|
||||
message = "A tensor with all NaNs was produced."
|
||||
|
||||
message += " Use --disable-nan-check commandline argument to disable this check."
|
||||
|
||||
raise NansException(message)
|
||||
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ from skimage import exposure
|
|||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import modules.sd_hijack
|
||||
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx
|
||||
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx, scripts
|
||||
from modules.sd_hijack import model_hijack
|
||||
from modules.shared import opts, cmd_opts, state
|
||||
import modules.shared as shared
|
||||
|
@ -172,7 +172,7 @@ class StableDiffusionProcessing:
|
|||
midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
|
||||
midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
|
||||
|
||||
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image.to(devices.dtype_unet) if devices.unet_needs_upcast else source_image))
|
||||
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image.to(devices.dtype_vae) if devices.unet_needs_upcast else source_image))
|
||||
conditioning_image = conditioning_image.float() if devices.unet_needs_upcast else conditioning_image
|
||||
conditioning = torch.nn.functional.interpolate(
|
||||
self.sd_model.depth_model(midas_in),
|
||||
|
@ -185,7 +185,12 @@ class StableDiffusionProcessing:
|
|||
conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1.
|
||||
return conditioning
|
||||
|
||||
def inpainting_image_conditioning(self, source_image, latent_image, image_mask = None):
|
||||
def edit_image_conditioning(self, source_image):
|
||||
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image))
|
||||
|
||||
return conditioning_image
|
||||
|
||||
def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None):
|
||||
self.is_using_inpainting_conditioning = True
|
||||
|
||||
# Handle the different mask inputs
|
||||
|
@ -212,7 +217,7 @@ class StableDiffusionProcessing:
|
|||
)
|
||||
|
||||
# Encode the new masked image using first stage of network.
|
||||
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image.to(devices.dtype_unet) if devices.unet_needs_upcast else conditioning_image))
|
||||
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image.to(devices.dtype_vae) if devices.unet_needs_upcast else conditioning_image))
|
||||
|
||||
# Create the concatenated conditioning tensor to be fed to `c_concat`
|
||||
conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
|
||||
|
@ -228,6 +233,9 @@ class StableDiffusionProcessing:
|
|||
if isinstance(self.sd_model, LatentDepth2ImageDiffusion):
|
||||
return self.depth2img_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image)
|
||||
|
||||
if self.sd_model.cond_stage_key == "edit":
|
||||
return self.edit_image_conditioning(source_image)
|
||||
|
||||
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
|
||||
return self.inpainting_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image, latent_image, image_mask=image_mask)
|
||||
|
||||
|
@ -409,7 +417,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
|
|||
|
||||
def decode_first_stage(model, x):
|
||||
with devices.autocast(disable=x.dtype == devices.dtype_vae):
|
||||
x = model.decode_first_stage(x)
|
||||
x = model.decode_first_stage(x.to(devices.dtype_vae) if devices.unet_needs_upcast else x)
|
||||
|
||||
return x
|
||||
|
||||
|
@ -650,6 +658,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||
|
||||
image = Image.fromarray(x_sample)
|
||||
|
||||
if p.scripts is not None:
|
||||
pp = scripts.PostprocessImageArgs(image)
|
||||
p.scripts.postprocess_image(p, pp)
|
||||
image = pp.image
|
||||
|
||||
if p.color_corrections is not None and i < len(p.color_corrections):
|
||||
if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
|
||||
image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
|
||||
|
@ -993,7 +1006,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
|||
|
||||
image = torch.from_numpy(batch_images)
|
||||
image = 2. * image - 1.
|
||||
image = image.to(device=shared.device, dtype=devices.dtype_unet if devices.unet_needs_upcast else None)
|
||||
image = image.to(device=shared.device, dtype=devices.dtype_vae if devices.unet_needs_upcast else None)
|
||||
|
||||
self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
|
||||
|
||||
|
|
|
@ -6,12 +6,16 @@ from collections import namedtuple
|
|||
|
||||
import gradio as gr
|
||||
|
||||
from modules.processing import StableDiffusionProcessing
|
||||
from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing
|
||||
|
||||
AlwaysVisible = object()
|
||||
|
||||
|
||||
class PostprocessImageArgs:
|
||||
def __init__(self, image):
|
||||
self.image = image
|
||||
|
||||
|
||||
class Script:
|
||||
filename = None
|
||||
args_from = None
|
||||
|
@ -65,7 +69,7 @@ class Script:
|
|||
args contains all values returned by components from ui()
|
||||
"""
|
||||
|
||||
raise NotImplementedError()
|
||||
pass
|
||||
|
||||
def process(self, p, *args):
|
||||
"""
|
||||
|
@ -100,6 +104,13 @@ class Script:
|
|||
|
||||
pass
|
||||
|
||||
def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
|
||||
"""
|
||||
Called for every image after it has been generated.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
def postprocess(self, p, processed, *args):
|
||||
"""
|
||||
This function is called after processing ends for AlwaysVisible scripts.
|
||||
|
@ -247,11 +258,15 @@ class ScriptRunner:
|
|||
self.infotext_fields = []
|
||||
|
||||
def initialize_scripts(self, is_img2img):
|
||||
from modules import scripts_auto_postprocessing
|
||||
|
||||
self.scripts.clear()
|
||||
self.alwayson_scripts.clear()
|
||||
self.selectable_scripts.clear()
|
||||
|
||||
for script_class, path, basedir, script_module in scripts_data:
|
||||
auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()
|
||||
|
||||
for script_class, path, basedir, script_module in auto_processing_scripts + scripts_data:
|
||||
script = script_class()
|
||||
script.filename = path
|
||||
script.is_txt2img = not is_img2img
|
||||
|
@ -332,7 +347,7 @@ class ScriptRunner:
|
|||
|
||||
return inputs
|
||||
|
||||
def run(self, p: StableDiffusionProcessing, *args):
|
||||
def run(self, p, *args):
|
||||
script_index = args[0]
|
||||
|
||||
if script_index == 0:
|
||||
|
@ -386,6 +401,15 @@ class ScriptRunner:
|
|||
print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
def postprocess_image(self, p, pp: PostprocessImageArgs):
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
script_args = p.script_args[script.args_from:script.args_to]
|
||||
script.postprocess_image(p, pp, *script_args)
|
||||
except Exception:
|
||||
print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
def before_component(self, component, **kwargs):
|
||||
for script in self.scripts:
|
||||
try:
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
from modules import scripts, scripts_postprocessing, shared
|
||||
|
||||
|
||||
class ScriptPostprocessingForMainUI(scripts.Script):
|
||||
def __init__(self, script_postproc):
|
||||
self.script: scripts_postprocessing.ScriptPostprocessing = script_postproc
|
||||
self.postprocessing_controls = None
|
||||
|
||||
def title(self):
|
||||
return self.script.name
|
||||
|
||||
def show(self, is_img2img):
|
||||
return scripts.AlwaysVisible
|
||||
|
||||
def ui(self, is_img2img):
|
||||
self.postprocessing_controls = self.script.ui()
|
||||
return self.postprocessing_controls.values()
|
||||
|
||||
def postprocess_image(self, p, script_pp, *args):
|
||||
args_dict = {k: v for k, v in zip(self.postprocessing_controls, args)}
|
||||
|
||||
pp = scripts_postprocessing.PostprocessedImage(script_pp.image)
|
||||
pp.info = {}
|
||||
self.script.process(pp, **args_dict)
|
||||
p.extra_generation_params.update(pp.info)
|
||||
script_pp.image = pp.image
|
||||
|
||||
|
||||
def create_auto_preprocessing_script_data():
|
||||
from modules import scripts
|
||||
|
||||
res = []
|
||||
|
||||
for name in shared.opts.postprocessing_enable_in_main_ui:
|
||||
script = next(iter([x for x in scripts.postprocessing_scripts_data if x.script_class.name == name]), None)
|
||||
if script is None:
|
||||
continue
|
||||
|
||||
constructor = lambda s=script: ScriptPostprocessingForMainUI(s.script_class())
|
||||
res.append(scripts.ScriptClassData(script_class=constructor, path=script.path, basedir=script.basedir, module=script.module))
|
||||
|
||||
return res
|
|
@ -46,6 +46,8 @@ class ScriptPostprocessing:
|
|||
pass
|
||||
|
||||
|
||||
|
||||
|
||||
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
|
||||
try:
|
||||
res = func(*args, **kwargs)
|
||||
|
@ -68,6 +70,9 @@ class ScriptPostprocessingRunner:
|
|||
script: ScriptPostprocessing = script_class()
|
||||
script.filename = path
|
||||
|
||||
if script.name == "Simple Upscale":
|
||||
continue
|
||||
|
||||
self.scripts.append(script)
|
||||
|
||||
def create_script_ui(self, script, inputs):
|
||||
|
@ -87,12 +92,11 @@ class ScriptPostprocessingRunner:
|
|||
import modules.scripts
|
||||
self.initialize_scripts(modules.scripts.postprocessing_scripts_data)
|
||||
|
||||
scripts_order = [x.lower().strip() for x in shared.opts.postprocessing_scipts_order.split(",")]
|
||||
scripts_order = shared.opts.postprocessing_operation_order
|
||||
|
||||
def script_score(name):
|
||||
name = name.lower()
|
||||
for i, possible_match in enumerate(scripts_order):
|
||||
if possible_match in name:
|
||||
if possible_match == name:
|
||||
return i
|
||||
|
||||
return len(self.scripts)
|
||||
|
@ -145,3 +149,4 @@ class ScriptPostprocessingRunner:
|
|||
def image_changed(self):
|
||||
for script in self.scripts_in_preferred_order():
|
||||
script.image_changed()
|
||||
|
||||
|
|
|
@ -96,15 +96,6 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
|
|||
return x_prev, pred_x0, e_t
|
||||
|
||||
|
||||
def should_hijack_inpainting(checkpoint_info):
|
||||
from modules import sd_models
|
||||
|
||||
ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
|
||||
cfg_basename = os.path.basename(sd_models.find_checkpoint_config(checkpoint_info)).lower()
|
||||
|
||||
return "inpainting" in ckpt_basename and not "inpainting" in cfg_basename
|
||||
|
||||
|
||||
def do_inpainting_hijack():
|
||||
# p_sample_plms is needed because PLMS can't work with dicts as conditionings
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ class CondFunc:
|
|||
self = super(CondFunc, cls).__new__(cls)
|
||||
if isinstance(orig_func, str):
|
||||
func_path = orig_func.split('.')
|
||||
for i in range(len(func_path)-2, -1, -1):
|
||||
for i in range(len(func_path)-1, -1, -1):
|
||||
try:
|
||||
resolved_obj = importlib.import_module('.'.join(func_path[:i]))
|
||||
break
|
||||
|
|
|
@ -2,8 +2,6 @@ import collections
|
|||
import os.path
|
||||
import sys
|
||||
import gc
|
||||
import time
|
||||
from collections import namedtuple
|
||||
import torch
|
||||
import re
|
||||
import safetensors.torch
|
||||
|
@ -14,10 +12,10 @@ import ldm.modules.midas as midas
|
|||
|
||||
from ldm.util import instantiate_from_config
|
||||
|
||||
from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes
|
||||
from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
|
||||
from modules.paths import models_path
|
||||
from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
|
||||
from modules.sd_hijack_ip2p import should_hijack_ip2p
|
||||
from modules.sd_hijack_inpainting import do_inpainting_hijack
|
||||
from modules.timer import Timer
|
||||
|
||||
model_dir = "Stable-diffusion"
|
||||
model_path = os.path.abspath(os.path.join(models_path, model_dir))
|
||||
|
@ -99,17 +97,6 @@ def checkpoint_tiles():
|
|||
return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key)
|
||||
|
||||
|
||||
def find_checkpoint_config(info):
|
||||
if info is None:
|
||||
return shared.cmd_opts.config
|
||||
|
||||
config = os.path.splitext(info.filename)[0] + ".yaml"
|
||||
if os.path.exists(config):
|
||||
return config
|
||||
|
||||
return shared.cmd_opts.config
|
||||
|
||||
|
||||
def list_models():
|
||||
checkpoints_list.clear()
|
||||
checkpoint_alisases.clear()
|
||||
|
@ -215,9 +202,7 @@ def get_state_dict_from_checkpoint(pl_sd):
|
|||
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
|
||||
_, extension = os.path.splitext(checkpoint_file)
|
||||
if extension.lower() == ".safetensors":
|
||||
device = map_location or shared.weight_load_location
|
||||
if device is None:
|
||||
device = devices.get_cuda_device_string() if torch.cuda.is_available() else "cpu"
|
||||
device = map_location or shared.weight_load_location or devices.get_optimal_device_name()
|
||||
pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
|
||||
else:
|
||||
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
|
||||
|
@ -229,60 +214,74 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None
|
|||
return sd
|
||||
|
||||
|
||||
def load_model_weights(model, checkpoint_info: CheckpointInfo):
|
||||
def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
|
||||
sd_model_hash = checkpoint_info.calculate_shorthash()
|
||||
timer.record("calculate hash")
|
||||
|
||||
if checkpoint_info in checkpoints_loaded:
|
||||
# use checkpoint cache
|
||||
print(f"Loading weights [{sd_model_hash}] from cache")
|
||||
return checkpoints_loaded[checkpoint_info]
|
||||
|
||||
print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
|
||||
res = read_state_dict(checkpoint_info.filename)
|
||||
timer.record("load weights from disk")
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
|
||||
title = checkpoint_info.title
|
||||
sd_model_hash = checkpoint_info.calculate_shorthash()
|
||||
timer.record("calculate hash")
|
||||
|
||||
if checkpoint_info.title != title:
|
||||
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
|
||||
|
||||
cache_enabled = shared.opts.sd_checkpoint_cache > 0
|
||||
if state_dict is None:
|
||||
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
||||
|
||||
if cache_enabled and checkpoint_info in checkpoints_loaded:
|
||||
# use checkpoint cache
|
||||
print(f"Loading weights [{sd_model_hash}] from cache")
|
||||
model.load_state_dict(checkpoints_loaded[checkpoint_info])
|
||||
else:
|
||||
# load from file
|
||||
print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
del state_dict
|
||||
timer.record("apply weights to model")
|
||||
|
||||
sd = read_state_dict(checkpoint_info.filename)
|
||||
model.load_state_dict(sd, strict=False)
|
||||
del sd
|
||||
|
||||
if cache_enabled:
|
||||
# cache newly loaded model
|
||||
checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
|
||||
if shared.opts.sd_checkpoint_cache > 0:
|
||||
# cache newly loaded model
|
||||
checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
|
||||
|
||||
if shared.cmd_opts.opt_channelslast:
|
||||
model.to(memory_format=torch.channels_last)
|
||||
if shared.cmd_opts.opt_channelslast:
|
||||
model.to(memory_format=torch.channels_last)
|
||||
timer.record("apply channels_last")
|
||||
|
||||
if not shared.cmd_opts.no_half:
|
||||
vae = model.first_stage_model
|
||||
depth_model = getattr(model, 'depth_model', None)
|
||||
if not shared.cmd_opts.no_half:
|
||||
vae = model.first_stage_model
|
||||
depth_model = getattr(model, 'depth_model', None)
|
||||
|
||||
# with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
|
||||
if shared.cmd_opts.no_half_vae:
|
||||
model.first_stage_model = None
|
||||
# with --upcast-sampling, don't convert the depth model weights to float16
|
||||
if shared.cmd_opts.upcast_sampling and depth_model:
|
||||
model.depth_model = None
|
||||
# with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
|
||||
if shared.cmd_opts.no_half_vae:
|
||||
model.first_stage_model = None
|
||||
# with --upcast-sampling, don't convert the depth model weights to float16
|
||||
if shared.cmd_opts.upcast_sampling and depth_model:
|
||||
model.depth_model = None
|
||||
|
||||
model.half()
|
||||
model.first_stage_model = vae
|
||||
if depth_model:
|
||||
model.depth_model = depth_model
|
||||
model.half()
|
||||
model.first_stage_model = vae
|
||||
if depth_model:
|
||||
model.depth_model = depth_model
|
||||
|
||||
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
|
||||
devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
|
||||
devices.dtype_unet = model.model.diffusion_model.dtype
|
||||
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
|
||||
timer.record("apply half()")
|
||||
|
||||
model.first_stage_model.to(devices.dtype_vae)
|
||||
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
|
||||
devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
|
||||
devices.dtype_unet = model.model.diffusion_model.dtype
|
||||
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
|
||||
|
||||
model.first_stage_model.to(devices.dtype_vae)
|
||||
timer.record("apply dtype to VAE")
|
||||
|
||||
# clean up cache if limit is reached
|
||||
if cache_enabled:
|
||||
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache + 1: # we need to count the current model
|
||||
checkpoints_loaded.popitem(last=False) # LRU
|
||||
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
|
||||
checkpoints_loaded.popitem(last=False)
|
||||
|
||||
model.sd_model_hash = sd_model_hash
|
||||
model.sd_model_checkpoint = checkpoint_info.filename
|
||||
|
@ -295,6 +294,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo):
|
|||
sd_vae.clear_loaded_vae()
|
||||
vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename)
|
||||
sd_vae.load_vae(model, vae_file, vae_source)
|
||||
timer.record("load VAE")
|
||||
|
||||
|
||||
def enable_midas_autodownload():
|
||||
|
@ -340,24 +340,20 @@ def enable_midas_autodownload():
|
|||
midas.api.load_model = load_model_wrapper
|
||||
|
||||
|
||||
class Timer:
|
||||
def __init__(self):
|
||||
self.start = time.time()
|
||||
def repair_config(sd_config):
|
||||
|
||||
def elapsed(self):
|
||||
end = time.time()
|
||||
res = end - self.start
|
||||
self.start = end
|
||||
return res
|
||||
if not hasattr(sd_config.model.params, "use_ema"):
|
||||
sd_config.model.params.use_ema = False
|
||||
|
||||
if shared.cmd_opts.no_half:
|
||||
sd_config.model.params.unet_config.params.use_fp16 = False
|
||||
elif shared.cmd_opts.upcast_sampling:
|
||||
sd_config.model.params.unet_config.params.use_fp16 = True
|
||||
|
||||
|
||||
def load_model(checkpoint_info=None):
|
||||
def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None):
|
||||
from modules import lowvram, sd_hijack
|
||||
checkpoint_info = checkpoint_info or select_checkpoint()
|
||||
checkpoint_config = find_checkpoint_config(checkpoint_info)
|
||||
|
||||
if checkpoint_config != shared.cmd_opts.config:
|
||||
print(f"Loading config from: {checkpoint_config}")
|
||||
|
||||
if shared.sd_model:
|
||||
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
|
||||
|
@ -365,38 +361,27 @@ def load_model(checkpoint_info=None):
|
|||
gc.collect()
|
||||
devices.torch_gc()
|
||||
|
||||
sd_config = OmegaConf.load(checkpoint_config)
|
||||
|
||||
if should_hijack_inpainting(checkpoint_info):
|
||||
# Hardcoded config for now...
|
||||
sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
|
||||
sd_config.model.params.conditioning_key = "hybrid"
|
||||
sd_config.model.params.unet_config.params.in_channels = 9
|
||||
sd_config.model.params.finetune_keys = None
|
||||
|
||||
if should_hijack_ip2p(checkpoint_info):
|
||||
sd_config.model.target = "modules.models.diffusion.ddpm_edit.LatentDiffusion"
|
||||
sd_config.model.params.conditioning_key = "hybrid"
|
||||
sd_config.model.params.first_stage_key = "edited"
|
||||
sd_config.model.params.cond_stage_key = "edit"
|
||||
sd_config.model.params.image_size = 16
|
||||
sd_config.model.params.unet_config.params.in_channels = 8
|
||||
sd_config.model.params.unet_config.params.out_channels = 4
|
||||
|
||||
if not hasattr(sd_config.model.params, "use_ema"):
|
||||
sd_config.model.params.use_ema = False
|
||||
|
||||
do_inpainting_hijack()
|
||||
|
||||
if shared.cmd_opts.no_half:
|
||||
sd_config.model.params.unet_config.params.use_fp16 = False
|
||||
elif shared.cmd_opts.upcast_sampling:
|
||||
sd_config.model.params.unet_config.params.use_fp16 = True
|
||||
|
||||
timer = Timer()
|
||||
|
||||
sd_model = None
|
||||
if already_loaded_state_dict is not None:
|
||||
state_dict = already_loaded_state_dict
|
||||
else:
|
||||
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
||||
|
||||
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
|
||||
|
||||
timer.record("find config")
|
||||
|
||||
sd_config = OmegaConf.load(checkpoint_config)
|
||||
repair_config(sd_config)
|
||||
|
||||
timer.record("load config")
|
||||
|
||||
print(f"Creating model from config: {checkpoint_config}")
|
||||
|
||||
sd_model = None
|
||||
try:
|
||||
with sd_disable_initialization.DisableInitialization():
|
||||
sd_model = instantiate_from_config(sd_config.model)
|
||||
|
@ -407,29 +392,35 @@ def load_model(checkpoint_info=None):
|
|||
print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
|
||||
sd_model = instantiate_from_config(sd_config.model)
|
||||
|
||||
elapsed_create = timer.elapsed()
|
||||
sd_model.used_config = checkpoint_config
|
||||
|
||||
load_model_weights(sd_model, checkpoint_info)
|
||||
timer.record("create model")
|
||||
|
||||
elapsed_load_weights = timer.elapsed()
|
||||
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
||||
|
||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||
lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
|
||||
else:
|
||||
sd_model.to(shared.device)
|
||||
|
||||
timer.record("move model to device")
|
||||
|
||||
sd_hijack.model_hijack.hijack(sd_model)
|
||||
|
||||
timer.record("hijack")
|
||||
|
||||
sd_model.eval()
|
||||
shared.sd_model = sd_model
|
||||
|
||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
|
||||
|
||||
timer.record("load textual inversion embeddings")
|
||||
|
||||
script_callbacks.model_loaded_callback(sd_model)
|
||||
|
||||
elapsed_the_rest = timer.elapsed()
|
||||
timer.record("scripts callbacks")
|
||||
|
||||
print(f"Model loaded in {elapsed_create + elapsed_load_weights + elapsed_the_rest:.1f}s ({elapsed_create:.1f}s create model, {elapsed_load_weights:.1f}s load weights).")
|
||||
print(f"Model loaded in {timer.summary()}.")
|
||||
|
||||
return sd_model
|
||||
|
||||
|
@ -440,6 +431,7 @@ def reload_model_weights(sd_model=None, info=None):
|
|||
|
||||
if not sd_model:
|
||||
sd_model = shared.sd_model
|
||||
|
||||
if sd_model is None: # previous model load failed
|
||||
current_checkpoint_info = None
|
||||
else:
|
||||
|
@ -447,38 +439,44 @@ def reload_model_weights(sd_model=None, info=None):
|
|||
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
||||
return
|
||||
|
||||
checkpoint_config = find_checkpoint_config(current_checkpoint_info)
|
||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||
lowvram.send_everything_to_cpu()
|
||||
else:
|
||||
sd_model.to(devices.cpu)
|
||||
|
||||
if current_checkpoint_info is None or checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info) or should_hijack_ip2p(checkpoint_info) != should_hijack_ip2p(sd_model.sd_checkpoint_info):
|
||||
del sd_model
|
||||
checkpoints_loaded.clear()
|
||||
load_model(checkpoint_info)
|
||||
return shared.sd_model
|
||||
|
||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||
lowvram.send_everything_to_cpu()
|
||||
else:
|
||||
sd_model.to(devices.cpu)
|
||||
|
||||
sd_hijack.model_hijack.undo_hijack(sd_model)
|
||||
sd_hijack.model_hijack.undo_hijack(sd_model)
|
||||
|
||||
timer = Timer()
|
||||
|
||||
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
||||
|
||||
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
|
||||
|
||||
timer.record("find config")
|
||||
|
||||
if sd_model is None or checkpoint_config != sd_model.used_config:
|
||||
del sd_model
|
||||
checkpoints_loaded.clear()
|
||||
load_model(checkpoint_info, already_loaded_state_dict=state_dict, time_taken_to_load_state_dict=timer.records["load weights from disk"])
|
||||
return shared.sd_model
|
||||
|
||||
try:
|
||||
load_model_weights(sd_model, checkpoint_info)
|
||||
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
||||
except Exception as e:
|
||||
print("Failed to load checkpoint, restoring previous")
|
||||
load_model_weights(sd_model, current_checkpoint_info)
|
||||
load_model_weights(sd_model, current_checkpoint_info, None, timer)
|
||||
raise
|
||||
finally:
|
||||
sd_hijack.model_hijack.hijack(sd_model)
|
||||
timer.record("hijack")
|
||||
|
||||
script_callbacks.model_loaded_callback(sd_model)
|
||||
timer.record("script callbacks")
|
||||
|
||||
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
|
||||
sd_model.to(devices.device)
|
||||
timer.record("move model to device")
|
||||
|
||||
elapsed = timer.elapsed()
|
||||
|
||||
print(f"Weights loaded in {elapsed:.1f}s.")
|
||||
print(f"Weights loaded in {timer.summary()}.")
|
||||
|
||||
return sd_model
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
import re
|
||||
import os
|
||||
|
||||
from modules import shared, paths
|
||||
|
||||
sd_configs_path = shared.sd_configs_path
|
||||
sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion")
|
||||
|
||||
|
||||
config_default = shared.sd_default_config
|
||||
config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
|
||||
config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
|
||||
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
|
||||
config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml")
|
||||
config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
|
||||
config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
|
||||
|
||||
re_parametrization_v = re.compile(r'-v\b')
|
||||
|
||||
|
||||
def guess_model_config_from_state_dict(sd, filename):
|
||||
fn = os.path.basename(filename)
|
||||
|
||||
sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None)
|
||||
diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
|
||||
|
||||
if sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
|
||||
return config_depth_model
|
||||
|
||||
if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024:
|
||||
if re.search(re_parametrization_v, fn) or "v2-1_768" in fn:
|
||||
return config_sd2v
|
||||
else:
|
||||
return config_sd2
|
||||
|
||||
if diffusion_model_input is not None:
|
||||
if diffusion_model_input.shape[1] == 9:
|
||||
return config_inpainting
|
||||
if diffusion_model_input.shape[1] == 8:
|
||||
return config_instruct_pix2pix
|
||||
|
||||
if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None:
|
||||
return config_alt_diffusion
|
||||
|
||||
return config_default
|
||||
|
||||
|
||||
def find_checkpoint_config(state_dict, info):
|
||||
if info is None:
|
||||
return guess_model_config_from_state_dict(state_dict, "")
|
||||
|
||||
config = find_checkpoint_config_near_filename(info)
|
||||
if config is not None:
|
||||
return config
|
||||
|
||||
return guess_model_config_from_state_dict(state_dict, info.filename)
|
||||
|
||||
|
||||
def find_checkpoint_config_near_filename(info):
|
||||
if info is None:
|
||||
return None
|
||||
|
||||
config = os.path.splitext(info.filename)[0] + ".yaml"
|
||||
if os.path.exists(config):
|
||||
return config
|
||||
|
||||
return None
|
||||
|
|
@ -454,7 +454,7 @@ class KDiffusionSampler:
|
|||
def initialize(self, p):
|
||||
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
|
||||
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||
self.model_wrap.step = 0
|
||||
self.model_wrap_cfg.step = 0
|
||||
self.eta = p.eta or opts.eta_ancestral
|
||||
|
||||
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
|
||||
|
|
|
@ -13,13 +13,14 @@ import modules.interrogate
|
|||
import modules.memmon
|
||||
import modules.styles
|
||||
import modules.devices as devices
|
||||
from modules import localization, sd_vae, extensions, script_loading, errors, ui_components
|
||||
from modules.paths import models_path, script_path, sd_path
|
||||
from modules import localization, extensions, script_loading, errors, ui_components, shared_items
|
||||
from modules.paths import models_path, script_path
|
||||
|
||||
|
||||
demo = None
|
||||
|
||||
sd_default_config = os.path.join(script_path, "configs/v1-inference.yaml")
|
||||
sd_configs_path = os.path.join(script_path, "configs")
|
||||
sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml")
|
||||
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
||||
default_sd_model_file = sd_model_file
|
||||
|
||||
|
@ -264,12 +265,6 @@ interrogator = modules.interrogate.InterrogateModels("interrogate")
|
|||
|
||||
face_restorers = []
|
||||
|
||||
|
||||
def realesrgan_models_names():
|
||||
import modules.realesrgan_model
|
||||
return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)]
|
||||
|
||||
|
||||
class OptionInfo:
|
||||
def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None):
|
||||
self.default = default
|
||||
|
@ -360,7 +355,7 @@ options_templates.update(options_section(('saving-to-dirs', "Saving to a directo
|
|||
options_templates.update(options_section(('upscaling', "Upscaling"), {
|
||||
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
|
||||
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
|
||||
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}),
|
||||
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
|
||||
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
|
||||
}))
|
||||
|
||||
|
@ -397,7 +392,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
|||
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
|
||||
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, refresh=sd_vae.refresh_vae_list),
|
||||
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list),
|
||||
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
|
||||
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),
|
||||
|
@ -483,7 +478,8 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
|
|||
}))
|
||||
|
||||
options_templates.update(options_section(('postprocessing', "Postprocessing"), {
|
||||
'postprocessing_scipts_order': OptionInfo("upscale, gfpgan, codeformer", "Postprocessing operation order"),
|
||||
'postprocessing_enable_in_main_ui': OptionInfo([], "Enable postprocessing operations in txt2img and img2img tabs", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
|
||||
'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
|
||||
'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||
}))
|
||||
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
|
||||
|
||||
def realesrgan_models_names():
|
||||
import modules.realesrgan_model
|
||||
return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)]
|
||||
|
||||
|
||||
def postprocessing_scripts():
|
||||
import modules.scripts
|
||||
|
||||
return modules.scripts.scripts_postproc.scripts
|
||||
|
||||
|
||||
def sd_vae_items():
|
||||
import modules.sd_vae
|
||||
|
||||
return ["Automatic", "None"] + list(modules.sd_vae.vae_dict)
|
||||
|
||||
|
||||
def refresh_vae_list():
|
||||
import modules.sd_vae
|
||||
|
||||
return modules.sd_vae.refresh_vae_list
|
|
@ -194,7 +194,7 @@ class EmbeddingDatabase:
|
|||
if not os.path.isdir(embdir.path):
|
||||
return
|
||||
|
||||
for root, dirs, fns in os.walk(embdir.path):
|
||||
for root, dirs, fns in os.walk(embdir.path, followlinks=True):
|
||||
for fn in fns:
|
||||
try:
|
||||
fullfn = os.path.join(root, fn)
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
import time
|
||||
|
||||
|
||||
class Timer:
|
||||
def __init__(self):
|
||||
self.start = time.time()
|
||||
self.records = {}
|
||||
self.total = 0
|
||||
|
||||
def elapsed(self):
|
||||
end = time.time()
|
||||
res = end - self.start
|
||||
self.start = end
|
||||
return res
|
||||
|
||||
def record(self, category, extra_time=0):
|
||||
e = self.elapsed()
|
||||
if category not in self.records:
|
||||
self.records[category] = 0
|
||||
|
||||
self.records[category] += e + extra_time
|
||||
self.total += e + extra_time
|
||||
|
||||
def summary(self):
|
||||
res = f"{self.total:.1f}s"
|
||||
|
||||
additions = [x for x in self.records.items() if x[1] >= 0.1]
|
||||
if not additions:
|
||||
return res
|
||||
|
||||
res += " ("
|
||||
res += ", ".join([f"{category}: {time_taken:.1f}s" for category, time_taken in additions])
|
||||
res += ")"
|
||||
|
||||
return res
|
|
@ -48,3 +48,11 @@ class FormColorPicker(gr.ColorPicker, gr.components.FormComponent):
|
|||
def get_block_name(self):
|
||||
return "colorpicker"
|
||||
|
||||
|
||||
class DropdownMulti(gr.Dropdown):
|
||||
"""Same as gr.Dropdown but always multiselect"""
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(multiselect=True, **kwargs)
|
||||
|
||||
def get_block_name(self):
|
||||
return "dropdown"
|
||||
|
|
|
@ -104,3 +104,28 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
|
|||
|
||||
def image_changed(self):
|
||||
upscale_cache.clear()
|
||||
|
||||
|
||||
class ScriptPostprocessingUpscaleSimple(ScriptPostprocessingUpscale):
|
||||
name = "Simple Upscale"
|
||||
order = 900
|
||||
|
||||
def ui(self):
|
||||
with FormRow():
|
||||
upscaler_name = gr.Dropdown(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
|
||||
upscale_by = gr.Slider(minimum=0.05, maximum=8.0, step=0.05, label="Upscale by", value=2)
|
||||
|
||||
return {
|
||||
"upscale_by": upscale_by,
|
||||
"upscaler_name": upscaler_name,
|
||||
}
|
||||
|
||||
def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_by=2.0, upscaler_name=None):
|
||||
if upscaler_name is None or upscaler_name == "None":
|
||||
return
|
||||
|
||||
upscaler1 = next(iter([x for x in shared.sd_upscalers if x.name == upscaler_name]), None)
|
||||
assert upscaler1, f'could not find upscaler named {upscaler_name}'
|
||||
|
||||
pp.image = self.upscale(pp.image, pp.info, upscaler1, 0, upscale_by, 0, 0, False)
|
||||
pp.info[f"Postprocess upscaler"] = upscaler1.name
|
||||
|
|
|
@ -164,7 +164,7 @@
|
|||
min-height: 3.2em;
|
||||
}
|
||||
|
||||
#txt2img_styles ul, #img2img_styles ul{
|
||||
ul.list-none{
|
||||
max-height: 35em;
|
||||
z-index: 2000;
|
||||
}
|
||||
|
@ -714,9 +714,6 @@ footer {
|
|||
white-space: nowrap;
|
||||
min-width: auto;
|
||||
}
|
||||
#txt2img_hires_fix{
|
||||
margin-left: -0.8em;
|
||||
}
|
||||
|
||||
#img2img_copy_to_img2img, #img2img_copy_to_sketch, #img2img_copy_to_inpaint, #img2img_copy_to_inpaint_sketch{
|
||||
margin-left: 0em;
|
||||
|
@ -744,7 +741,6 @@ footer {
|
|||
|
||||
.dark .gr-compact{
|
||||
background-color: rgb(31 41 55 / var(--tw-bg-opacity));
|
||||
margin-left: 0.8em;
|
||||
}
|
||||
|
||||
.gr-compact{
|
||||
|
|
|
@ -10,7 +10,7 @@ then
|
|||
fi
|
||||
|
||||
export install_dir="$HOME"
|
||||
export COMMANDLINE_ARGS="--skip-torch-cuda-test --no-half --use-cpu interrogate"
|
||||
export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --use-cpu interrogate"
|
||||
export TORCH_COMMAND="pip install torch==1.12.1 torchvision==0.13.1"
|
||||
export K_DIFFUSION_REPO="https://github.com/brkirch/k-diffusion.git"
|
||||
export K_DIFFUSION_COMMIT_HASH="51c9778f269cedb55a4d88c79c0246d35bdadb71"
|
||||
|
|
Loading…
Reference in New Issue