Merge pull request #12227 from AUTOMATIC1111/multiple_loaded_models

option to keep multiple models in memory
This commit is contained in:
AUTOMATIC1111 2023-08-05 07:52:50 +03:00 committed by GitHub
commit c613416af3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 135 additions and 35 deletions

View File

@ -15,6 +15,9 @@ def send_everything_to_cpu():
def setup_for_low_vram(sd_model, use_medvram): def setup_for_low_vram(sd_model, use_medvram):
if getattr(sd_model, 'lowvram', False):
return
sd_model.lowvram = True sd_model.lowvram = True
parents = {} parents = {}

View File

@ -5,7 +5,7 @@ from types import MethodType
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
from modules.hypernetworks import hypernetwork from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts from modules.shared import cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, sd_hijack_inpainting
import ldm.modules.attention import ldm.modules.attention
import ldm.modules.diffusionmodules.model import ldm.modules.diffusionmodules.model
@ -29,8 +29,12 @@ ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.Cros
ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
# silence new console spam from SD2 # silence new console spam from SD2
ldm.modules.attention.print = lambda *args: None ldm.modules.attention.print = shared.ldm_print
ldm.modules.diffusionmodules.model.print = lambda *args: None ldm.modules.diffusionmodules.model.print = shared.ldm_print
ldm.util.print = shared.ldm_print
ldm.models.diffusion.ddpm.print = shared.ldm_print
sd_hijack_inpainting.do_inpainting_hijack()
optimizers = [] optimizers = []
current_optimizer: sd_hijack_optimizations.SdOptimization = None current_optimizer: sd_hijack_optimizations.SdOptimization = None

View File

@ -92,6 +92,4 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
def do_inpainting_hijack(): def do_inpainting_hijack():
# p_sample_plms is needed because PLMS can't work with dicts as conditionings
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms

View File

