diff --git a/modules/models/sd3/sd3_model.py b/modules/models/sd3/sd3_model.py index 8b8285244..d60b04e4e 100644 --- a/modules/models/sd3/sd3_model.py +++ b/modules/models/sd3/sd3_model.py @@ -61,9 +61,9 @@ class SD3Cond(torch.nn.Module): self.tokenizer = SD3Tokenizer() with torch.no_grad(): - self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=torch.float32) - self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=torch.float32, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG) - self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=torch.float32) + self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=devices.dtype) + self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG) + self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype) self.weights_loaded = False diff --git a/modules/sd_models.py b/modules/sd_models.py index da083f71d..61fb881ba 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -406,6 +406,7 @@ def set_model_fields(model): if not hasattr(model, 'latent_channels'): model.latent_channels = 4 + def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer): sd_model_hash = checkpoint_info.calculate_shorthash() timer.record("calculate hash")