add support for switching model checkpoints at runtime
This commit is contained in:
parent
b8be33dad1
commit
247f58a5e7
|
@ -274,7 +274,7 @@ def apply_filename_pattern(x, p, seed, prompt):
|
|||
x = x.replace("[height]", str(p.height))
|
||||
x = x.replace("[sampler]", sd_samplers.samplers[p.sampler_index].name)
|
||||
|
||||
x = x.replace("[model_hash]", shared.sd_model_hash)
|
||||
x = x.replace("[model_hash]", shared.sd_model.sd_model_hash)
|
||||
x = x.replace("[date]", datetime.date.today().isoformat())
|
||||
|
||||
if cmd_opts.hide_ui_dir_config:
|
||||
|
|
|
@ -227,7 +227,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||
"Seed": all_seeds[index],
|
||||
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
|
||||
"Size": f"{p.width}x{p.height}",
|
||||
"Model hash": (None if not opts.add_model_hash_to_info or not shared.sd_model_hash else shared.sd_model_hash),
|
||||
"Model hash": (None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
|
||||
"Batch size": (None if p.batch_size < 2 else p.batch_size),
|
||||
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
|
||||
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
|
||||
|
|
|
@ -0,0 +1,148 @@
|
|||
import glob
|
||||
import os.path
|
||||
import sys
|
||||
from collections import namedtuple
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
|
||||
from modules import shared
|
||||
|
||||
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash'])
|
||||
checkpoints_list = {}
|
||||
|
||||
try:
|
||||
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
||||
|
||||
from transformers import logging
|
||||
|
||||
logging.set_verbosity_error()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def list_models():
|
||||
checkpoints_list.clear()
|
||||
|
||||
model_dir = os.path.abspath(shared.cmd_opts.ckpt_dir)
|
||||
|
||||
def modeltitle(path, h):
|
||||
abspath = os.path.abspath(path)
|
||||
|
||||
if abspath.startswith(model_dir):
|
||||
name = abspath.replace(model_dir, '')
|
||||
else:
|
||||
name = os.path.basename(path)
|
||||
|
||||
if name.startswith("\\") or name.startswith("/"):
|
||||
name = name[1:]
|
||||
|
||||
return f'{name} [{h}]'
|
||||
|
||||
cmd_ckpt = shared.cmd_opts.ckpt
|
||||
if os.path.exists(cmd_ckpt):
|
||||
h = model_hash(cmd_ckpt)
|
||||
title = modeltitle(cmd_ckpt, h)
|
||||
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h)
|
||||
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
|
||||
print(f"Checkpoint in --ckpt argument not found: {cmd_ckpt}", file=sys.stderr)
|
||||
|
||||
if os.path.exists(model_dir):
|
||||
for filename in glob.glob(model_dir + '/**/*.ckpt', recursive=True):
|
||||
h = model_hash(filename)
|
||||
title = modeltitle(filename, h)
|
||||
checkpoints_list[title] = CheckpointInfo(filename, title, h)
|
||||
|
||||
|
||||
def model_hash(filename):
|
||||
try:
|
||||
with open(filename, "rb") as file:
|
||||
import hashlib
|
||||
m = hashlib.sha256()
|
||||
|
||||
file.seek(0x100000)
|
||||
m.update(file.read(0x10000))
|
||||
return m.hexdigest()[0:8]
|
||||
except FileNotFoundError:
|
||||
return 'NOFILE'
|
||||
|
||||
|
||||
def select_checkpoint():
|
||||
model_checkpoint = shared.opts.sd_model_checkpoint
|
||||
checkpoint_info = checkpoints_list.get(model_checkpoint, None)
|
||||
if checkpoint_info is not None:
|
||||
return checkpoint_info
|
||||
|
||||
if len(checkpoints_list) == 0:
|
||||
print(f"Checkpoint {model_checkpoint} not found and no other checkpoints found", file=sys.stderr)
|
||||
return None
|
||||
|
||||
checkpoint_info = next(iter(checkpoints_list.values()))
|
||||
if model_checkpoint is not None:
|
||||
print(f"Checkpoint {model_checkpoint} not found; loading fallback {checkpoint_info.title}", file=sys.stderr)
|
||||
|
||||
return checkpoint_info
|
||||
|
||||
|
||||
def load_model_weights(model, checkpoint_file, sd_model_hash):
|
||||
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
|
||||
|
||||
pl_sd = torch.load(checkpoint_file, map_location="cpu")
|
||||
if "global_step" in pl_sd:
|
||||
print(f"Global Step: {pl_sd['global_step']}")
|
||||
sd = pl_sd["state_dict"]
|
||||
|
||||
model.load_state_dict(sd, strict=False)
|
||||
|
||||
if shared.cmd_opts.opt_channelslast:
|
||||
model.to(memory_format=torch.channels_last)
|
||||
|
||||
if not shared.cmd_opts.no_half:
|
||||
model.half()
|
||||
|
||||
model.sd_model_hash = sd_model_hash
|
||||
model.sd_model_checkpint = checkpoint_file
|
||||
|
||||
|
||||
def load_model():
|
||||
from modules import lowvram, sd_hijack
|
||||
checkpoint_info = select_checkpoint()
|
||||
|
||||
sd_config = OmegaConf.load(shared.cmd_opts.config)
|
||||
sd_model = instantiate_from_config(sd_config.model)
|
||||
load_model_weights(sd_model, checkpoint_info.filename, checkpoint_info.hash)
|
||||
|
||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||
lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
|
||||
else:
|
||||
sd_model.to(shared.device)
|
||||
|
||||
sd_hijack.model_hijack.hijack(sd_model)
|
||||
|
||||
sd_model.eval()
|
||||
|
||||
print(f"Model loaded.")
|
||||
return sd_model
|
||||
|
||||
|
||||
def reload_model_weights(sd_model):
|
||||
from modules import lowvram, devices
|
||||
checkpoint_info = select_checkpoint()
|
||||
|
||||
if sd_model.sd_model_checkpint == checkpoint_info.filename:
|
||||
return
|
||||
|
||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||
lowvram.send_everything_to_cpu()
|
||||
else:
|
||||
sd_model.to(devices.cpu)
|
||||
|
||||
load_model_weights(sd_model, checkpoint_info.filename, checkpoint_info.hash)
|
||||
|
||||
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
|
||||
sd_model.to(devices.device)
|
||||
|
||||
print(f"Weights loaded.")
|
||||
return sd_model
|
|
@ -13,14 +13,15 @@ from modules.devices import get_optimal_device
|
|||
import modules.styles
|
||||
import modules.interrogate
|
||||
import modules.memmon
|
||||
import modules.sd_models
|
||||
|
||||
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
||||
if not os.path.exists(sd_model_file):
|
||||
sd_model_file = "models/ldm/stable-diffusion-v1/model.ckpt"
|
||||
default_sd_model_file = sd_model_file
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",)
|
||||
parser.add_argument("--ckpt", type=str, default=os.path.join(sd_path, sd_model_file), help="path to checkpoint of model",)
|
||||
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; this checkpoint will be added to the list of checkpoints and loaded by default if you don't have a checkpoint selected in settings",)
|
||||
parser.add_argument("--ckpt-dir", type=str, default=os.path.join(script_path, 'models'), help="path to directory with stable diffusion checkpoints",)
|
||||
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
|
||||
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default='GFPGANv1.3.pth')
|
||||
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
|
||||
|
@ -88,13 +89,17 @@ interrogator = modules.interrogate.InterrogateModels("interrogate")
|
|||
|
||||
face_restorers = []
|
||||
|
||||
modules.sd_models.list_models()
|
||||
|
||||
|
||||
class Options:
|
||||
class OptionInfo:
|
||||
def __init__(self, default=None, label="", component=None, component_args=None):
|
||||
def __init__(self, default=None, label="", component=None, component_args=None, onchange=None):
|
||||
self.default = default
|
||||
self.label = label
|
||||
self.component = component
|
||||
self.component_args = component_args
|
||||
self.onchange = onchange
|
||||
|
||||
data = None
|
||||
hide_dirs = {"visible": False} if cmd_opts.hide_ui_dir_config else None
|
||||
|
@ -150,6 +155,7 @@ class Options:
|
|||
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
|
||||
"interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
|
||||
"interrogate_clip_dict_limit": OptionInfo(1500, "Interrogate: maximum number of lines in text file (0 = No limit)"),
|
||||
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Radio, lambda: {"choices": [x.title for x in modules.sd_models.checkpoints_list.values()]}),
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
|
@ -180,6 +186,10 @@ class Options:
|
|||
with open(filename, "r", encoding="utf8") as file:
|
||||
self.data = json.load(file)
|
||||
|
||||
def onchange(self, key, func):
|
||||
item = self.data_labels.get(key)
|
||||
item.onchange = func
|
||||
|
||||
|
||||
opts = Options()
|
||||
if os.path.exists(config_filename):
|
||||
|
@ -188,7 +198,6 @@ if os.path.exists(config_filename):
|
|||
sd_upscalers = []
|
||||
|
||||
sd_model = None
|
||||
sd_model_hash = ''
|
||||
|
||||
progress_print_out = sys.stdout
|
||||
|
||||
|
|
|
@ -758,7 +758,12 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
|
|||
if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
|
||||
continue
|
||||
|
||||
oldval = opts.data.get(key, None)
|
||||
opts.data[key] = value
|
||||
|
||||
if oldval != value and opts.data_labels[key].onchange is not None:
|
||||
opts.data_labels[key].onchange()
|
||||
|
||||
up.append(comp.update(value=value))
|
||||
|
||||
opts.save(shared.config_filename)
|
||||
|
|
67
webui.py
67
webui.py
|
@ -3,13 +3,8 @@ import threading
|
|||
|
||||
from modules.paths import script_path
|
||||
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
import signal
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
|
||||
from modules.shared import opts, cmd_opts, state
|
||||
import modules.shared as shared
|
||||
import modules.ui
|
||||
|
@ -24,6 +19,7 @@ import modules.extras
|
|||
import modules.lowvram
|
||||
import modules.txt2img
|
||||
import modules.img2img
|
||||
import modules.sd_models
|
||||
|
||||
|
||||
modules.codeformer_model.setup_codeformer()
|
||||
|
@ -33,31 +29,19 @@ shared.face_restorers.append(modules.face_restoration.FaceRestoration())
|
|||
esrgan.load_models(cmd_opts.esrgan_models_path)
|
||||
realesrgan.setup_realesrgan()
|
||||
|
||||
|
||||
def load_model_from_config(config, ckpt, verbose=False):
|
||||
print(f"Loading model [{shared.sd_model_hash}] from {ckpt}")
|
||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||
if "global_step" in pl_sd:
|
||||
print(f"Global Step: {pl_sd['global_step']}")
|
||||
sd = pl_sd["state_dict"]
|
||||
|
||||
model = instantiate_from_config(config.model)
|
||||
m, u = model.load_state_dict(sd, strict=False)
|
||||
if len(m) > 0 and verbose:
|
||||
print("missing keys:")
|
||||
print(m)
|
||||
if len(u) > 0 and verbose:
|
||||
print("unexpected keys:")
|
||||
print(u)
|
||||
if cmd_opts.opt_channelslast:
|
||||
model = model.to(memory_format=torch.channels_last)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
queue_lock = threading.Lock()
|
||||
|
||||
|
||||
def wrap_queued_call(func):
|
||||
def f(*args, **kwargs):
|
||||
with queue_lock:
|
||||
res = func(*args, **kwargs)
|
||||
|
||||
return res
|
||||
|
||||
return f
|
||||
|
||||
|
||||
def wrap_gradio_gpu_call(func):
|
||||
def f(*args, **kwargs):
|
||||
shared.state.sampling_step = 0
|
||||
|
@ -80,33 +64,8 @@ def wrap_gradio_gpu_call(func):
|
|||
|
||||
modules.scripts.load_scripts(os.path.join(script_path, "scripts"))
|
||||
|
||||
try:
|
||||
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
||||
|
||||
from transformers import logging
|
||||
|
||||
logging.set_verbosity_error()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
with open(cmd_opts.ckpt, "rb") as file:
|
||||
import hashlib
|
||||
m = hashlib.sha256()
|
||||
|
||||
file.seek(0x100000)
|
||||
m.update(file.read(0x10000))
|
||||
shared.sd_model_hash = m.hexdigest()[0:8]
|
||||
|
||||
sd_config = OmegaConf.load(cmd_opts.config)
|
||||
shared.sd_model = load_model_from_config(sd_config, cmd_opts.ckpt)
|
||||
shared.sd_model = (shared.sd_model if cmd_opts.no_half else shared.sd_model.half())
|
||||
|
||||
if cmd_opts.lowvram or cmd_opts.medvram:
|
||||
modules.lowvram.setup_for_low_vram(shared.sd_model, cmd_opts.medvram)
|
||||
else:
|
||||
shared.sd_model = shared.sd_model.to(shared.device)
|
||||
|
||||
modules.sd_hijack.model_hijack.hijack(shared.sd_model)
|
||||
shared.sd_model = modules.sd_models.load_model()
|
||||
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model)))
|
||||
|
||||
|
||||
def webui():
|
||||
|
|
Loading…
Reference in New Issue