diff --git a/router/src/validation.rs b/router/src/validation.rs index 90dc3741..64f25c82 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -556,6 +556,7 @@ mod tests { inputs: "Hello".to_string(), parameters: GenerateParameters { top_p: Some(1.0), + max_new_tokens: Some(5), ..default_parameters() }, }) @@ -570,6 +571,7 @@ mod tests { inputs: "Hello".to_string(), parameters: GenerateParameters { top_p: Some(0.99), + max_new_tokens: Some(5), ..default_parameters() }, }) @@ -584,6 +586,7 @@ mod tests { inputs: "Hello".to_string(), parameters: GenerateParameters { top_p: None, + max_new_tokens: Some(5), ..default_parameters() }, }) @@ -616,6 +619,7 @@ mod tests { inputs: "Hello".to_string(), parameters: GenerateParameters { top_n_tokens: Some(5), + max_new_tokens: Some(5), ..default_parameters() }, }) @@ -630,6 +634,7 @@ mod tests { inputs: "Hello".to_string(), parameters: GenerateParameters { top_n_tokens: Some(4), + max_new_tokens: Some(5), ..default_parameters() }, }) @@ -641,6 +646,7 @@ mod tests { inputs: "Hello".to_string(), parameters: GenerateParameters { top_n_tokens: Some(0), + max_new_tokens: Some(5), ..default_parameters() }, }) @@ -652,6 +658,7 @@ mod tests { inputs: "Hello".to_string(), parameters: GenerateParameters { top_n_tokens: None, + max_new_tokens: Some(5), ..default_parameters() }, }) diff --git a/server/tests/utils/test_hub.py b/server/tests/utils/test_hub.py index fac9a64d..5438c153 100644 --- a/server/tests/utils/test_hub.py +++ b/server/tests/utils/test_hub.py @@ -1,5 +1,13 @@ +import os +import requests +import tempfile + import pytest +import huggingface_hub.constants +from huggingface_hub import hf_api + +import text_generation_server.utils.hub from text_generation_server.utils.hub import ( weight_hub_files, download_weights, @@ -10,6 +18,52 @@ from text_generation_server.utils.hub import ( ) +@pytest.fixture() +def offline(): + current_value = text_generation_server.utils.hub.HF_HUB_OFFLINE + text_generation_server.utils.hub.HF_HUB_OFFLINE = True + yield "offline" + text_generation_server.utils.hub.HF_HUB_OFFLINE = current_value + + +@pytest.fixture() +def fresh_cache(): + with tempfile.TemporaryDirectory() as d: + current_value = huggingface_hub.constants.HUGGINGFACE_HUB_CACHE + huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = d + text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = d + os.environ['HUGGINGFACE_HUB_CACHE'] = d + yield + huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = current_value + os.environ['HUGGINGFACE_HUB_CACHE'] = current_value + text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = current_value + + +@pytest.fixture() +def prefetched(): + model_id = "bert-base-uncased" + huggingface_hub.snapshot_download( + repo_id=model_id, + revision="main", + local_files_only=False, + repo_type="model", + allow_patterns=["*.safetensors"] + ) + yield model_id + + +def test_weight_hub_files_offline_error(offline, fresh_cache): + # If the model is not prefetched then it will raise an error + with pytest.raises(EntryNotFoundError): + weight_hub_files("gpt2") + + +def test_weight_hub_files_offline_ok(prefetched, offline): + # If the model is prefetched then we should be able to get the weight files from local cache + filenames = weight_hub_files(prefetched) + assert filenames == ['model.safetensors'] + + def test_weight_hub_files(): filenames = weight_hub_files("bigscience/bloom-560m") assert filenames == ["model.safetensors"] @@ -33,8 +87,11 @@ def test_download_weights(): assert files == local_files -def test_weight_files_error(): +def test_weight_files_revision_error(): with pytest.raises(RevisionNotFoundError): weight_files("bigscience/bloom-560m", revision="error") + + +def test_weight_files_not_cached_error(fresh_cache): with pytest.raises(LocalEntryNotFoundError): weight_files("bert-base-uncased") diff --git a/server/text_generation_server/utils/hub.py b/server/text_generation_server/utils/hub.py index 23743c9b..019d4855 100644 --- a/server/text_generation_server/utils/hub.py +++ b/server/text_generation_server/utils/hub.py @@ -6,24 +6,29 @@ from loguru import logger from pathlib import Path from typing import Optional, List -from huggingface_hub import HfApi, hf_hub_download +from huggingface_hub import file_download, hf_api, HfApi, hf_hub_download from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from huggingface_hub.utils import ( LocalEntryNotFoundError, EntryNotFoundError, - RevisionNotFoundError, # Import here to ease try/except in other part of the lib + RevisionNotFoundError, # noqa # Import here to ease try/except in other part of the lib ) WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None) +HF_HUB_OFFLINE = os.environ.get("HF_HUB_OFFLINE", "0").lower() in ["true", "1", "yes"] -def weight_hub_files( - model_id: str, revision: Optional[str] = None, extension: str = ".safetensors" -) -> List[str]: - """Get the weights filenames on the hub""" - api = HfApi() - info = api.model_info(model_id, revision=revision) - filenames = [ +def _cached_weight_files(model_id: str, revision: Optional[str], extension: str) -> List[str]: + """Guess weight files from the cached revision snapshot directory""" + d = _get_cached_revision_directory(model_id, revision) + if not d: + return [] + filenames = _weight_files_from_dir(d, extension) + return filenames + + +def _weight_hub_files_from_model_info(info: hf_api.ModelInfo, extension: str) -> List[str]: + return [ s.rfilename for s in info.siblings if s.rfilename.endswith(extension) @@ -33,24 +38,26 @@ def weight_hub_files( and "training" not in s.rfilename ] - if not filenames: - raise EntryNotFoundError( - f"No {extension} weights found for model {model_id} and revision {revision}.", - None, - ) +def _weight_files_from_dir(d: Path, extension: str) -> List[str]: + # os.walk: do not iterate, just scan for depth 1, not recursively + # see _weight_hub_files_from_model_info, that's also what is + # done there with the len(s.rfilename.split("/")) == 1 condition + root, _, files = next(os.walk(str(d))) + filenames = [f for f in files + if f.endswith(extension) + and "arguments" not in f + and "args" not in f + and "training" not in f] return filenames -def try_to_load_from_cache( - model_id: str, revision: Optional[str], filename: str -) -> Optional[Path]: - """Try to load a file from the Hugging Face cache""" +def _get_cached_revision_directory(model_id: str, revision: Optional[str]) -> Optional[Path]: if revision is None: revision = "main" - object_id = model_id.replace("/", "--") - repo_cache = Path(HUGGINGFACE_HUB_CACHE) / f"models--{object_id}" + repo_cache = Path(HUGGINGFACE_HUB_CACHE) / Path( + file_download.repo_folder_name(repo_id=model_id, repo_type="model")) if not repo_cache.is_dir(): # No cache for this model @@ -74,8 +81,42 @@ def try_to_load_from_cache( # No cache for this revision and we won't try to return a random revision return None + return snapshots_dir / revision + + +def weight_hub_files( + model_id: str, revision: Optional[str] = None, extension: str = ".safetensors" +) -> List[str]: + """Get the weights filenames on the hub""" + api = HfApi() + + if HF_HUB_OFFLINE: + filenames = _cached_weight_files(model_id, revision, extension) + else: + # Online case, fetch model info from the Hub + info = api.model_info(model_id, revision=revision) + filenames = _weight_hub_files_from_model_info(info, extension) + + if not filenames: + raise EntryNotFoundError( + f"No {extension} weights found for model {model_id} and revision {revision}.", + None, + ) + + return filenames + + +def try_to_load_from_cache( + model_id: str, revision: Optional[str], filename: str +) -> Optional[Path]: + """Try to load a file from the Hugging Face cache""" + + d = _get_cached_revision_directory(model_id, revision) + if not d: + return None + # Check if file exists in cache - cached_file = snapshots_dir / revision / filename + cached_file = d / filename return cached_file if cached_file.is_file() else None @@ -84,13 +125,14 @@ def weight_files( ) -> List[Path]: """Get the local files""" # Local model - if Path(model_id).exists() and Path(model_id).is_dir(): - local_files = list(Path(model_id).glob(f"*{extension}")) + d = Path(model_id) + if d.exists() and d.is_dir(): + local_files = _weight_files_from_dir(d, extension) if not local_files: raise FileNotFoundError( f"No local weights found in {model_id} with extension {extension}" ) - return local_files + return [Path(f) for f in local_files] try: filenames = weight_hub_files(model_id, revision, extension) @@ -138,33 +180,33 @@ def download_weights( ) -> List[Path]: """Download the safetensors files from the hub""" - def download_file(filename, tries=5, backoff: int = 5): - local_file = try_to_load_from_cache(model_id, revision, filename) + def download_file(fname, tries=5, backoff: int = 5): + local_file = try_to_load_from_cache(model_id, revision, fname) if local_file is not None: - logger.info(f"File {filename} already present in cache.") + logger.info(f"File {fname} already present in cache.") return Path(local_file) - for i in range(tries): + for idx in range(tries): try: - logger.info(f"Download file: {filename}") - start_time = time.time() + logger.info(f"Download file: {fname}") + stime = time.time() local_file = hf_hub_download( - filename=filename, + filename=fname, repo_id=model_id, revision=revision, - local_files_only=False, + local_files_only=HF_HUB_OFFLINE, ) logger.info( - f"Downloaded {local_file} in {timedelta(seconds=int(time.time() - start_time))}." + f"Downloaded {local_file} in {timedelta(seconds=int(time.time() - stime))}." ) return Path(local_file) except Exception as e: - if i + 1 == tries: + if idx + 1 == tries: raise e logger.error(e) logger.info(f"Retrying in {backoff} seconds") time.sleep(backoff) - logger.info(f"Retry {i + 1}/{tries - 1}") + logger.info(f"Retry {idx + 1}/{tries - 1}") # We do this instead of using tqdm because we want to parse the logs with the launcher start_time = time.time()