@ -15,7 +15,6 @@ import ldm.modules.midas as midas
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache
from modules.sd_hijack_inpainting import do_inpainting_hijack
from modules.timer import Timer from modules.timer import Timer
import tomesd import tomesd
@ -434,6 +433,7 @@ sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight'
class SdModelData: class SdModelData:
def __init__(self): def __init__(self):
self.sd_model = None self.sd_model = None
self.loaded_sd_models = []
self.was_loaded_at_least_once = False self.was_loaded_at_least_once = False
self.lock = threading.Lock() self.lock = threading.Lock()
@ -448,6 +448,7 @@ class SdModelData:
try: try:
load_model() load_model()
except Exception as e: except Exception as e:
errors.display(e, "loading stable diffusion model", full_traceback=True) errors.display(e, "loading stable diffusion model", full_traceback=True)
print("", file=sys.stderr) print("", file=sys.stderr)
@ -459,11 +460,24 @@ class SdModelData:
def set_sd_model(self, v): def set_sd_model(self, v):
self.sd_model = v self.sd_model = v
try:
self.loaded_sd_models.remove(v)
except ValueError:
pass
if v is not None:
self.loaded_sd_models.insert(0, v)
model_data = SdModelData() model_data = SdModelData()
def get_empty_cond(sd_model): def get_empty_cond(sd_model):
from modules import extra_networks, processing
p = processing.StableDiffusionProcessingTxt2Img()
extra_networks.activate(p, {})
if hasattr(sd_model, 'conditioner'): if hasattr(sd_model, 'conditioner'):
d = sd_model.get_learned_conditioning([""]) d = sd_model.get_learned_conditioning([""])
return d['crossattn'] return d['crossattn']
@ -471,19 +485,43 @@ def get_empty_cond(sd_model):
return sd_model.cond_stage_model([""]) return sd_model.cond_stage_model([""])
def send_model_to_cpu(m):
from modules import lowvram
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu()
else:
m.to(devices.cpu)
devices.torch_gc()
def send_model_to_device(m):
from modules import lowvram
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram)
else:
m.to(shared.device)
def send_model_to_trash(m):
m.to(device="meta")
devices.torch_gc()
def load_model(checkpoint_info=None, already_loaded_state_dict=None): def load_model(checkpoint_info=None, already_loaded_state_dict=None):
from modules import lowvram, sd_hijack from modules import sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint() checkpoint_info = checkpoint_info or select_checkpoint()
timer = Timer()
if model_data.sd_model: if model_data.sd_model:
sd_hijack.model_hijack.undo_hijack(model_data.sd_model) send_model_to_trash(model_data.sd_model)
model_data.sd_model = None model_data.sd_model = None
gc.collect()
devices.torch_gc() devices.torch_gc()
do_inpainting_hijack() timer.record("unload existing model")
timer = Timer()
if already_loaded_state_dict is not None: if already_loaded_state_dict is not None:
state_dict = already_loaded_state_dict state_dict = already_loaded_state_dict
@ -523,12 +561,9 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu): with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu):
load_model_weights(sd_model, checkpoint_info, state_dict, timer) load_model_weights(sd_model, checkpoint_info, state_dict, timer)
timer.record("load weights from state dict")
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: send_model_to_device(sd_model)
lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
else:
sd_model.to(shared.device)
timer.record("move model to device") timer.record("move model to device")
sd_hijack.model_hijack.hijack(sd_model) sd_hijack.model_hijack.hijack(sd_model)
@ -536,7 +571,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
timer.record("hijack") timer.record("hijack")
sd_model.eval() sd_model.eval()
model_data.sd_model = sd_model model_data.set_sd_model(sd_model)
model_data.was_loaded_at_least_once = True model_data.was_loaded_at_least_once = True
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
@ -557,10 +592,61 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
return sd_model return sd_model
def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
"""
Checks if the desired checkpoint from checkpoint_info is not already loaded in model_data.loaded_sd_models.
If it is loaded, returns that (moving it to GPU if necessary, and moving the currently loadded model to CPU if necessary).
If not, returns the model that can be used to load weights from checkpoint_info's file.
If no such model exists, returns None.
Additionaly deletes loaded models that are over the limit set in settings (sd_checkpoints_limit).
"""
already_loaded = None
for i in reversed(range(len(model_data.loaded_sd_models))):
loaded_model = model_data.loaded_sd_models[i]
if loaded_model.sd_checkpoint_info.filename == checkpoint_info.filename:
already_loaded = loaded_model
continue
if len(model_data.loaded_sd_models) > shared.opts.sd_checkpoints_limit > 0:
print(f"Unloading model {len(model_data.loaded_sd_models)} over the limit of {shared.opts.sd_checkpoints_limit}: {loaded_model.sd_checkpoint_info.title}")
model_data.loaded_sd_models.pop()
send_model_to_trash(loaded_model)
timer.record("send model to trash")
if shared.opts.sd_checkpoints_keep_in_cpu:
send_model_to_cpu(sd_model)
timer.record("send model to cpu")
if already_loaded is not None:
send_model_to_device(already_loaded)
timer.record("send model to device")
model_data.set_sd_model(already_loaded)
print(f"Using already loaded model {already_loaded.sd_checkpoint_info.title}: done in {timer.summary()}")
return model_data.sd_model
elif shared.opts.sd_checkpoints_limit > 1 and len(model_data.loaded_sd_models) < shared.opts.sd_checkpoints_limit:
print(f"Loading model {checkpoint_info.title} ({len(model_data.loaded_sd_models) + 1} out of {shared.opts.sd_checkpoints_limit})")
model_data.sd_model = None
load_model(checkpoint_info)
return model_data.sd_model
elif len(model_data.loaded_sd_models) > 0:
sd_model = model_data.loaded_sd_models.pop()
model_data.sd_model = sd_model
print(f"Reusing loaded model {sd_model.sd_checkpoint_info.title} to load {checkpoint_info.title}")
return sd_model
else:
return None
def reload_model_weights(sd_model=None, info=None): def reload_model_weights(sd_model=None, info=None):
from modules import lowvram, devices, sd_hijack from modules import devices, sd_hijack
checkpoint_info = info or select_checkpoint() checkpoint_info = info or select_checkpoint()
timer = Timer()
if not sd_model: if not sd_model:
sd_model = model_data.sd_model sd_model = model_data.sd_model
@ -569,19 +655,17 @@ def reload_model_weights(sd_model=None, info=None):
else: else:
current_checkpoint_info = sd_model.sd_checkpoint_info current_checkpoint_info = sd_model.sd_checkpoint_info
if sd_model.sd_model_checkpoint == checkpoint_info.filename: if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return return sd_model
sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer)
if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
return sd_model
if sd_model is not None:
sd_unet.apply_unet("None") sd_unet.apply_unet("None")
send_model_to_cpu(sd_model)
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu()
else:
sd_model.to(devices.cpu)
sd_hijack.model_hijack.undo_hijack(sd_model) sd_hijack.model_hijack.undo_hijack(sd_model)
timer = Timer()
state_dict = get_checkpoint_state_dict(checkpoint_info, timer) state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
@ -590,9 +674,8 @@ def reload_model_weights(sd_model=None, info=None):
if sd_model is None or checkpoint_config != sd_model.used_config: if sd_model is None or checkpoint_config != sd_model.used_config:
if sd_model is not None: if sd_model is not None:
sd_model.to(device="meta") send_model_to_trash(sd_model)
devices.torch_gc()
load_model(checkpoint_info, already_loaded_state_dict=state_dict) load_model(checkpoint_info, already_loaded_state_dict=state_dict)
return model_data.sd_model return model_data.sd_model
@ -615,6 +698,8 @@ def reload_model_weights(sd_model=None, info=None):
print(f"Weights loaded in {timer.summary()}.") print(f"Weights loaded in {timer.summary()}.")
model_data.set_sd_model(sd_model)
return sd_model return sd_model

