diff --git a/modules/models/sd3/sd3_cond.py b/modules/models/sd3/sd3_cond.py index e25ba1b63..bade90ba1 100644 --- a/modules/models/sd3/sd3_cond.py +++ b/modules/models/sd3/sd3_cond.py @@ -174,15 +174,10 @@ class SD3Cond(torch.nn.Module): self.model_lg = Sd3ClipLG(self.clip_l, self.clip_g) self.model_t5 = Sd3T5(self.t5xxl) - self.weights_loaded = False - def forward(self, prompts: list[str]): with devices.without_autocast(): lg_out, vector_out = self.model_lg(prompts) - - token_count = lg_out.shape[1] - - t5_out = self.model_t5(prompts, token_count=token_count) + t5_out = self.model_t5(prompts, token_count=lg_out.shape[1]) lgt_out = torch.cat([lg_out, t5_out], dim=-2) return { @@ -190,27 +185,24 @@ class SD3Cond(torch.nn.Module): 'vector': vector_out, } - def load_weights(self): - if self.weights_loaded: - return - + def before_load_weights(self, state_dict): clip_path = os.path.join(shared.models_path, "CLIP") - clip_g_file = modelloader.load_file_from_url(CLIPG_URL, model_dir=clip_path, file_name="clip_g.safetensors") - with safetensors.safe_open(clip_g_file, framework="pt") as file: - self.clip_g.transformer.load_state_dict(SafetensorsMapping(file)) + if 'text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight' not in state_dict: + clip_g_file = modelloader.load_file_from_url(CLIPG_URL, model_dir=clip_path, file_name="clip_g.safetensors") + with safetensors.safe_open(clip_g_file, framework="pt") as file: + self.clip_g.transformer.load_state_dict(SafetensorsMapping(file)) - clip_l_file = modelloader.load_file_from_url(CLIPL_URL, model_dir=clip_path, file_name="clip_l.safetensors") - with safetensors.safe_open(clip_l_file, framework="pt") as file: - self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False) + if 'text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight' not in state_dict: + clip_l_file = modelloader.load_file_from_url(CLIPL_URL, model_dir=clip_path, file_name="clip_l.safetensors") + with safetensors.safe_open(clip_l_file, framework="pt") as file: + self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False) - if self.t5xxl: + if self.t5xxl and 'text_encoders.t5xxl.transformer.encoder.embed_tokens.weight' not in state_dict: t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors") with safetensors.safe_open(t5_file, framework="pt") as file: self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False) - self.weights_loaded = True - def encode_embedding_init_text(self, init_text, nvpt): return torch.tensor([[0]], device=devices.device) # XXX diff --git a/modules/models/sd3/sd3_model.py b/modules/models/sd3/sd3_model.py index 2d66b80f1..98470cdab 100644 --- a/modules/models/sd3/sd3_model.py +++ b/modules/models/sd3/sd3_model.py @@ -31,7 +31,7 @@ class SD3Inferencer(torch.nn.Module): self.alphas_cumprod = 1 / (self.model.model_sampling.sigmas ** 2 + 1) - self.cond_stage_model = SD3Cond() + self.text_encoders = SD3Cond() self.cond_stage_key = 'txt' self.parameterization = "eps" @@ -40,8 +40,12 @@ class SD3Inferencer(torch.nn.Module): self.latent_format = SD3LatentFormat() self.latent_channels = 16 - def after_load_weights(self): - self.cond_stage_model.load_weights() + @property + def cond_stage_model(self): + return self.text_encoders + + def before_load_weights(self, state_dict): + self.cond_stage_model.before_load_weights(state_dict) def ema_scope(self): return contextlib.nullcontext() diff --git a/modules/sd_models.py b/modules/sd_models.py index 681030442..55bd9ca5e 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -434,9 +434,15 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer # cache newly loaded model checkpoints_loaded[checkpoint_info] = state_dict.copy() + if hasattr(model, "before_load_weights"): + model.before_load_weights(state_dict) + model.load_state_dict(state_dict, strict=False) timer.record("apply weights to model") + if hasattr(model, "after_load_weights"): + model.after_load_weights(state_dict) + del state_dict # Set is_sdxl_inpaint flag. @@ -838,9 +844,6 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion): load_model_weights(sd_model, checkpoint_info, state_dict, timer) - if hasattr(sd_model, "after_load_weights"): - sd_model.after_load_weights() - timer.record("load weights from state dict") send_model_to_device(sd_model)