diff --git a/modules/lowvram.py b/modules/lowvram.py index 3f8306643..96f52b7b4 100644 --- a/modules/lowvram.py +++ b/modules/lowvram.py @@ -15,6 +15,9 @@ def send_everything_to_cpu(): def setup_for_low_vram(sd_model, use_medvram): + if getattr(sd_model, 'lowvram', False): + return + sd_model.lowvram = True parents = {} diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 609fd56c4..9ad981998 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -5,7 +5,7 @@ from types import MethodType from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet from modules.hypernetworks import hypernetwork from modules.shared import cmd_opts -from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr +from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, sd_hijack_inpainting import ldm.modules.attention import ldm.modules.diffusionmodules.model @@ -29,8 +29,12 @@ ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.Cros ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention # silence new console spam from SD2 -ldm.modules.attention.print = lambda *args: None -ldm.modules.diffusionmodules.model.print = lambda *args: None +ldm.modules.attention.print = shared.ldm_print +ldm.modules.diffusionmodules.model.print = shared.ldm_print +ldm.util.print = shared.ldm_print +ldm.models.diffusion.ddpm.print = shared.ldm_print + +sd_hijack_inpainting.do_inpainting_hijack() optimizers = [] current_optimizer: sd_hijack_optimizations.SdOptimization = None diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py index c1977b194..2d44b8566 100644 --- a/modules/sd_hijack_inpainting.py +++ b/modules/sd_hijack_inpainting.py @@ -92,6 +92,4 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F def do_inpainting_hijack(): - # p_sample_plms is needed because PLMS can't work with dicts as conditionings - ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms diff --git a/modules/sd_models.py b/modules/sd_models.py index 1608b37f5..f60516046 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -15,7 +15,6 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache -from modules.sd_hijack_inpainting import do_inpainting_hijack from modules.timer import Timer import tomesd @@ -434,6 +433,7 @@ sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight' class SdModelData: def __init__(self): self.sd_model = None + self.loaded_sd_models = [] self.was_loaded_at_least_once = False self.lock = threading.Lock() @@ -448,6 +448,7 @@ class SdModelData: try: load_model() + except Exception as e: errors.display(e, "loading stable diffusion model", full_traceback=True) print("", file=sys.stderr) @@ -459,11 +460,24 @@ class SdModelData: def set_sd_model(self, v): self.sd_model = v + try: + self.loaded_sd_models.remove(v) + except ValueError: + pass + + if v is not None: + self.loaded_sd_models.insert(0, v) + model_data = SdModelData() def get_empty_cond(sd_model): + from modules import extra_networks, processing + + p = processing.StableDiffusionProcessingTxt2Img() + extra_networks.activate(p, {}) + if hasattr(sd_model, 'conditioner'): d = sd_model.get_learned_conditioning([""]) return d['crossattn'] @@ -471,19 +485,43 @@ def get_empty_cond(sd_model): return sd_model.cond_stage_model([""]) +def send_model_to_cpu(m): + from modules import lowvram + + if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: + lowvram.send_everything_to_cpu() + else: + m.to(devices.cpu) + + devices.torch_gc() + + +def send_model_to_device(m): + from modules import lowvram + + if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: + lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram) + else: + m.to(shared.device) + + +def send_model_to_trash(m): + m.to(device="meta") + devices.torch_gc() + + def load_model(checkpoint_info=None, already_loaded_state_dict=None): - from modules import lowvram, sd_hijack + from modules import sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() + timer = Timer() + if model_data.sd_model: - sd_hijack.model_hijack.undo_hijack(model_data.sd_model) + send_model_to_trash(model_data.sd_model) model_data.sd_model = None - gc.collect() devices.torch_gc() - do_inpainting_hijack() - - timer = Timer() + timer.record("unload existing model") if already_loaded_state_dict is not None: state_dict = already_loaded_state_dict @@ -523,12 +561,9 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu): load_model_weights(sd_model, checkpoint_info, state_dict, timer) + timer.record("load weights from state dict") - if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: - lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram) - else: - sd_model.to(shared.device) - + send_model_to_device(sd_model) timer.record("move model to device") sd_hijack.model_hijack.hijack(sd_model) @@ -536,7 +571,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): timer.record("hijack") sd_model.eval() - model_data.sd_model = sd_model + model_data.set_sd_model(sd_model) model_data.was_loaded_at_least_once = True sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model @@ -557,10 +592,61 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): return sd_model +def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): + """ + Checks if the desired checkpoint from checkpoint_info is not already loaded in model_data.loaded_sd_models. + If it is loaded, returns that (moving it to GPU if necessary, and moving the currently loadded model to CPU if necessary). + If not, returns the model that can be used to load weights from checkpoint_info's file. + If no such model exists, returns None. + Additionaly deletes loaded models that are over the limit set in settings (sd_checkpoints_limit). + """ + + already_loaded = None + for i in reversed(range(len(model_data.loaded_sd_models))): + loaded_model = model_data.loaded_sd_models[i] + if loaded_model.sd_checkpoint_info.filename == checkpoint_info.filename: + already_loaded = loaded_model + continue + + if len(model_data.loaded_sd_models) > shared.opts.sd_checkpoints_limit > 0: + print(f"Unloading model {len(model_data.loaded_sd_models)} over the limit of {shared.opts.sd_checkpoints_limit}: {loaded_model.sd_checkpoint_info.title}") + model_data.loaded_sd_models.pop() + send_model_to_trash(loaded_model) + timer.record("send model to trash") + + if shared.opts.sd_checkpoints_keep_in_cpu: + send_model_to_cpu(sd_model) + timer.record("send model to cpu") + + if already_loaded is not None: + send_model_to_device(already_loaded) + timer.record("send model to device") + + model_data.set_sd_model(already_loaded) + print(f"Using already loaded model {already_loaded.sd_checkpoint_info.title}: done in {timer.summary()}") + return model_data.sd_model + elif shared.opts.sd_checkpoints_limit > 1 and len(model_data.loaded_sd_models) < shared.opts.sd_checkpoints_limit: + print(f"Loading model {checkpoint_info.title} ({len(model_data.loaded_sd_models) + 1} out of {shared.opts.sd_checkpoints_limit})") + + model_data.sd_model = None + load_model(checkpoint_info) + return model_data.sd_model + elif len(model_data.loaded_sd_models) > 0: + sd_model = model_data.loaded_sd_models.pop() + model_data.sd_model = sd_model + + print(f"Reusing loaded model {sd_model.sd_checkpoint_info.title} to load {checkpoint_info.title}") + return sd_model + else: + return None + + def reload_model_weights(sd_model=None, info=None): - from modules import lowvram, devices, sd_hijack + from modules import devices, sd_hijack checkpoint_info = info or select_checkpoint() + timer = Timer() + if not sd_model: sd_model = model_data.sd_model @@ -569,19 +655,17 @@ def reload_model_weights(sd_model=None, info=None): else: current_checkpoint_info = sd_model.sd_checkpoint_info if sd_model.sd_model_checkpoint == checkpoint_info.filename: - return + return sd_model + sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer) + if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename: + return sd_model + + if sd_model is not None: sd_unet.apply_unet("None") - - if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: - lowvram.send_everything_to_cpu() - else: - sd_model.to(devices.cpu) - + send_model_to_cpu(sd_model) sd_hijack.model_hijack.undo_hijack(sd_model) - timer = Timer() - state_dict = get_checkpoint_state_dict(checkpoint_info, timer) checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) @@ -590,9 +674,8 @@ def reload_model_weights(sd_model=None, info=None): if sd_model is None or checkpoint_config != sd_model.used_config: if sd_model is not None: - sd_model.to(device="meta") + send_model_to_trash(sd_model) - devices.torch_gc() load_model(checkpoint_info, already_loaded_state_dict=state_dict) return model_data.sd_model @@ -615,6 +698,8 @@ def reload_model_weights(sd_model=None, info=None): print(f"Weights loaded in {timer.summary()}.") + model_data.set_sd_model(sd_model) + return sd_model diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index bc2195087..011233216 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -98,10 +98,10 @@ def extend_sdxl(model): model.conditioner.wrapped = torch.nn.Module() -sgm.modules.attention.print = lambda *args: None -sgm.modules.diffusionmodules.model.print = lambda *args: None -sgm.modules.diffusionmodules.openaimodel.print = lambda *args: None -sgm.modules.encoders.modules.print = lambda *args: None +sgm.modules.attention.print = shared.ldm_print +sgm.modules.diffusionmodules.model.print = shared.ldm_print +sgm.modules.diffusionmodules.openaimodel.print = shared.ldm_print +sgm.modules.encoders.modules.print = shared.ldm_print # this gets the code to load the vanilla attention that we override sgm.modules.attention.SDP_IS_AVAILABLE = True diff --git a/modules/shared.py b/modules/shared.py index 86ffb48fd..8245250a5 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -400,6 +400,7 @@ options_templates.update(options_section(('system', "System"), { "print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."), "list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""), "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"), + "hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."), })) options_templates.update(options_section(('training', "Training"), { @@ -419,7 +420,9 @@ options_templates.update(options_section(('training', "Training"), { options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints), - "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), + "sd_checkpoints_limit": OptionInfo(1, "Maximum number of checkpoints loaded at the same time", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}), + "sd_checkpoints_keep_in_cpu": OptionInfo(True, "Only keep one model on device").info("will keep models other than the currently used one in RAM rather than VRAM"), + "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}).info("obsolete; set to 0 and use the two settings above instead"), "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"), "sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), @@ -906,3 +909,10 @@ def walk_files(path, allowed_extensions=None): continue yield os.path.join(root, filename) + + +def ldm_print(*args, **kwargs): + if opts.hide_ldm_prints: + return + + print(*args, **kwargs)