assume epsilon for compatibility with old diffusers converted files

This commit is contained in:
Victor Hall 2023-02-08 15:42:07 -05:00
parent 8b13c7ed1f
commit 1665f07e61
1 changed files with 6 additions and 2 deletions

View File

@ -19,7 +19,7 @@ import logging
def get_attn_yaml(ckpt_path): def get_attn_yaml(ckpt_path):
""" """
Patch the UNet to use updated attention heads for xformers support in FP32 Analyze the checkpoint to determine the attention head type and yaml to use for inference
""" """
unet_cfg_path = os.path.join(ckpt_path, "unet", "config.json") unet_cfg_path = os.path.join(ckpt_path, "unet", "config.json")
with open(unet_cfg_path, "r") as f: with open(unet_cfg_path, "r") as f:
@ -32,7 +32,11 @@ def get_attn_yaml(ckpt_path):
is_sd1attn = unet_cfg["attention_head_dim"] == [8, 8, 8, 8] is_sd1attn = unet_cfg["attention_head_dim"] == [8, 8, 8, 8]
is_sd1attn = unet_cfg["attention_head_dim"] == 8 or is_sd1attn is_sd1attn = unet_cfg["attention_head_dim"] == 8 or is_sd1attn
prediction_type = scheduler_cfg["prediction_type"] if 'prediction_type' not in scheduler_cfg:
logging.warn(f"Model has no prediction_type, assuming epsilon")
prediction_type = "epsilon"
else:
prediction_type = scheduler_cfg["prediction_type"]
logging.info(f" unet attention_head_dim: {unet_cfg['attention_head_dim']}") logging.info(f" unet attention_head_dim: {unet_cfg['attention_head_dim']}")