assume epsilon for compatibility with old diffusers converted files
This commit is contained in:
parent
8b13c7ed1f
commit
1665f07e61
|
@ -19,7 +19,7 @@ import logging
|
||||||
|
|
||||||
def get_attn_yaml(ckpt_path):
|
def get_attn_yaml(ckpt_path):
|
||||||
"""
|
"""
|
||||||
Patch the UNet to use updated attention heads for xformers support in FP32
|
Analyze the checkpoint to determine the attention head type and yaml to use for inference
|
||||||
"""
|
"""
|
||||||
unet_cfg_path = os.path.join(ckpt_path, "unet", "config.json")
|
unet_cfg_path = os.path.join(ckpt_path, "unet", "config.json")
|
||||||
with open(unet_cfg_path, "r") as f:
|
with open(unet_cfg_path, "r") as f:
|
||||||
|
@ -32,7 +32,11 @@ def get_attn_yaml(ckpt_path):
|
||||||
is_sd1attn = unet_cfg["attention_head_dim"] == [8, 8, 8, 8]
|
is_sd1attn = unet_cfg["attention_head_dim"] == [8, 8, 8, 8]
|
||||||
is_sd1attn = unet_cfg["attention_head_dim"] == 8 or is_sd1attn
|
is_sd1attn = unet_cfg["attention_head_dim"] == 8 or is_sd1attn
|
||||||
|
|
||||||
prediction_type = scheduler_cfg["prediction_type"]
|
if 'prediction_type' not in scheduler_cfg:
|
||||||
|
logging.warn(f"Model has no prediction_type, assuming epsilon")
|
||||||
|
prediction_type = "epsilon"
|
||||||
|
else:
|
||||||
|
prediction_type = scheduler_cfg["prediction_type"]
|
||||||
|
|
||||||
logging.info(f" unet attention_head_dim: {unet_cfg['attention_head_dim']}")
|
logging.info(f" unet attention_head_dim: {unet_cfg['attention_head_dim']}")
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue