diff --git a/modules/sd_models.py b/modules/sd_models.py index 685585b1c..2c976561e 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -462,6 +462,7 @@ class SdModelData: def __init__(self): self.sd_model = None self.loaded_sd_models = [] + self.loaded_vae_states = {} self.was_loaded_at_least_once = False self.lock = threading.Lock() @@ -485,16 +486,27 @@ class SdModelData: return self.sd_model - def set_sd_model(self, v): + def set_sd_model(self, v, already_loaded=False): self.sd_model = v + if already_loaded: + sd_vae_state = self.loaded_vae_states.get(v.sd_model_hash, {}) + sd_vae.base_vae = sd_vae_state.get("base_vae", None) + sd_vae.loaded_vae_file = sd_vae_state.get("loaded_vae_file", None) + sd_vae.checkpoint_info = sd_vae_state.get("checkpoint_info", None) try: self.loaded_sd_models.remove(v) + self.loaded_vae_states.pop(v.sd_model_hash, {}).clear() except ValueError: pass if v is not None: self.loaded_sd_models.insert(0, v) + self.loaded_vae_states[v.sd_model_hash] = dict( + base_vae=sd_vae.base_vae, + loaded_vae_file=sd_vae.loaded_vae_file, + checkpoint_info=sd_vae.checkpoint_info, + ) model_data = SdModelData() @@ -649,6 +661,7 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): if len(model_data.loaded_sd_models) > shared.opts.sd_checkpoints_limit > 0: print(f"Unloading model {len(model_data.loaded_sd_models)} over the limit of {shared.opts.sd_checkpoints_limit}: {loaded_model.sd_checkpoint_info.title}") model_data.loaded_sd_models.pop() + model_data.loaded_vae_states.pop(loaded_model.sd_model_hash, {}).clear() send_model_to_trash(loaded_model) timer.record("send model to trash") @@ -660,7 +673,7 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): send_model_to_device(already_loaded) timer.record("send model to device") - model_data.set_sd_model(already_loaded) + model_data.set_sd_model(already_loaded, already_loaded=True) if not SkipWritingToConfig.skip: shared.opts.data["sd_model_checkpoint"] = already_loaded.sd_checkpoint_info.title @@ -678,6 +691,11 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): sd_model = model_data.loaded_sd_models.pop() model_data.sd_model = sd_model + sd_vae_state = model_data.loaded_vae_states.pop(sd_model.sd_model_hash, {}) + sd_vae.base_vae = sd_vae_state.get("base_vae", None) + sd_vae.loaded_vae_file = sd_vae_state.get("loaded_vae_file", None) + sd_vae.checkpoint_info = sd_vae_state.get("checkpoint_info", None) + print(f"Reusing loaded model {sd_model.sd_checkpoint_info.title} to load {checkpoint_info.title}") return sd_model else: