Merge pull request #12 from damian0815/hf_model_download
Enable download models from huggingface
This commit is contained in:
commit
81599bb548
44
train.py
44
train.py
|
@ -23,9 +23,9 @@ import logging
|
||||||
import time
|
import time
|
||||||
import gc
|
import gc
|
||||||
import random
|
import random
|
||||||
|
import traceback
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.cuda.amp import autocast, GradScaler
|
from torch.cuda.amp import autocast, GradScaler
|
||||||
import torchvision.transforms as transforms
|
import torchvision.transforms as transforms
|
||||||
|
@ -50,6 +50,7 @@ import wandb
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from data.every_dream import EveryDreamBatch
|
from data.every_dream import EveryDreamBatch
|
||||||
|
from utils.huggingface_downloader import try_download_model_from_hf
|
||||||
from utils.convert_diff_to_ckpt import convert as converter
|
from utils.convert_diff_to_ckpt import convert as converter
|
||||||
from utils.gpu import GPU
|
from utils.gpu import GPU
|
||||||
|
|
||||||
|
@ -62,8 +63,12 @@ def clean_filename(filename):
|
||||||
"""
|
"""
|
||||||
return "".join([c for c in filename if c.isalpha() or c.isdigit() or c==' ']).rstrip()
|
return "".join([c for c in filename if c.isalpha() or c.isdigit() or c==' ']).rstrip()
|
||||||
|
|
||||||
|
def get_hf_ckpt_cache_path(ckpt_path):
|
||||||
|
return os.path.join("ckpt_cache", os.path.basename(ckpt_path))
|
||||||
|
|
||||||
def convert_to_hf(ckpt_path):
|
def convert_to_hf(ckpt_path):
|
||||||
hf_cache = os.path.join("ckpt_cache", os.path.basename(ckpt_path))
|
|
||||||
|
hf_cache = get_hf_ckpt_cache_path(ckpt_path)
|
||||||
from utils.patch_unet import patch_unet
|
from utils.patch_unet import patch_unet
|
||||||
|
|
||||||
if os.path.isfile(ckpt_path):
|
if os.path.isfile(ckpt_path):
|
||||||
|
@ -455,15 +460,29 @@ def main(args):
|
||||||
del tfimage
|
del tfimage
|
||||||
del images
|
del images
|
||||||
|
|
||||||
try:
|
try:
|
||||||
hf_ckpt_path, is_sd1attn, yaml = convert_to_hf(args.resume_ckpt)
|
|
||||||
text_encoder = CLIPTextModel.from_pretrained(hf_ckpt_path, subfolder="text_encoder")
|
# check for a local file
|
||||||
vae = AutoencoderKL.from_pretrained(hf_ckpt_path, subfolder="vae")
|
hf_cache_path = get_hf_ckpt_cache_path(args.resume_ckpt)
|
||||||
unet = UNet2DConditionModel.from_pretrained(hf_ckpt_path, subfolder="unet", upcast_attention=not is_sd1attn)
|
if os.path.exists(hf_cache_path) or os.path.exists(args.resume_ckpt):
|
||||||
sample_scheduler = DDIMScheduler.from_pretrained(hf_ckpt_path, subfolder="scheduler")
|
model_root_folder, is_sd1attn, yaml = convert_to_hf(args.resume_ckpt)
|
||||||
noise_scheduler = DDPMScheduler.from_pretrained(hf_ckpt_path, subfolder="scheduler")
|
else:
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(hf_ckpt_path, subfolder="tokenizer", use_fast=False)
|
# try to download from HF using resume_ckpt as a repo id
|
||||||
except:
|
print(f"local file/folder not found for {args.resume_ckpt}, will try to download from huggingface.co")
|
||||||
|
hf_repo_subfolder = args.hf_repo_subfolder if hasattr(args, 'hf_repo_subfolder') else None
|
||||||
|
model_root_folder, is_sd1attn, yaml = try_download_model_from_hf(repo_id=args.resume_ckpt,
|
||||||
|
subfolder=hf_repo_subfolder)
|
||||||
|
if model_root_folder is None:
|
||||||
|
raise ValueError(f"No local file/folder for {args.resume_ckpt}, and no matching huggingface.co repo could be downloaded")
|
||||||
|
|
||||||
|
text_encoder = CLIPTextModel.from_pretrained(model_root_folder, subfolder="text_encoder")
|
||||||
|
vae = AutoencoderKL.from_pretrained(model_root_folder, subfolder="vae")
|
||||||
|
unet = UNet2DConditionModel.from_pretrained(model_root_folder, subfolder="unet")
|
||||||
|
sample_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
|
||||||
|
noise_scheduler = DDPMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
|
||||||
|
tokenizer = CLIPTokenizer.from_pretrained(model_root_folder, subfolder="tokenizer", use_fast=False)
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
logging.ERROR(" * Failed to load checkpoint *")
|
logging.ERROR(" * Failed to load checkpoint *")
|
||||||
|
|
||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
|
@ -943,6 +962,7 @@ if __name__ == "__main__":
|
||||||
argparser.add_argument("--gpuid", type=int, default=0, help="id of gpu to use for training, (def: 0) (ex: 1 to use GPU_ID 1)")
|
argparser.add_argument("--gpuid", type=int, default=0, help="id of gpu to use for training, (def: 0) (ex: 1 to use GPU_ID 1)")
|
||||||
argparser.add_argument("--gradient_checkpointing", action="store_true", default=False, help="enable gradient checkpointing to reduce VRAM use, may reduce performance (def: False)")
|
argparser.add_argument("--gradient_checkpointing", action="store_true", default=False, help="enable gradient checkpointing to reduce VRAM use, may reduce performance (def: False)")
|
||||||
argparser.add_argument("--grad_accum", type=int, default=1, help="Gradient accumulation factor (def: 1), (ex, 2)")
|
argparser.add_argument("--grad_accum", type=int, default=1, help="Gradient accumulation factor (def: 1), (ex, 2)")
|
||||||
|
argparser.add_argument("--hf_repo_subfolder", type=str, default=None, help="Subfolder inside the huggingface repo to download, if the model is not in the root of the repo.")
|
||||||
argparser.add_argument("--logdir", type=str, default="logs", help="folder to save logs to (def: logs)")
|
argparser.add_argument("--logdir", type=str, default="logs", help="folder to save logs to (def: logs)")
|
||||||
argparser.add_argument("--log_step", type=int, default=25, help="How often to log training stats, def: 25, recommend default!")
|
argparser.add_argument("--log_step", type=int, default=25, help="How often to log training stats, def: 25, recommend default!")
|
||||||
argparser.add_argument("--lowvram", action="store_true", default=False, help="automatically overrides various args to support 12GB gpu")
|
argparser.add_argument("--lowvram", action="store_true", default=False, help="automatically overrides various args to support 12GB gpu")
|
||||||
|
@ -954,7 +974,7 @@ if __name__ == "__main__":
|
||||||
argparser.add_argument("--notebook", action="store_true", default=False, help="disable keypresses and uses tqdm.notebook for jupyter notebook (def: False)")
|
argparser.add_argument("--notebook", action="store_true", default=False, help="disable keypresses and uses tqdm.notebook for jupyter notebook (def: False)")
|
||||||
argparser.add_argument("--project_name", type=str, default="myproj", help="Project name for logs and checkpoints, ex. 'tedbennett', 'superduperV1'")
|
argparser.add_argument("--project_name", type=str, default="myproj", help="Project name for logs and checkpoints, ex. 'tedbennett', 'superduperV1'")
|
||||||
argparser.add_argument("--resolution", type=int, default=512, help="resolution to train", choices=supported_resolutions)
|
argparser.add_argument("--resolution", type=int, default=512, help="resolution to train", choices=supported_resolutions)
|
||||||
argparser.add_argument("--resume_ckpt", type=str, required=True, default="sd_v1-5_vae.ckpt")
|
argparser.add_argument("--resume_ckpt", type=str, required=True, default="sd_v1-5_vae.ckpt", help="The checkpoint to resume from, either a local .ckpt file, a converted Diffusers format folder, or a Huggingface.co repo id such as stabilityai/stable-diffusion-2-1 ")
|
||||||
argparser.add_argument("--sample_prompts", type=str, default="sample_prompts.txt", help="File with prompts to generate test samples from (def: sample_prompts.txt)")
|
argparser.add_argument("--sample_prompts", type=str, default="sample_prompts.txt", help="File with prompts to generate test samples from (def: sample_prompts.txt)")
|
||||||
argparser.add_argument("--sample_steps", type=int, default=250, help="Number of steps between samples (def: 250)")
|
argparser.add_argument("--sample_steps", type=int, default=250, help="Number of steps between samples (def: 250)")
|
||||||
argparser.add_argument("--save_ckpt_dir", type=str, default=None, help="folder to save checkpoints to (def: root training folder)")
|
argparser.add_argument("--save_ckpt_dir", type=str, default=None, help="folder to save checkpoints to (def: root training folder)")
|
||||||
|
|
|
@ -0,0 +1,42 @@
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import huggingface_hub
|
||||||
|
from utils.patch_unet import patch_unet
|
||||||
|
|
||||||
|
|
||||||
|
def try_download_model_from_hf(repo_id: str,
|
||||||
|
subfolder: Optional[str]=None) -> Tuple[Optional[str], Optional[bool], Optional[str]]:
|
||||||
|
"""
|
||||||
|
Attempts to download files from the following subfolders under the given repo id:
|
||||||
|
"text_encoder", "vae", "unet", "scheduler", "tokenizer".
|
||||||
|
:param repo_id The repository id of the model on huggingface, such as 'stabilityai/stable-diffusion-2-1' which
|
||||||
|
corresponds to `https://huggingface.co/stabilityai/stable-diffusion-2-1`.
|
||||||
|
:param access_token Access token to use when fetching. If None, uses environment-saved token.
|
||||||
|
:return: Root folder on disk to the downloaded files, or None if download failed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
access_token = os.environ['HF_API_TOKEN']
|
||||||
|
if access_token is not None:
|
||||||
|
huggingface_hub.login(access_token)
|
||||||
|
except:
|
||||||
|
logging.info("no HF_API_TOKEN env var found, will attempt to download without authenticating")
|
||||||
|
|
||||||
|
# check if the model exists
|
||||||
|
model_info = huggingface_hub.model_info(repo_id)
|
||||||
|
if model_info is None:
|
||||||
|
return None, None, None
|
||||||
|
|
||||||
|
model_subfolders = ["text_encoder", "vae", "unet", "scheduler", "tokenizer"]
|
||||||
|
allow_patterns = ["model_index.json"] + [os.path.join(subfolder or '', f, "*") for f in model_subfolders]
|
||||||
|
# prefer *.bin files for now
|
||||||
|
# TODO: look for *.safetensors files and download them instead, if they exist
|
||||||
|
ignore_patterns = "*.safetensors"
|
||||||
|
downloaded_folder = huggingface_hub.snapshot_download(repo_id=repo_id,
|
||||||
|
allow_patterns=allow_patterns,
|
||||||
|
ignore_patterns=ignore_patterns)
|
||||||
|
print(f"model with repo id {repo_id} downloaded to {downloaded_folder}")
|
||||||
|
is_sd1_attn, yaml_path = patch_unet(downloaded_folder)
|
||||||
|
return downloaded_folder, is_sd1_attn, yaml_path
|
Loading…
Reference in New Issue