support loading clip/t5 from the main model checkpoint
This commit is contained in:
parent
d67348a0a5
commit
7e4b06fcd0
|
@ -174,15 +174,10 @@ class SD3Cond(torch.nn.Module):
|
||||||
self.model_lg = Sd3ClipLG(self.clip_l, self.clip_g)
|
self.model_lg = Sd3ClipLG(self.clip_l, self.clip_g)
|
||||||
self.model_t5 = Sd3T5(self.t5xxl)
|
self.model_t5 = Sd3T5(self.t5xxl)
|
||||||
|
|
||||||
self.weights_loaded = False
|
|
||||||
|
|
||||||
def forward(self, prompts: list[str]):
|
def forward(self, prompts: list[str]):
|
||||||
with devices.without_autocast():
|
with devices.without_autocast():
|
||||||
lg_out, vector_out = self.model_lg(prompts)
|
lg_out, vector_out = self.model_lg(prompts)
|
||||||
|
t5_out = self.model_t5(prompts, token_count=lg_out.shape[1])
|
||||||
token_count = lg_out.shape[1]
|
|
||||||
|
|
||||||
t5_out = self.model_t5(prompts, token_count=token_count)
|
|
||||||
lgt_out = torch.cat([lg_out, t5_out], dim=-2)
|
lgt_out = torch.cat([lg_out, t5_out], dim=-2)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
@ -190,27 +185,24 @@ class SD3Cond(torch.nn.Module):
|
||||||
'vector': vector_out,
|
'vector': vector_out,
|
||||||
}
|
}
|
||||||
|
|
||||||
def load_weights(self):
|
def before_load_weights(self, state_dict):
|
||||||
if self.weights_loaded:
|
|
||||||
return
|
|
||||||
|
|
||||||
clip_path = os.path.join(shared.models_path, "CLIP")
|
clip_path = os.path.join(shared.models_path, "CLIP")
|
||||||
|
|
||||||
clip_g_file = modelloader.load_file_from_url(CLIPG_URL, model_dir=clip_path, file_name="clip_g.safetensors")
|
if 'text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight' not in state_dict:
|
||||||
with safetensors.safe_open(clip_g_file, framework="pt") as file:
|
clip_g_file = modelloader.load_file_from_url(CLIPG_URL, model_dir=clip_path, file_name="clip_g.safetensors")
|
||||||
self.clip_g.transformer.load_state_dict(SafetensorsMapping(file))
|
with safetensors.safe_open(clip_g_file, framework="pt") as file:
|
||||||
|
self.clip_g.transformer.load_state_dict(SafetensorsMapping(file))
|
||||||
|
|
||||||
clip_l_file = modelloader.load_file_from_url(CLIPL_URL, model_dir=clip_path, file_name="clip_l.safetensors")
|
if 'text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight' not in state_dict:
|
||||||
with safetensors.safe_open(clip_l_file, framework="pt") as file:
|
clip_l_file = modelloader.load_file_from_url(CLIPL_URL, model_dir=clip_path, file_name="clip_l.safetensors")
|
||||||
self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
|
with safetensors.safe_open(clip_l_file, framework="pt") as file:
|
||||||
|
self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
|
||||||
|
|
||||||
if self.t5xxl:
|
if self.t5xxl and 'text_encoders.t5xxl.transformer.encoder.embed_tokens.weight' not in state_dict:
|
||||||
t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors")
|
t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors")
|
||||||
with safetensors.safe_open(t5_file, framework="pt") as file:
|
with safetensors.safe_open(t5_file, framework="pt") as file:
|
||||||
self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
|
self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
|
||||||
|
|
||||||
self.weights_loaded = True
|
|
||||||
|
|
||||||
def encode_embedding_init_text(self, init_text, nvpt):
|
def encode_embedding_init_text(self, init_text, nvpt):
|
||||||
return torch.tensor([[0]], device=devices.device) # XXX
|
return torch.tensor([[0]], device=devices.device) # XXX
|
||||||
|
|
||||||
|
|
|
@ -31,7 +31,7 @@ class SD3Inferencer(torch.nn.Module):
|
||||||
|
|
||||||
self.alphas_cumprod = 1 / (self.model.model_sampling.sigmas ** 2 + 1)
|
self.alphas_cumprod = 1 / (self.model.model_sampling.sigmas ** 2 + 1)
|
||||||
|
|
||||||
self.cond_stage_model = SD3Cond()
|
self.text_encoders = SD3Cond()
|
||||||
self.cond_stage_key = 'txt'
|
self.cond_stage_key = 'txt'
|
||||||
|
|
||||||
self.parameterization = "eps"
|
self.parameterization = "eps"
|
||||||
|
@ -40,8 +40,12 @@ class SD3Inferencer(torch.nn.Module):
|
||||||
self.latent_format = SD3LatentFormat()
|
self.latent_format = SD3LatentFormat()
|
||||||
self.latent_channels = 16
|
self.latent_channels = 16
|
||||||
|
|
||||||
def after_load_weights(self):
|
@property
|
||||||
self.cond_stage_model.load_weights()
|
def cond_stage_model(self):
|
||||||
|
return self.text_encoders
|
||||||
|
|
||||||
|
def before_load_weights(self, state_dict):
|
||||||
|
self.cond_stage_model.before_load_weights(state_dict)
|
||||||
|
|
||||||
def ema_scope(self):
|
def ema_scope(self):
|
||||||
return contextlib.nullcontext()
|
return contextlib.nullcontext()
|
||||||
|
|
|
@ -434,9 +434,15 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
||||||
# cache newly loaded model
|
# cache newly loaded model
|
||||||
checkpoints_loaded[checkpoint_info] = state_dict.copy()
|
checkpoints_loaded[checkpoint_info] = state_dict.copy()
|
||||||
|
|
||||||
|
if hasattr(model, "before_load_weights"):
|
||||||
|
model.before_load_weights(state_dict)
|
||||||
|
|
||||||
model.load_state_dict(state_dict, strict=False)
|
model.load_state_dict(state_dict, strict=False)
|
||||||
timer.record("apply weights to model")
|
timer.record("apply weights to model")
|
||||||
|
|
||||||
|
if hasattr(model, "after_load_weights"):
|
||||||
|
model.after_load_weights(state_dict)
|
||||||
|
|
||||||
del state_dict
|
del state_dict
|
||||||
|
|
||||||
# Set is_sdxl_inpaint flag.
|
# Set is_sdxl_inpaint flag.
|
||||||
|
@ -838,9 +844,6 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||||
with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion):
|
with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion):
|
||||||
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
||||||
|
|
||||||
if hasattr(sd_model, "after_load_weights"):
|
|
||||||
sd_model.after_load_weights()
|
|
||||||
|
|
||||||
timer.record("load weights from state dict")
|
timer.record("load weights from state dict")
|
||||||
|
|
||||||
send_model_to_device(sd_model)
|
send_model_to_device(sd_model)
|
||||||
|
|
Loading…
Reference in New Issue