support loading clip/t5 from the main model checkpoint
This commit is contained in:
parent
d67348a0a5
commit
7e4b06fcd0
|
@ -174,15 +174,10 @@ class SD3Cond(torch.nn.Module):
|
|||
self.model_lg = Sd3ClipLG(self.clip_l, self.clip_g)
|
||||
self.model_t5 = Sd3T5(self.t5xxl)
|
||||
|
||||
self.weights_loaded = False
|
||||
|
||||
def forward(self, prompts: list[str]):
|
||||
with devices.without_autocast():
|
||||
lg_out, vector_out = self.model_lg(prompts)
|
||||
|
||||
token_count = lg_out.shape[1]
|
||||
|
||||
t5_out = self.model_t5(prompts, token_count=token_count)
|
||||
t5_out = self.model_t5(prompts, token_count=lg_out.shape[1])
|
||||
lgt_out = torch.cat([lg_out, t5_out], dim=-2)
|
||||
|
||||
return {
|
||||
|
@ -190,27 +185,24 @@ class SD3Cond(torch.nn.Module):
|
|||
'vector': vector_out,
|
||||
}
|
||||
|
||||
def load_weights(self):
|
||||
if self.weights_loaded:
|
||||
return
|
||||
|
||||
def before_load_weights(self, state_dict):
|
||||
clip_path = os.path.join(shared.models_path, "CLIP")
|
||||
|
||||
clip_g_file = modelloader.load_file_from_url(CLIPG_URL, model_dir=clip_path, file_name="clip_g.safetensors")
|
||||
with safetensors.safe_open(clip_g_file, framework="pt") as file:
|
||||
self.clip_g.transformer.load_state_dict(SafetensorsMapping(file))
|
||||
if 'text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight' not in state_dict:
|
||||
clip_g_file = modelloader.load_file_from_url(CLIPG_URL, model_dir=clip_path, file_name="clip_g.safetensors")
|
||||
with safetensors.safe_open(clip_g_file, framework="pt") as file:
|
||||
self.clip_g.transformer.load_state_dict(SafetensorsMapping(file))
|
||||
|
||||
clip_l_file = modelloader.load_file_from_url(CLIPL_URL, model_dir=clip_path, file_name="clip_l.safetensors")
|
||||
with safetensors.safe_open(clip_l_file, framework="pt") as file:
|
||||
self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
|
||||
if 'text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight' not in state_dict:
|
||||
clip_l_file = modelloader.load_file_from_url(CLIPL_URL, model_dir=clip_path, file_name="clip_l.safetensors")
|
||||
with safetensors.safe_open(clip_l_file, framework="pt") as file:
|
||||
self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
|
||||
|
||||
if self.t5xxl:
|
||||
if self.t5xxl and 'text_encoders.t5xxl.transformer.encoder.embed_tokens.weight' not in state_dict:
|
||||
t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors")
|
||||
with safetensors.safe_open(t5_file, framework="pt") as file:
|
||||
self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
|
||||
|
||||
self.weights_loaded = True
|
||||
|
||||
def encode_embedding_init_text(self, init_text, nvpt):
|
||||
return torch.tensor([[0]], device=devices.device) # XXX
|
||||
|
||||
|
|
|
@ -31,7 +31,7 @@ class SD3Inferencer(torch.nn.Module):
|
|||
|
||||
self.alphas_cumprod = 1 / (self.model.model_sampling.sigmas ** 2 + 1)
|
||||
|
||||
self.cond_stage_model = SD3Cond()
|
||||
self.text_encoders = SD3Cond()
|
||||
self.cond_stage_key = 'txt'
|
||||
|
||||
self.parameterization = "eps"
|
||||
|
@ -40,8 +40,12 @@ class SD3Inferencer(torch.nn.Module):
|
|||
self.latent_format = SD3LatentFormat()
|
||||
self.latent_channels = 16
|
||||
|
||||
def after_load_weights(self):
|
||||
self.cond_stage_model.load_weights()
|
||||
@property
|
||||
def cond_stage_model(self):
|
||||
return self.text_encoders
|
||||
|
||||
def before_load_weights(self, state_dict):
|
||||
self.cond_stage_model.before_load_weights(state_dict)
|
||||
|
||||
def ema_scope(self):
|
||||
return contextlib.nullcontext()
|
||||
|
|
|
@ -434,9 +434,15 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
|||
# cache newly loaded model
|
||||
checkpoints_loaded[checkpoint_info] = state_dict.copy()
|
||||
|
||||
if hasattr(model, "before_load_weights"):
|
||||
model.before_load_weights(state_dict)
|
||||
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
timer.record("apply weights to model")
|
||||
|
||||
if hasattr(model, "after_load_weights"):
|
||||
model.after_load_weights(state_dict)
|
||||
|
||||
del state_dict
|
||||
|
||||
# Set is_sdxl_inpaint flag.
|
||||
|
@ -838,9 +844,6 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
|||
with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion):
|
||||
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
||||
|
||||
if hasattr(sd_model, "after_load_weights"):
|
||||
sd_model.after_load_weights()
|
||||
|
||||
timer.record("load weights from state dict")
|
||||
|
||||
send_model_to_device(sd_model)
|
||||
|
|
Loading…
Reference in New Issue