2023-01-13 12:46:14 -07:00
|
|
|
import logging
|
2023-01-12 14:32:37 -07:00
|
|
|
import os
|
2023-01-23 11:19:22 -07:00
|
|
|
from typing import Optional, Tuple
|
2023-01-12 14:32:37 -07:00
|
|
|
|
|
|
|
import huggingface_hub
|
2023-02-20 14:00:25 -07:00
|
|
|
from diffusers import StableDiffusionPipeline
|
|
|
|
|
2023-05-26 23:03:01 -06:00
|
|
|
from utils.unet_utils import get_attn_yaml
|
2023-01-23 11:19:22 -07:00
|
|
|
|
2023-01-12 14:32:37 -07:00
|
|
|
|
2023-02-20 14:00:25 -07:00
|
|
|
def try_download_model_from_hf(repo_id: str) -> Tuple[StableDiffusionPipeline, str, bool, str] | None:
|
2023-01-12 14:32:37 -07:00
|
|
|
"""
|
|
|
|
Attempts to download files from the following subfolders under the given repo id:
|
|
|
|
"text_encoder", "vae", "unet", "scheduler", "tokenizer".
|
|
|
|
:param repo_id The repository id of the model on huggingface, such as 'stabilityai/stable-diffusion-2-1' which
|
|
|
|
corresponds to `https://huggingface.co/stabilityai/stable-diffusion-2-1`.
|
|
|
|
:param access_token Access token to use when fetching. If None, uses environment-saved token.
|
|
|
|
:return: Root folder on disk to the downloaded files, or None if download failed.
|
|
|
|
"""
|
|
|
|
|
2023-02-20 14:00:25 -07:00
|
|
|
access_token = os.environ.get('HF_API_TOKEN', None)
|
|
|
|
if access_token is not None:
|
|
|
|
huggingface_hub.login(access_token)
|
2023-01-12 14:32:37 -07:00
|
|
|
|
|
|
|
# check if the model exists
|
|
|
|
model_info = huggingface_hub.model_info(repo_id)
|
|
|
|
if model_info is None:
|
2023-02-20 14:00:25 -07:00
|
|
|
return None
|
|
|
|
|
|
|
|
# load it to download it
|
2023-07-08 12:23:50 -06:00
|
|
|
#pipe, cache_folder = StableDiffusionPipeline.from_pretrained(repo_id, return_cached_folder=True)
|
|
|
|
cache_folder = StableDiffusionPipeline.download(repo_id)
|
|
|
|
pipe = StableDiffusionPipeline.from_pretrained(repo_id)
|
2023-02-20 14:00:25 -07:00
|
|
|
|
|
|
|
is_sd1_attn, yaml_path = get_attn_yaml(cache_folder)
|
2023-07-08 12:23:50 -06:00
|
|
|
print(f"* HuggingFace Downloaded model from {repo_id} to {cache_folder}.")
|
|
|
|
print(f"** Using attention yaml file: {yaml_path}, is_sd1_attn: {is_sd1_attn}.")
|
2023-07-08 12:59:38 -06:00
|
|
|
|
2023-02-20 14:00:25 -07:00
|
|
|
return pipe, cache_folder, is_sd1_attn, yaml_path
|