View File

@ -98,10 +98,10 @@ def extend_sdxl(model):
model.conditioner.wrapped = torch.nn.Module() model.conditioner.wrapped = torch.nn.Module()
sgm.modules.attention.print = lambda *args: None sgm.modules.attention.print = shared.ldm_print
sgm.modules.diffusionmodules.model.print = lambda *args: None sgm.modules.diffusionmodules.model.print = shared.ldm_print
sgm.modules.diffusionmodules.openaimodel.print = lambda *args: None sgm.modules.diffusionmodules.openaimodel.print = shared.ldm_print
sgm.modules.encoders.modules.print = lambda *args: None sgm.modules.encoders.modules.print = shared.ldm_print
# this gets the code to load the vanilla attention that we override # this gets the code to load the vanilla attention that we override
sgm.modules.attention.SDP_IS_AVAILABLE = True sgm.modules.attention.SDP_IS_AVAILABLE = True

View File

@ -400,6 +400,7 @@ options_templates.update(options_section(('system', "System"), {
"print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."), "print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
"list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""), "list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
"disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"), "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
"hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."),
})) }))
options_templates.update(options_section(('training', "Training"), { options_templates.update(options_section(('training', "Training"), {
@ -419,7 +420,9 @@ options_templates.update(options_section(('training', "Training"), {
options_templates.update(options_section(('sd', "Stable Diffusion"), { options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints), "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_checkpoints_limit": OptionInfo(1, "Maximum number of checkpoints loaded at the same time", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}),
"sd_checkpoints_keep_in_cpu": OptionInfo(True, "Only keep one model on device").info("will keep models other than the currently used one in RAM rather than VRAM"),
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}).info("obsolete; set to 0 and use the two settings above instead"),
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"), "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), "sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
@ -906,3 +909,10 @@ def walk_files(path, allowed_extensions=None):
continue continue
yield os.path.join(root, filename) yield os.path.join(root, filename)
def ldm_print(*args, **kwargs):
if opts.hide_ldm_prints:
return
print(*args, **kwargs)