Merge pull request #83 from damian0815/fix-huggingface-download-safetensors-models
use StableDiffusionPipeline.from_pretrained() to download HF models
This commit is contained in:
commit
30c7b2f96d
20
train.py
20
train.py
|
@ -427,21 +427,24 @@ def main(args):
|
|||
hf_cache_path = get_hf_ckpt_cache_path(args.resume_ckpt)
|
||||
if os.path.exists(hf_cache_path) or os.path.exists(args.resume_ckpt):
|
||||
model_root_folder, is_sd1attn, yaml = convert_to_hf(args.resume_ckpt)
|
||||
text_encoder = CLIPTextModel.from_pretrained(model_root_folder, subfolder="text_encoder")
|
||||
vae = AutoencoderKL.from_pretrained(model_root_folder, subfolder="vae")
|
||||
unet = UNet2DConditionModel.from_pretrained(model_root_folder, subfolder="unet")
|
||||
else:
|
||||
# try to download from HF using resume_ckpt as a repo id
|
||||
print(f"local file/folder not found for {args.resume_ckpt}, will try to download from huggingface.co")
|
||||
hf_repo_subfolder = args.hf_repo_subfolder if hasattr(args, 'hf_repo_subfolder') else None
|
||||
model_root_folder, is_sd1attn, yaml = try_download_model_from_hf(repo_id=args.resume_ckpt,
|
||||
subfolder=hf_repo_subfolder)
|
||||
if model_root_folder is None:
|
||||
downloaded = try_download_model_from_hf(repo_id=args.resume_ckpt)
|
||||
if downloaded is None:
|
||||
raise ValueError(f"No local file/folder for {args.resume_ckpt}, and no matching huggingface.co repo could be downloaded")
|
||||
pipe, model_root_folder, is_sd1attn, yaml = downloaded
|
||||
text_encoder = pipe.text_encoder
|
||||
vae = pipe.vae
|
||||
unet = pipe.unet
|
||||
del pipe
|
||||
|
||||
text_encoder = CLIPTextModel.from_pretrained(model_root_folder, subfolder="text_encoder")
|
||||
vae = AutoencoderKL.from_pretrained(model_root_folder, subfolder="vae")
|
||||
unet = UNet2DConditionModel.from_pretrained(model_root_folder, subfolder="unet")
|
||||
reference_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
|
||||
tokenizer = CLIPTokenizer.from_pretrained(model_root_folder, subfolder="tokenizer", use_fast=False)
|
||||
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
logging.error(" * Failed to load checkpoint *")
|
||||
|
@ -910,7 +913,6 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--gpuid", type=int, default=0, help="id of gpu to use for training, (def: 0) (ex: 1 to use GPU_ID 1)")
|
||||
argparser.add_argument("--gradient_checkpointing", action="store_true", default=False, help="enable gradient checkpointing to reduce VRAM use, may reduce performance (def: False)")
|
||||
argparser.add_argument("--grad_accum", type=int, default=1, help="Gradient accumulation factor (def: 1), (ex, 2)")
|
||||
argparser.add_argument("--hf_repo_subfolder", type=str, default=None, help="Subfolder inside the huggingface repo to download, if the model is not in the root of the repo.")
|
||||
argparser.add_argument("--logdir", type=str, default="logs", help="folder to save logs to (def: logs)")
|
||||
argparser.add_argument("--log_step", type=int, default=25, help="How often to log training stats, def: 25, recommend default!")
|
||||
argparser.add_argument("--lowvram", action="store_true", default=False, help="automatically overrides various args to support 12GB gpu")
|
||||
|
|
|
@ -3,11 +3,12 @@ import os
|
|||
from typing import Optional, Tuple
|
||||
|
||||
import huggingface_hub
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
from utils.analyze_unet import get_attn_yaml
|
||||
|
||||
|
||||
def try_download_model_from_hf(repo_id: str,
|
||||
subfolder: Optional[str]=None) -> Tuple[Optional[str], Optional[bool], Optional[str]]:
|
||||
def try_download_model_from_hf(repo_id: str) -> Tuple[StableDiffusionPipeline, str, bool, str] | None:
|
||||
"""
|
||||
Attempts to download files from the following subfolders under the given repo id:
|
||||
"text_encoder", "vae", "unet", "scheduler", "tokenizer".
|
||||
|
@ -17,26 +18,17 @@ def try_download_model_from_hf(repo_id: str,
|
|||
:return: Root folder on disk to the downloaded files, or None if download failed.
|
||||
"""
|
||||
|
||||
try:
|
||||
access_token = os.environ['HF_API_TOKEN']
|
||||
if access_token is not None:
|
||||
huggingface_hub.login(access_token)
|
||||
except:
|
||||
logging.info("no HF_API_TOKEN env var found, will attempt to download without authenticating")
|
||||
access_token = os.environ.get('HF_API_TOKEN', None)
|
||||
if access_token is not None:
|
||||
huggingface_hub.login(access_token)
|
||||
|
||||
# check if the model exists
|
||||
model_info = huggingface_hub.model_info(repo_id)
|
||||
if model_info is None:
|
||||
return None, None, None
|
||||
return None
|
||||
|
||||
model_subfolders = ["text_encoder", "vae", "unet", "scheduler", "tokenizer"]
|
||||
allow_patterns = ["model_index.json"] + [os.path.join(subfolder or '', f, "*") for f in model_subfolders]
|
||||
# prefer *.bin files for now
|
||||
# TODO: look for *.safetensors files and download them instead, if they exist
|
||||
ignore_patterns = "*.safetensors"
|
||||
downloaded_folder = huggingface_hub.snapshot_download(repo_id=repo_id,
|
||||
allow_patterns=allow_patterns,
|
||||
ignore_patterns=ignore_patterns)
|
||||
print(f"model with repo id {repo_id} downloaded to {downloaded_folder}")
|
||||
is_sd1_attn, yaml_path = get_attn_yaml(downloaded_folder)
|
||||
return downloaded_folder, is_sd1_attn, yaml_path
|
||||
# load it to download it
|
||||
pipe, cache_folder = StableDiffusionPipeline.from_pretrained(repo_id, return_cached_folder=True)
|
||||
|
||||
is_sd1_attn, yaml_path = get_attn_yaml(cache_folder)
|
||||
return pipe, cache_folder, is_sd1_attn, yaml_path
|
||||
|
|
Loading…
Reference in New Issue