check for local files before downloading from HF

This commit is contained in:
Damian Stewart 2023-01-23 21:10:04 +01:00
parent 6ee717c807
commit 067ea506a2
1 changed files with 16 additions and 7 deletions

View File

@ -63,9 +63,12 @@ def clean_filename(filename):
"""
return "".join([c for c in filename if c.isalpha() or c.isdigit() or c==' ']).rstrip()
def get_hf_ckpt_cache_path(ckpt_path):
return os.path.join("ckpt_cache", os.path.basename(ckpt_path))
def convert_to_hf(ckpt_path):
hf_cache = os.path.join("ckpt_cache", os.path.basename(ckpt_path))
hf_cache = get_hf_ckpt_cache_path(ckpt_path)
from utils.patch_unet import patch_unet
if os.path.isfile(ckpt_path):
@ -458,13 +461,19 @@ def main(args):
del images
try:
# first try to download from HF using resume_ckpt as a repo id
hf_repo_subfolder = args.hf_repo_subfolder if hasattr(args, 'hf_repo_subfolder') else None
model_root_folder, is_sd1attn, yaml = try_download_model_from_hf(repo_id=args.resume_ckpt,
subfolder=hf_repo_subfolder)
# if that doesn't work, try to load resume_ckpt as a local file or folder
if model_root_folder is None:
# check for a local file
hf_cache_path = get_hf_ckpt_cache_path(args.resume_ckpt)
if os.path.exists(hf_cache_path) or os.path.exists(args.resume_ckpt):
model_root_folder, is_sd1attn, yaml = convert_to_hf(args.resume_ckpt)
else:
# try to download from HF using resume_ckpt as a repo id
print(f"local file/folder not found for {args.resume_ckpt}, will try to download from huggingface.co")
hf_repo_subfolder = args.hf_repo_subfolder if hasattr(args, 'hf_repo_subfolder') else None
model_root_folder, is_sd1attn, yaml = try_download_model_from_hf(repo_id=args.resume_ckpt,
subfolder=hf_repo_subfolder)
if model_root_folder is None:
raise ValueError(f"No local file/folder for {args.resume_ckpt}, and no matching huggingface.co repo could be downloaded")
text_encoder = CLIPTextModel.from_pretrained(model_root_folder, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(model_root_folder, subfolder="vae")