Merge pull request #12227 from AUTOMATIC1111/multiple_loaded_models
option to keep multiple models in memory
This commit is contained in:
commit
c613416af3
|
@ -15,6 +15,9 @@ def send_everything_to_cpu():
|
||||||
|
|
||||||
|
|
||||||
def setup_for_low_vram(sd_model, use_medvram):
|
def setup_for_low_vram(sd_model, use_medvram):
|
||||||
|
if getattr(sd_model, 'lowvram', False):
|
||||||
|
return
|
||||||
|
|
||||||
sd_model.lowvram = True
|
sd_model.lowvram = True
|
||||||
|
|
||||||
parents = {}
|
parents = {}
|
||||||
|
|
|
@ -5,7 +5,7 @@ from types import MethodType
|
||||||
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
|
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
|
||||||
from modules.hypernetworks import hypernetwork
|
from modules.hypernetworks import hypernetwork
|
||||||
from modules.shared import cmd_opts
|
from modules.shared import cmd_opts
|
||||||
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
|
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, sd_hijack_inpainting
|
||||||
|
|
||||||
import ldm.modules.attention
|
import ldm.modules.attention
|
||||||
import ldm.modules.diffusionmodules.model
|
import ldm.modules.diffusionmodules.model
|
||||||
|
@ -29,8 +29,12 @@ ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.Cros
|
||||||
ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
|
ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
|
||||||
|
|
||||||
# silence new console spam from SD2
|
# silence new console spam from SD2
|
||||||
ldm.modules.attention.print = lambda *args: None
|
ldm.modules.attention.print = shared.ldm_print
|
||||||
ldm.modules.diffusionmodules.model.print = lambda *args: None
|
ldm.modules.diffusionmodules.model.print = shared.ldm_print
|
||||||
|
ldm.util.print = shared.ldm_print
|
||||||
|
ldm.models.diffusion.ddpm.print = shared.ldm_print
|
||||||
|
|
||||||
|
sd_hijack_inpainting.do_inpainting_hijack()
|
||||||
|
|
||||||
optimizers = []
|
optimizers = []
|
||||||
current_optimizer: sd_hijack_optimizations.SdOptimization = None
|
current_optimizer: sd_hijack_optimizations.SdOptimization = None
|
||||||
|
|
|
@ -92,6 +92,4 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
|
||||||
|
|
||||||
|
|
||||||
def do_inpainting_hijack():
|
def do_inpainting_hijack():
|
||||||
# p_sample_plms is needed because PLMS can't work with dicts as conditionings
|
|
||||||
|
|
||||||
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
|
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
|
||||||
|
|
|
@ -15,7 +15,6 @@ import ldm.modules.midas as midas
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
|
|
||||||
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache
|
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache
|
||||||
from modules.sd_hijack_inpainting import do_inpainting_hijack
|
|
||||||
from modules.timer import Timer
|
from modules.timer import Timer
|
||||||
import tomesd
|
import tomesd
|
||||||
|
|
||||||
|
@ -434,6 +433,7 @@ sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight'
|
||||||
class SdModelData:
|
class SdModelData:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.sd_model = None
|
self.sd_model = None
|
||||||
|
self.loaded_sd_models = []
|
||||||
self.was_loaded_at_least_once = False
|
self.was_loaded_at_least_once = False
|
||||||
self.lock = threading.Lock()
|
self.lock = threading.Lock()
|
||||||
|
|
||||||
|
@ -448,6 +448,7 @@ class SdModelData:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
load_model()
|
load_model()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
errors.display(e, "loading stable diffusion model", full_traceback=True)
|
errors.display(e, "loading stable diffusion model", full_traceback=True)
|
||||||
print("", file=sys.stderr)
|
print("", file=sys.stderr)
|
||||||
|
@ -459,11 +460,24 @@ class SdModelData:
|
||||||
def set_sd_model(self, v):
|
def set_sd_model(self, v):
|
||||||
self.sd_model = v
|
self.sd_model = v
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.loaded_sd_models.remove(v)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if v is not None:
|
||||||
|
self.loaded_sd_models.insert(0, v)
|
||||||
|
|
||||||
|
|
||||||
model_data = SdModelData()
|
model_data = SdModelData()
|
||||||
|
|
||||||
|
|
||||||
def get_empty_cond(sd_model):
|
def get_empty_cond(sd_model):
|
||||||
|
from modules import extra_networks, processing
|
||||||
|
|
||||||
|
p = processing.StableDiffusionProcessingTxt2Img()
|
||||||
|
extra_networks.activate(p, {})
|
||||||
|
|
||||||
if hasattr(sd_model, 'conditioner'):
|
if hasattr(sd_model, 'conditioner'):
|
||||||
d = sd_model.get_learned_conditioning([""])
|
d = sd_model.get_learned_conditioning([""])
|
||||||
return d['crossattn']
|
return d['crossattn']
|
||||||
|
@ -471,19 +485,43 @@ def get_empty_cond(sd_model):
|
||||||
return sd_model.cond_stage_model([""])
|
return sd_model.cond_stage_model([""])
|
||||||
|
|
||||||
|
|
||||||
|
def send_model_to_cpu(m):
|
||||||
|
from modules import lowvram
|
||||||
|
|
||||||
|
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||||
|
lowvram.send_everything_to_cpu()
|
||||||
|
else:
|
||||||
|
m.to(devices.cpu)
|
||||||
|
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
|
|
||||||
|
def send_model_to_device(m):
|
||||||
|
from modules import lowvram
|
||||||
|
|
||||||
|
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||||
|
lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram)
|
||||||
|
else:
|
||||||
|
m.to(shared.device)
|
||||||
|
|
||||||
|
|
||||||
|
def send_model_to_trash(m):
|
||||||
|
m.to(device="meta")
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
|
|
||||||
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||||
from modules import lowvram, sd_hijack
|
from modules import sd_hijack
|
||||||
checkpoint_info = checkpoint_info or select_checkpoint()
|
checkpoint_info = checkpoint_info or select_checkpoint()
|
||||||
|
|
||||||
|
timer = Timer()
|
||||||
|
|
||||||
if model_data.sd_model:
|
if model_data.sd_model:
|
||||||
sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
|
send_model_to_trash(model_data.sd_model)
|
||||||
model_data.sd_model = None
|
model_data.sd_model = None
|
||||||
gc.collect()
|
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
do_inpainting_hijack()
|
timer.record("unload existing model")
|
||||||
|
|
||||||
timer = Timer()
|
|
||||||
|
|
||||||
if already_loaded_state_dict is not None:
|
if already_loaded_state_dict is not None:
|
||||||
state_dict = already_loaded_state_dict
|
state_dict = already_loaded_state_dict
|
||||||
|
@ -523,12 +561,9 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||||
|
|
||||||
with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu):
|
with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu):
|
||||||
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
||||||
|
timer.record("load weights from state dict")
|
||||||
|
|
||||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
send_model_to_device(sd_model)
|
||||||
lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
|
|
||||||
else:
|
|
||||||
sd_model.to(shared.device)
|
|
||||||
|
|
||||||
timer.record("move model to device")
|
timer.record("move model to device")
|
||||||
|
|
||||||
sd_hijack.model_hijack.hijack(sd_model)
|
sd_hijack.model_hijack.hijack(sd_model)
|
||||||
|
@ -536,7 +571,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||||
timer.record("hijack")
|
timer.record("hijack")
|
||||||
|
|
||||||
sd_model.eval()
|
sd_model.eval()
|
||||||
model_data.sd_model = sd_model
|
model_data.set_sd_model(sd_model)
|
||||||
model_data.was_loaded_at_least_once = True
|
model_data.was_loaded_at_least_once = True
|
||||||
|
|
||||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
|
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
|
||||||
|
@ -557,10 +592,61 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||||
return sd_model
|
return sd_model
|
||||||
|
|
||||||
|
|
||||||
|
def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
|
||||||
|
"""
|
||||||
|
Checks if the desired checkpoint from checkpoint_info is not already loaded in model_data.loaded_sd_models.
|
||||||
|
If it is loaded, returns that (moving it to GPU if necessary, and moving the currently loadded model to CPU if necessary).
|
||||||
|
If not, returns the model that can be used to load weights from checkpoint_info's file.
|
||||||
|
If no such model exists, returns None.
|
||||||
|
Additionaly deletes loaded models that are over the limit set in settings (sd_checkpoints_limit).
|
||||||
|
"""
|
||||||
|
|
||||||
|
already_loaded = None
|
||||||
|
for i in reversed(range(len(model_data.loaded_sd_models))):
|
||||||
|
loaded_model = model_data.loaded_sd_models[i]
|
||||||
|
if loaded_model.sd_checkpoint_info.filename == checkpoint_info.filename:
|
||||||
|
already_loaded = loaded_model
|
||||||
|
continue
|
||||||
|
|
||||||
|
if len(model_data.loaded_sd_models) > shared.opts.sd_checkpoints_limit > 0:
|
||||||
|
print(f"Unloading model {len(model_data.loaded_sd_models)} over the limit of {shared.opts.sd_checkpoints_limit}: {loaded_model.sd_checkpoint_info.title}")
|
||||||
|
model_data.loaded_sd_models.pop()
|
||||||
|
send_model_to_trash(loaded_model)
|
||||||
|
timer.record("send model to trash")
|
||||||
|
|
||||||
|
if shared.opts.sd_checkpoints_keep_in_cpu:
|
||||||
|
send_model_to_cpu(sd_model)
|
||||||
|
timer.record("send model to cpu")
|
||||||
|
|
||||||
|
if already_loaded is not None:
|
||||||
|
send_model_to_device(already_loaded)
|
||||||
|
timer.record("send model to device")
|
||||||
|
|
||||||
|
model_data.set_sd_model(already_loaded)
|
||||||
|
print(f"Using already loaded model {already_loaded.sd_checkpoint_info.title}: done in {timer.summary()}")
|
||||||
|
return model_data.sd_model
|
||||||
|
elif shared.opts.sd_checkpoints_limit > 1 and len(model_data.loaded_sd_models) < shared.opts.sd_checkpoints_limit:
|
||||||
|
print(f"Loading model {checkpoint_info.title} ({len(model_data.loaded_sd_models) + 1} out of {shared.opts.sd_checkpoints_limit})")
|
||||||
|
|
||||||
|
model_data.sd_model = None
|
||||||
|
load_model(checkpoint_info)
|
||||||
|
return model_data.sd_model
|
||||||
|
elif len(model_data.loaded_sd_models) > 0:
|
||||||
|
sd_model = model_data.loaded_sd_models.pop()
|
||||||
|
model_data.sd_model = sd_model
|
||||||
|
|
||||||
|
print(f"Reusing loaded model {sd_model.sd_checkpoint_info.title} to load {checkpoint_info.title}")
|
||||||
|
return sd_model
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def reload_model_weights(sd_model=None, info=None):
|
def reload_model_weights(sd_model=None, info=None):
|
||||||
from modules import lowvram, devices, sd_hijack
|
from modules import devices, sd_hijack
|
||||||
checkpoint_info = info or select_checkpoint()
|
checkpoint_info = info or select_checkpoint()
|
||||||
|
|
||||||
|
timer = Timer()
|
||||||
|
|
||||||
if not sd_model:
|
if not sd_model:
|
||||||
sd_model = model_data.sd_model
|
sd_model = model_data.sd_model
|
||||||
|
|
||||||
|
@ -569,19 +655,17 @@ def reload_model_weights(sd_model=None, info=None):
|
||||||
else:
|
else:
|
||||||
current_checkpoint_info = sd_model.sd_checkpoint_info
|
current_checkpoint_info = sd_model.sd_checkpoint_info
|
||||||
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
||||||
return
|
return sd_model
|
||||||
|
|
||||||
|
sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer)
|
||||||
|
if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
|
||||||
|
return sd_model
|
||||||
|
|
||||||
|
if sd_model is not None:
|
||||||
sd_unet.apply_unet("None")
|
sd_unet.apply_unet("None")
|
||||||
|
send_model_to_cpu(sd_model)
|
||||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
|
||||||
lowvram.send_everything_to_cpu()
|
|
||||||
else:
|
|
||||||
sd_model.to(devices.cpu)
|
|
||||||
|
|
||||||
sd_hijack.model_hijack.undo_hijack(sd_model)
|
sd_hijack.model_hijack.undo_hijack(sd_model)
|
||||||
|
|
||||||
timer = Timer()
|
|
||||||
|
|
||||||
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
||||||
|
|
||||||
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
|
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
|
||||||
|
@ -590,9 +674,8 @@ def reload_model_weights(sd_model=None, info=None):
|
||||||
|
|
||||||
if sd_model is None or checkpoint_config != sd_model.used_config:
|
if sd_model is None or checkpoint_config != sd_model.used_config:
|
||||||
if sd_model is not None:
|
if sd_model is not None:
|
||||||
sd_model.to(device="meta")
|
send_model_to_trash(sd_model)
|
||||||
|
|
||||||
devices.torch_gc()
|
|
||||||
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
|
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
|
||||||
return model_data.sd_model
|
return model_data.sd_model
|
||||||
|
|
||||||
|
@ -615,6 +698,8 @@ def reload_model_weights(sd_model=None, info=None):
|
||||||
|
|
||||||
print(f"Weights loaded in {timer.summary()}.")
|
print(f"Weights loaded in {timer.summary()}.")
|
||||||
|
|
||||||
|
model_data.set_sd_model(sd_model)
|
||||||
|
|
||||||
return sd_model
|
return sd_model
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -98,10 +98,10 @@ def extend_sdxl(model):
|
||||||
model.conditioner.wrapped = torch.nn.Module()
|
model.conditioner.wrapped = torch.nn.Module()
|
||||||
|
|
||||||
|
|
||||||
sgm.modules.attention.print = lambda *args: None
|
sgm.modules.attention.print = shared.ldm_print
|
||||||
sgm.modules.diffusionmodules.model.print = lambda *args: None
|
sgm.modules.diffusionmodules.model.print = shared.ldm_print
|
||||||
sgm.modules.diffusionmodules.openaimodel.print = lambda *args: None
|
sgm.modules.diffusionmodules.openaimodel.print = shared.ldm_print
|
||||||
sgm.modules.encoders.modules.print = lambda *args: None
|
sgm.modules.encoders.modules.print = shared.ldm_print
|
||||||
|
|
||||||
# this gets the code to load the vanilla attention that we override
|
# this gets the code to load the vanilla attention that we override
|
||||||
sgm.modules.attention.SDP_IS_AVAILABLE = True
|
sgm.modules.attention.SDP_IS_AVAILABLE = True
|
||||||
|
|
|
@ -400,6 +400,7 @@ options_templates.update(options_section(('system', "System"), {
|
||||||
"print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
|
"print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
|
||||||
"list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
|
"list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
|
||||||
"disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
|
"disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
|
||||||
|
"hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('training', "Training"), {
|
options_templates.update(options_section(('training', "Training"), {
|
||||||
|
@ -419,7 +420,9 @@ options_templates.update(options_section(('training', "Training"), {
|
||||||
|
|
||||||
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||||
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
|
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
|
||||||
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
"sd_checkpoints_limit": OptionInfo(1, "Maximum number of checkpoints loaded at the same time", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}),
|
||||||
|
"sd_checkpoints_keep_in_cpu": OptionInfo(True, "Only keep one model on device").info("will keep models other than the currently used one in RAM rather than VRAM"),
|
||||||
|
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}).info("obsolete; set to 0 and use the two settings above instead"),
|
||||||
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||||
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
|
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
|
||||||
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
|
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
|
||||||
|
@ -906,3 +909,10 @@ def walk_files(path, allowed_extensions=None):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
yield os.path.join(root, filename)
|
yield os.path.join(root, filename)
|
||||||
|
|
||||||
|
|
||||||
|
def ldm_print(*args, **kwargs):
|
||||||
|
if opts.hide_ldm_prints:
|
||||||
|
return
|
||||||
|
|
||||||
|
print(*args, **kwargs)
|
||||||
|
|
Loading…
Reference in New Issue