274 lines
8.9 KiB
Python
274 lines
8.9 KiB
Python
import time
|
|
import os
|
|
|
|
from datetime import timedelta
|
|
from loguru import logger
|
|
from pathlib import Path
|
|
from typing import Optional, List
|
|
|
|
from huggingface_hub import file_download, hf_api, HfApi, hf_hub_download
|
|
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
|
from huggingface_hub.utils import (
|
|
LocalEntryNotFoundError,
|
|
EntryNotFoundError,
|
|
RevisionNotFoundError, # noqa # Import here to ease try/except in other part of the lib
|
|
)
|
|
|
|
WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
|
|
HF_HUB_OFFLINE = os.environ.get("HF_HUB_OFFLINE", "0").lower() in ["true", "1", "yes"]
|
|
|
|
|
|
def _cached_adapter_weight_files(
|
|
adapter_id: str, revision: Optional[str], extension: str
|
|
) -> List[str]:
|
|
"""Guess weight files from the cached revision snapshot directory"""
|
|
d = _get_cached_revision_directory(adapter_id, revision)
|
|
if not d:
|
|
return []
|
|
filenames = _adapter_weight_files_from_dir(d, extension)
|
|
return filenames
|
|
|
|
|
|
def _cached_weight_files(
|
|
model_id: str, revision: Optional[str], extension: str
|
|
) -> List[str]:
|
|
"""Guess weight files from the cached revision snapshot directory"""
|
|
d = _get_cached_revision_directory(model_id, revision)
|
|
if not d:
|
|
return []
|
|
filenames = _weight_files_from_dir(d, extension)
|
|
return filenames
|
|
|
|
|
|
def _weight_hub_files_from_model_info(
|
|
info: hf_api.ModelInfo, extension: str
|
|
) -> List[str]:
|
|
return [
|
|
s.rfilename
|
|
for s in info.siblings
|
|
if s.rfilename.endswith(extension)
|
|
and len(s.rfilename.split("/")) == 1
|
|
and "arguments" not in s.rfilename
|
|
and "args" not in s.rfilename
|
|
and "training" not in s.rfilename
|
|
]
|
|
|
|
|
|
def _weight_files_from_dir(d: Path, extension: str) -> List[str]:
|
|
# os.walk: do not iterate, just scan for depth 1, not recursively
|
|
# see _weight_hub_files_from_model_info, that's also what is
|
|
# done there with the len(s.rfilename.split("/")) == 1 condition
|
|
root, _, files = next(os.walk(str(d)))
|
|
filenames = [
|
|
os.path.join(root, f)
|
|
for f in files
|
|
if f.endswith(extension)
|
|
and "arguments" not in f
|
|
and "args" not in f
|
|
and "adapter" not in f
|
|
and "training" not in f
|
|
]
|
|
return filenames
|
|
|
|
|
|
def _adapter_weight_files_from_dir(d: Path, extension: str) -> List[str]:
|
|
# os.walk: do not iterate, just scan for depth 1, not recursively
|
|
# see _weight_files_from_dir, that's also what is done there
|
|
root, _, files = next(os.walk(str(d)))
|
|
filenames = [
|
|
os.path.join(root, f)
|
|
for f in files
|
|
if f.endswith(extension)
|
|
and "arguments" not in f
|
|
and "args" not in f
|
|
and "training" not in f
|
|
]
|
|
return filenames
|
|
|
|
|
|
def _adapter_config_files_from_dir(d: Path) -> List[str]:
|
|
# os.walk: do not iterate, just scan for depth 1, not recursively
|
|
# see _weight_files_from_dir, that's also what is done there
|
|
root, _, files = next(os.walk(str(d)))
|
|
filenames = [
|
|
os.path.join(root, f)
|
|
for f in files
|
|
if f.endswith(".json") and "arguments" not in f and "args" not in f
|
|
]
|
|
return filenames
|
|
|
|
|
|
def _get_cached_revision_directory(
|
|
model_id: str, revision: Optional[str]
|
|
) -> Optional[Path]:
|
|
if revision is None:
|
|
revision = "main"
|
|
|
|
repo_cache = Path(HUGGINGFACE_HUB_CACHE) / Path(
|
|
file_download.repo_folder_name(repo_id=model_id, repo_type="model")
|
|
)
|
|
|
|
if not repo_cache.is_dir():
|
|
# No cache for this model
|
|
return None
|
|
|
|
refs_dir = repo_cache / "refs"
|
|
snapshots_dir = repo_cache / "snapshots"
|
|
|
|
# Resolve refs (for instance to convert main to the associated commit sha)
|
|
if refs_dir.is_dir():
|
|
revision_file = refs_dir / revision
|
|
if revision_file.exists():
|
|
with revision_file.open() as f:
|
|
revision = f.read()
|
|
|
|
# Check if revision folder exists
|
|
if not snapshots_dir.exists():
|
|
return None
|
|
cached_shas = os.listdir(snapshots_dir)
|
|
if revision not in cached_shas:
|
|
# No cache for this revision and we won't try to return a random revision
|
|
return None
|
|
|
|
return snapshots_dir / revision
|
|
|
|
|
|
def weight_hub_files(
|
|
model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
|
|
) -> List[str]:
|
|
"""Get the weights filenames on the hub"""
|
|
api = HfApi()
|
|
|
|
if HF_HUB_OFFLINE:
|
|
filenames = _cached_weight_files(model_id, revision, extension)
|
|
else:
|
|
# Online case, fetch model info from the Hub
|
|
info = api.model_info(model_id, revision=revision)
|
|
filenames = _weight_hub_files_from_model_info(info, extension)
|
|
|
|
if not filenames:
|
|
raise EntryNotFoundError(
|
|
f"No {extension} weights found for model {model_id} and revision {revision}.",
|
|
None,
|
|
)
|
|
|
|
return filenames
|
|
|
|
|
|
def try_to_load_from_cache(
|
|
model_id: str, revision: Optional[str], filename: str
|
|
) -> Optional[Path]:
|
|
"""Try to load a file from the Hugging Face cache"""
|
|
|
|
d = _get_cached_revision_directory(model_id, revision)
|
|
if not d:
|
|
return None
|
|
|
|
# Check if file exists in cache
|
|
cached_file = d / filename
|
|
return cached_file if cached_file.is_file() else None
|
|
|
|
|
|
def weight_files(
|
|
model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
|
|
) -> List[Path]:
|
|
"""Get the local files"""
|
|
# Local model
|
|
d = Path(model_id)
|
|
if d.exists() and d.is_dir():
|
|
local_files = _weight_files_from_dir(d, extension)
|
|
if not local_files:
|
|
raise FileNotFoundError(
|
|
f"No local weights found in {model_id} with extension {extension}"
|
|
)
|
|
return [Path(f) for f in local_files]
|
|
|
|
try:
|
|
filenames = weight_hub_files(model_id, revision, extension)
|
|
except EntryNotFoundError as e:
|
|
if extension != ".safetensors":
|
|
raise e
|
|
# Try to see if there are pytorch weights
|
|
pt_filenames = weight_hub_files(model_id, revision, extension=".bin")
|
|
# Change pytorch extension to safetensors extension
|
|
# It is possible that we have safetensors weights locally even though they are not on the
|
|
# hub if we converted weights locally without pushing them
|
|
filenames = [
|
|
f"{Path(f).stem.lstrip('pytorch_')}.safetensors" for f in pt_filenames
|
|
]
|
|
|
|
if WEIGHTS_CACHE_OVERRIDE is not None:
|
|
files = []
|
|
for filename in filenames:
|
|
p = Path(WEIGHTS_CACHE_OVERRIDE) / filename
|
|
if not p.exists():
|
|
raise FileNotFoundError(
|
|
f"File {p} not found in {WEIGHTS_CACHE_OVERRIDE}."
|
|
)
|
|
files.append(p)
|
|
return files
|
|
|
|
files = []
|
|
for filename in filenames:
|
|
cache_file = try_to_load_from_cache(
|
|
model_id, revision=revision, filename=filename
|
|
)
|
|
if cache_file is None:
|
|
raise LocalEntryNotFoundError(
|
|
f"File {filename} of model {model_id} not found in "
|
|
f"{os.getenv('HUGGINGFACE_HUB_CACHE', 'the local cache')}. "
|
|
f"Please run `text-generation-server download-weights {model_id}` first."
|
|
)
|
|
files.append(cache_file)
|
|
|
|
return files
|
|
|
|
|
|
def download_weights(
|
|
filenames: List[str], model_id: str, revision: Optional[str] = None
|
|
) -> List[Path]:
|
|
"""Download the safetensors files from the hub"""
|
|
|
|
def download_file(fname, tries=5, backoff: int = 5):
|
|
local_file = try_to_load_from_cache(model_id, revision, fname)
|
|
if local_file is not None:
|
|
logger.info(f"File {fname} already present in cache.")
|
|
return Path(local_file)
|
|
|
|
for idx in range(tries):
|
|
try:
|
|
logger.info(f"Download file: {fname}")
|
|
stime = time.time()
|
|
local_file = hf_hub_download(
|
|
filename=fname,
|
|
repo_id=model_id,
|
|
revision=revision,
|
|
local_files_only=HF_HUB_OFFLINE,
|
|
)
|
|
logger.info(
|
|
f"Downloaded {local_file} in {timedelta(seconds=int(time.time() - stime))}."
|
|
)
|
|
return Path(local_file)
|
|
except Exception as e:
|
|
if idx + 1 == tries:
|
|
raise e
|
|
logger.error(e)
|
|
logger.info(f"Retrying in {backoff} seconds")
|
|
time.sleep(backoff)
|
|
logger.info(f"Retry {idx + 1}/{tries - 1}")
|
|
|
|
# We do this instead of using tqdm because we want to parse the logs with the launcher
|
|
start_time = time.time()
|
|
files = []
|
|
for i, filename in enumerate(filenames):
|
|
file = download_file(filename)
|
|
|
|
elapsed = timedelta(seconds=int(time.time() - start_time))
|
|
remaining = len(filenames) - (i + 1)
|
|
eta = (elapsed / (i + 1)) * remaining if remaining > 0 else 0
|
|
|
|
logger.info(f"Download: [{i + 1}/{len(filenames)}] -- ETA: {eta}")
|
|
files.append(file)
|
|
|
|
return files
|