check for local files before downloading from HF
This commit is contained in:
parent
6ee717c807
commit
067ea506a2
23
train.py
23
train.py
|
@ -63,9 +63,12 @@ def clean_filename(filename):
|
|||
"""
|
||||
return "".join([c for c in filename if c.isalpha() or c.isdigit() or c==' ']).rstrip()
|
||||
|
||||
def get_hf_ckpt_cache_path(ckpt_path):
|
||||
return os.path.join("ckpt_cache", os.path.basename(ckpt_path))
|
||||
|
||||
def convert_to_hf(ckpt_path):
|
||||
|
||||
hf_cache = os.path.join("ckpt_cache", os.path.basename(ckpt_path))
|
||||
hf_cache = get_hf_ckpt_cache_path(ckpt_path)
|
||||
from utils.patch_unet import patch_unet
|
||||
|
||||
if os.path.isfile(ckpt_path):
|
||||
|
@ -458,13 +461,19 @@ def main(args):
|
|||
del images
|
||||
|
||||
try:
|
||||
# first try to download from HF using resume_ckpt as a repo id
|
||||
hf_repo_subfolder = args.hf_repo_subfolder if hasattr(args, 'hf_repo_subfolder') else None
|
||||
model_root_folder, is_sd1attn, yaml = try_download_model_from_hf(repo_id=args.resume_ckpt,
|
||||
subfolder=hf_repo_subfolder)
|
||||
# if that doesn't work, try to load resume_ckpt as a local file or folder
|
||||
if model_root_folder is None:
|
||||
|
||||
# check for a local file
|
||||
hf_cache_path = get_hf_ckpt_cache_path(args.resume_ckpt)
|
||||
if os.path.exists(hf_cache_path) or os.path.exists(args.resume_ckpt):
|
||||
model_root_folder, is_sd1attn, yaml = convert_to_hf(args.resume_ckpt)
|
||||
else:
|
||||
# try to download from HF using resume_ckpt as a repo id
|
||||
print(f"local file/folder not found for {args.resume_ckpt}, will try to download from huggingface.co")
|
||||
hf_repo_subfolder = args.hf_repo_subfolder if hasattr(args, 'hf_repo_subfolder') else None
|
||||
model_root_folder, is_sd1attn, yaml = try_download_model_from_hf(repo_id=args.resume_ckpt,
|
||||
subfolder=hf_repo_subfolder)
|
||||
if model_root_folder is None:
|
||||
raise ValueError(f"No local file/folder for {args.resume_ckpt}, and no matching huggingface.co repo could be downloaded")
|
||||
|
||||
text_encoder = CLIPTextModel.from_pretrained(model_root_folder, subfolder="text_encoder")
|
||||
vae = AutoencoderKL.from_pretrained(model_root_folder, subfolder="vae")
|
||||
|
|
Loading…
Reference in New Issue