add support for switching model checkpoints at runtime
This commit is contained in:
parent
b8be33dad1
commit
247f58a5e7
|
@ -274,7 +274,7 @@ def apply_filename_pattern(x, p, seed, prompt):
|
||||||
x = x.replace("[height]", str(p.height))
|
x = x.replace("[height]", str(p.height))
|
||||||
x = x.replace("[sampler]", sd_samplers.samplers[p.sampler_index].name)
|
x = x.replace("[sampler]", sd_samplers.samplers[p.sampler_index].name)
|
||||||
|
|
||||||
x = x.replace("[model_hash]", shared.sd_model_hash)
|
x = x.replace("[model_hash]", shared.sd_model.sd_model_hash)
|
||||||
x = x.replace("[date]", datetime.date.today().isoformat())
|
x = x.replace("[date]", datetime.date.today().isoformat())
|
||||||
|
|
||||||
if cmd_opts.hide_ui_dir_config:
|
if cmd_opts.hide_ui_dir_config:
|
||||||
|
|
|
@ -227,7 +227,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||||
"Seed": all_seeds[index],
|
"Seed": all_seeds[index],
|
||||||
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
|
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
|
||||||
"Size": f"{p.width}x{p.height}",
|
"Size": f"{p.width}x{p.height}",
|
||||||
"Model hash": (None if not opts.add_model_hash_to_info or not shared.sd_model_hash else shared.sd_model_hash),
|
"Model hash": (None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
|
||||||
"Batch size": (None if p.batch_size < 2 else p.batch_size),
|
"Batch size": (None if p.batch_size < 2 else p.batch_size),
|
||||||
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
|
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
|
||||||
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
|
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
|
||||||
|
|
|
@ -0,0 +1,148 @@
|
||||||
|
import glob
|
||||||
|
import os.path
|
||||||
|
import sys
|
||||||
|
from collections import namedtuple
|
||||||
|
import torch
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
|
||||||
|
from ldm.util import instantiate_from_config
|
||||||
|
|
||||||
|
from modules import shared
|
||||||
|
|
||||||
|
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash'])
|
||||||
|
checkpoints_list = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
||||||
|
|
||||||
|
from transformers import logging
|
||||||
|
|
||||||
|
logging.set_verbosity_error()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def list_models():
|
||||||
|
checkpoints_list.clear()
|
||||||
|
|
||||||
|
model_dir = os.path.abspath(shared.cmd_opts.ckpt_dir)
|
||||||
|
|
||||||
|
def modeltitle(path, h):
|
||||||
|
abspath = os.path.abspath(path)
|
||||||
|
|
||||||
|
if abspath.startswith(model_dir):
|
||||||
|
name = abspath.replace(model_dir, '')
|
||||||
|
else:
|
||||||
|
name = os.path.basename(path)
|
||||||
|
|
||||||
|
if name.startswith("\\") or name.startswith("/"):
|
||||||
|
name = name[1:]
|
||||||
|
|
||||||
|
return f'{name} [{h}]'
|
||||||
|
|
||||||
|
cmd_ckpt = shared.cmd_opts.ckpt
|
||||||
|
if os.path.exists(cmd_ckpt):
|
||||||
|
h = model_hash(cmd_ckpt)
|
||||||
|
title = modeltitle(cmd_ckpt, h)
|
||||||
|
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h)
|
||||||
|
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
|
||||||
|
print(f"Checkpoint in --ckpt argument not found: {cmd_ckpt}", file=sys.stderr)
|
||||||
|
|
||||||
|
if os.path.exists(model_dir):
|
||||||
|
for filename in glob.glob(model_dir + '/**/*.ckpt', recursive=True):
|
||||||
|
h = model_hash(filename)
|
||||||
|
title = modeltitle(filename, h)
|
||||||
|
checkpoints_list[title] = CheckpointInfo(filename, title, h)
|
||||||
|
|
||||||
|
|
||||||
|
def model_hash(filename):
|
||||||
|
try:
|
||||||
|
with open(filename, "rb") as file:
|
||||||
|
import hashlib
|
||||||
|
m = hashlib.sha256()
|
||||||
|
|
||||||
|
file.seek(0x100000)
|
||||||
|
m.update(file.read(0x10000))
|
||||||
|
return m.hexdigest()[0:8]
|
||||||
|
except FileNotFoundError:
|
||||||
|
return 'NOFILE'
|
||||||
|
|
||||||
|
|
||||||
|
def select_checkpoint():
|
||||||
|
model_checkpoint = shared.opts.sd_model_checkpoint
|
||||||
|
checkpoint_info = checkpoints_list.get(model_checkpoint, None)
|
||||||
|
if checkpoint_info is not None:
|
||||||
|
return checkpoint_info
|
||||||
|
|
||||||
|
if len(checkpoints_list) == 0:
|
||||||
|
print(f"Checkpoint {model_checkpoint} not found and no other checkpoints found", file=sys.stderr)
|
||||||
|
return None
|
||||||
|
|
||||||
|
checkpoint_info = next(iter(checkpoints_list.values()))
|
||||||
|
if model_checkpoint is not None:
|
||||||
|
print(f"Checkpoint {model_checkpoint} not found; loading fallback {checkpoint_info.title}", file=sys.stderr)
|
||||||
|
|
||||||
|
return checkpoint_info
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_weights(model, checkpoint_file, sd_model_hash):
|
||||||
|
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
|
||||||
|
|
||||||
|
pl_sd = torch.load(checkpoint_file, map_location="cpu")
|
||||||
|
if "global_step" in pl_sd:
|
||||||
|
print(f"Global Step: {pl_sd['global_step']}")
|
||||||
|
sd = pl_sd["state_dict"]
|
||||||
|
|
||||||
|
model.load_state_dict(sd, strict=False)
|
||||||
|
|
||||||
|
if shared.cmd_opts.opt_channelslast:
|
||||||
|
model.to(memory_format=torch.channels_last)
|
||||||
|
|
||||||
|
if not shared.cmd_opts.no_half:
|
||||||
|
model.half()
|
||||||
|
|
||||||
|
model.sd_model_hash = sd_model_hash
|
||||||
|
model.sd_model_checkpint = checkpoint_file
|
||||||
|
|
||||||
|
|
||||||
|
def load_model():
|
||||||
|
from modules import lowvram, sd_hijack
|
||||||
|
checkpoint_info = select_checkpoint()
|
||||||
|
|
||||||
|
sd_config = OmegaConf.load(shared.cmd_opts.config)
|
||||||
|
sd_model = instantiate_from_config(sd_config.model)
|
||||||
|
load_model_weights(sd_model, checkpoint_info.filename, checkpoint_info.hash)
|
||||||
|
|
||||||
|
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||||
|
lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
|
||||||
|
else:
|
||||||
|
sd_model.to(shared.device)
|
||||||
|
|
||||||
|
sd_hijack.model_hijack.hijack(sd_model)
|
||||||
|
|
||||||
|
sd_model.eval()
|
||||||
|
|
||||||
|
print(f"Model loaded.")
|
||||||
|
return sd_model
|
||||||
|
|
||||||
|
|
||||||
|
def reload_model_weights(sd_model):
|
||||||
|
from modules import lowvram, devices
|
||||||
|
checkpoint_info = select_checkpoint()
|
||||||
|
|
||||||
|
if sd_model.sd_model_checkpint == checkpoint_info.filename:
|
||||||
|
return
|
||||||
|
|
||||||
|
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||||
|
lowvram.send_everything_to_cpu()
|
||||||
|
else:
|
||||||
|
sd_model.to(devices.cpu)
|
||||||
|
|
||||||
|
load_model_weights(sd_model, checkpoint_info.filename, checkpoint_info.hash)
|
||||||
|
|
||||||
|
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
|
||||||
|
sd_model.to(devices.device)
|
||||||
|
|
||||||
|
print(f"Weights loaded.")
|
||||||
|
return sd_model
|
|
@ -13,14 +13,15 @@ from modules.devices import get_optimal_device
|
||||||
import modules.styles
|
import modules.styles
|
||||||
import modules.interrogate
|
import modules.interrogate
|
||||||
import modules.memmon
|
import modules.memmon
|
||||||
|
import modules.sd_models
|
||||||
|
|
||||||
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
||||||
if not os.path.exists(sd_model_file):
|
default_sd_model_file = sd_model_file
|
||||||
sd_model_file = "models/ldm/stable-diffusion-v1/model.ckpt"
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",)
|
parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",)
|
||||||
parser.add_argument("--ckpt", type=str, default=os.path.join(sd_path, sd_model_file), help="path to checkpoint of model",)
|
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; this checkpoint will be added to the list of checkpoints and loaded by default if you don't have a checkpoint selected in settings",)
|
||||||
|
parser.add_argument("--ckpt-dir", type=str, default=os.path.join(script_path, 'models'), help="path to directory with stable diffusion checkpoints",)
|
||||||
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
|
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
|
||||||
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default='GFPGANv1.3.pth')
|
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default='GFPGANv1.3.pth')
|
||||||
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
|
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
|
||||||
|
@ -88,13 +89,17 @@ interrogator = modules.interrogate.InterrogateModels("interrogate")
|
||||||
|
|
||||||
face_restorers = []
|
face_restorers = []
|
||||||
|
|
||||||
|
modules.sd_models.list_models()
|
||||||
|
|
||||||
|
|
||||||
class Options:
|
class Options:
|
||||||
class OptionInfo:
|
class OptionInfo:
|
||||||
def __init__(self, default=None, label="", component=None, component_args=None):
|
def __init__(self, default=None, label="", component=None, component_args=None, onchange=None):
|
||||||
self.default = default
|
self.default = default
|
||||||
self.label = label
|
self.label = label
|
||||||
self.component = component
|
self.component = component
|
||||||
self.component_args = component_args
|
self.component_args = component_args
|
||||||
|
self.onchange = onchange
|
||||||
|
|
||||||
data = None
|
data = None
|
||||||
hide_dirs = {"visible": False} if cmd_opts.hide_ui_dir_config else None
|
hide_dirs = {"visible": False} if cmd_opts.hide_ui_dir_config else None
|
||||||
|
@ -150,6 +155,7 @@ class Options:
|
||||||
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
|
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
|
||||||
"interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
|
"interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
|
||||||
"interrogate_clip_dict_limit": OptionInfo(1500, "Interrogate: maximum number of lines in text file (0 = No limit)"),
|
"interrogate_clip_dict_limit": OptionInfo(1500, "Interrogate: maximum number of lines in text file (0 = No limit)"),
|
||||||
|
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Radio, lambda: {"choices": [x.title for x in modules.sd_models.checkpoints_list.values()]}),
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -180,6 +186,10 @@ class Options:
|
||||||
with open(filename, "r", encoding="utf8") as file:
|
with open(filename, "r", encoding="utf8") as file:
|
||||||
self.data = json.load(file)
|
self.data = json.load(file)
|
||||||
|
|
||||||
|
def onchange(self, key, func):
|
||||||
|
item = self.data_labels.get(key)
|
||||||
|
item.onchange = func
|
||||||
|
|
||||||
|
|
||||||
opts = Options()
|
opts = Options()
|
||||||
if os.path.exists(config_filename):
|
if os.path.exists(config_filename):
|
||||||
|
@ -188,7 +198,6 @@ if os.path.exists(config_filename):
|
||||||
sd_upscalers = []
|
sd_upscalers = []
|
||||||
|
|
||||||
sd_model = None
|
sd_model = None
|
||||||
sd_model_hash = ''
|
|
||||||
|
|
||||||
progress_print_out = sys.stdout
|
progress_print_out = sys.stdout
|
||||||
|
|
||||||
|
|
|
@ -758,7 +758,12 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
|
||||||
if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
|
if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
oldval = opts.data.get(key, None)
|
||||||
opts.data[key] = value
|
opts.data[key] = value
|
||||||
|
|
||||||
|
if oldval != value and opts.data_labels[key].onchange is not None:
|
||||||
|
opts.data_labels[key].onchange()
|
||||||
|
|
||||||
up.append(comp.update(value=value))
|
up.append(comp.update(value=value))
|
||||||
|
|
||||||
opts.save(shared.config_filename)
|
opts.save(shared.config_filename)
|
||||||
|
|
67
webui.py
67
webui.py
|
@ -3,13 +3,8 @@ import threading
|
||||||
|
|
||||||
from modules.paths import script_path
|
from modules.paths import script_path
|
||||||
|
|
||||||
import torch
|
|
||||||
from omegaconf import OmegaConf
|
|
||||||
|
|
||||||
import signal
|
import signal
|
||||||
|
|
||||||
from ldm.util import instantiate_from_config
|
|
||||||
|
|
||||||
from modules.shared import opts, cmd_opts, state
|
from modules.shared import opts, cmd_opts, state
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
import modules.ui
|
import modules.ui
|
||||||
|
@ -24,6 +19,7 @@ import modules.extras
|
||||||
import modules.lowvram
|
import modules.lowvram
|
||||||
import modules.txt2img
|
import modules.txt2img
|
||||||
import modules.img2img
|
import modules.img2img
|
||||||
|
import modules.sd_models
|
||||||
|
|
||||||
|
|
||||||
modules.codeformer_model.setup_codeformer()
|
modules.codeformer_model.setup_codeformer()
|
||||||
|
@ -33,31 +29,19 @@ shared.face_restorers.append(modules.face_restoration.FaceRestoration())
|
||||||
esrgan.load_models(cmd_opts.esrgan_models_path)
|
esrgan.load_models(cmd_opts.esrgan_models_path)
|
||||||
realesrgan.setup_realesrgan()
|
realesrgan.setup_realesrgan()
|
||||||
|
|
||||||
|
|
||||||
def load_model_from_config(config, ckpt, verbose=False):
|
|
||||||
print(f"Loading model [{shared.sd_model_hash}] from {ckpt}")
|
|
||||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
|
||||||
if "global_step" in pl_sd:
|
|
||||||
print(f"Global Step: {pl_sd['global_step']}")
|
|
||||||
sd = pl_sd["state_dict"]
|
|
||||||
|
|
||||||
model = instantiate_from_config(config.model)
|
|
||||||
m, u = model.load_state_dict(sd, strict=False)
|
|
||||||
if len(m) > 0 and verbose:
|
|
||||||
print("missing keys:")
|
|
||||||
print(m)
|
|
||||||
if len(u) > 0 and verbose:
|
|
||||||
print("unexpected keys:")
|
|
||||||
print(u)
|
|
||||||
if cmd_opts.opt_channelslast:
|
|
||||||
model = model.to(memory_format=torch.channels_last)
|
|
||||||
model.eval()
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
queue_lock = threading.Lock()
|
queue_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def wrap_queued_call(func):
|
||||||
|
def f(*args, **kwargs):
|
||||||
|
with queue_lock:
|
||||||
|
res = func(*args, **kwargs)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
return f
|
||||||
|
|
||||||
|
|
||||||
def wrap_gradio_gpu_call(func):
|
def wrap_gradio_gpu_call(func):
|
||||||
def f(*args, **kwargs):
|
def f(*args, **kwargs):
|
||||||
shared.state.sampling_step = 0
|
shared.state.sampling_step = 0
|
||||||
|
@ -80,33 +64,8 @@ def wrap_gradio_gpu_call(func):
|
||||||
|
|
||||||
modules.scripts.load_scripts(os.path.join(script_path, "scripts"))
|
modules.scripts.load_scripts(os.path.join(script_path, "scripts"))
|
||||||
|
|
||||||
try:
|
shared.sd_model = modules.sd_models.load_model()
|
||||||
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model)))
|
||||||
|
|
||||||
from transformers import logging
|
|
||||||
|
|
||||||
logging.set_verbosity_error()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
with open(cmd_opts.ckpt, "rb") as file:
|
|
||||||
import hashlib
|
|
||||||
m = hashlib.sha256()
|
|
||||||
|
|
||||||
file.seek(0x100000)
|
|
||||||
m.update(file.read(0x10000))
|
|
||||||
shared.sd_model_hash = m.hexdigest()[0:8]
|
|
||||||
|
|
||||||
sd_config = OmegaConf.load(cmd_opts.config)
|
|
||||||
shared.sd_model = load_model_from_config(sd_config, cmd_opts.ckpt)
|
|
||||||
shared.sd_model = (shared.sd_model if cmd_opts.no_half else shared.sd_model.half())
|
|
||||||
|
|
||||||
if cmd_opts.lowvram or cmd_opts.medvram:
|
|
||||||
modules.lowvram.setup_for_low_vram(shared.sd_model, cmd_opts.medvram)
|
|
||||||
else:
|
|
||||||
shared.sd_model = shared.sd_model.to(shared.device)
|
|
||||||
|
|
||||||
modules.sd_hijack.model_hijack.hijack(shared.sd_model)
|
|
||||||
|
|
||||||
|
|
||||||
def webui():
|
def webui():
|
||||||
|
|
Loading…
Reference in New Issue