change t5xxl checkpoint to fp8
This commit is contained in:
parent
58dc35a64a
commit
d4b814aed6
|
@ -29,7 +29,7 @@ CLIPL_CONFIG = {
|
||||||
"num_hidden_layers": 12,
|
"num_hidden_layers": 12,
|
||||||
}
|
}
|
||||||
|
|
||||||
T5_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/t5xxl_fp16.safetensors"
|
T5_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/t5xxl_fp8_e4m3fn.safetensors"
|
||||||
T5_CONFIG = {
|
T5_CONFIG = {
|
||||||
"d_ff": 10240,
|
"d_ff": 10240,
|
||||||
"d_model": 4096,
|
"d_model": 4096,
|
||||||
|
@ -101,7 +101,7 @@ class SD3Cond(torch.nn.Module):
|
||||||
with safetensors.safe_open(clip_l_file, framework="pt") as file:
|
with safetensors.safe_open(clip_l_file, framework="pt") as file:
|
||||||
self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
|
self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
|
||||||
|
|
||||||
t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors")
|
t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp8_e4m3fn.safetensors")
|
||||||
with safetensors.safe_open(t5_file, framework="pt") as file:
|
with safetensors.safe_open(t5_file, framework="pt") as file:
|
||||||
self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
|
self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue