fix: fix offline (#1341) (#1347)

@oOraph

---------

Signed-off-by: Raphael Glon <oOraph@users.noreply.github.com>
Co-authored-by: Raphael Glon <oOraph@users.noreply.github.com>
This commit is contained in:
OlivierDehaene 2023-12-18 10:20:08 +01:00 committed by GitHub
parent 1b1bfa49b0
commit 8428ed1011
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 142 additions and 36 deletions

View File

@ -556,6 +556,7 @@ mod tests {
inputs: "Hello".to_string(),
parameters: GenerateParameters {
top_p: Some(1.0),
max_new_tokens: Some(5),
..default_parameters()
},
})
@ -570,6 +571,7 @@ mod tests {
inputs: "Hello".to_string(),
parameters: GenerateParameters {
top_p: Some(0.99),
max_new_tokens: Some(5),
..default_parameters()
},
})
@ -584,6 +586,7 @@ mod tests {
inputs: "Hello".to_string(),
parameters: GenerateParameters {
top_p: None,
max_new_tokens: Some(5),
..default_parameters()
},
})
@ -616,6 +619,7 @@ mod tests {
inputs: "Hello".to_string(),
parameters: GenerateParameters {
top_n_tokens: Some(5),
max_new_tokens: Some(5),
..default_parameters()
},
})
@ -630,6 +634,7 @@ mod tests {
inputs: "Hello".to_string(),
parameters: GenerateParameters {
top_n_tokens: Some(4),
max_new_tokens: Some(5),
..default_parameters()
},
})
@ -641,6 +646,7 @@ mod tests {
inputs: "Hello".to_string(),
parameters: GenerateParameters {
top_n_tokens: Some(0),
max_new_tokens: Some(5),
..default_parameters()
},
})
@ -652,6 +658,7 @@ mod tests {
inputs: "Hello".to_string(),
parameters: GenerateParameters {
top_n_tokens: None,
max_new_tokens: Some(5),
..default_parameters()
},
})

View File

@ -1,5 +1,13 @@
import os
import requests
import tempfile
import pytest
import huggingface_hub.constants
from huggingface_hub import hf_api
import text_generation_server.utils.hub
from text_generation_server.utils.hub import (
weight_hub_files,
download_weights,
@ -10,6 +18,52 @@ from text_generation_server.utils.hub import (
)
@pytest.fixture()
def offline():
current_value = text_generation_server.utils.hub.HF_HUB_OFFLINE
text_generation_server.utils.hub.HF_HUB_OFFLINE = True
yield "offline"
text_generation_server.utils.hub.HF_HUB_OFFLINE = current_value
@pytest.fixture()
def fresh_cache():
with tempfile.TemporaryDirectory() as d:
current_value = huggingface_hub.constants.HUGGINGFACE_HUB_CACHE
huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = d
text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = d
os.environ['HUGGINGFACE_HUB_CACHE'] = d
yield
huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = current_value
os.environ['HUGGINGFACE_HUB_CACHE'] = current_value
text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = current_value
@pytest.fixture()
def prefetched():
model_id = "bert-base-uncased"
huggingface_hub.snapshot_download(
repo_id=model_id,
revision="main",
local_files_only=False,
repo_type="model",
allow_patterns=["*.safetensors"]
)
yield model_id
def test_weight_hub_files_offline_error(offline, fresh_cache):
# If the model is not prefetched then it will raise an error
with pytest.raises(EntryNotFoundError):
weight_hub_files("gpt2")
def test_weight_hub_files_offline_ok(prefetched, offline):
# If the model is prefetched then we should be able to get the weight files from local cache
filenames = weight_hub_files(prefetched)
assert filenames == ['model.safetensors']
def test_weight_hub_files():
filenames = weight_hub_files("bigscience/bloom-560m")
assert filenames == ["model.safetensors"]
@ -33,8 +87,11 @@ def test_download_weights():
assert files == local_files
def test_weight_files_error():
def test_weight_files_revision_error():
with pytest.raises(RevisionNotFoundError):
weight_files("bigscience/bloom-560m", revision="error")
def test_weight_files_not_cached_error(fresh_cache):
with pytest.raises(LocalEntryNotFoundError):
weight_files("bert-base-uncased")

View File

@ -6,24 +6,29 @@ from loguru import logger
from pathlib import Path
from typing import Optional, List
from huggingface_hub import HfApi, hf_hub_download
from huggingface_hub import file_download, hf_api, HfApi, hf_hub_download
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from huggingface_hub.utils import (
LocalEntryNotFoundError,
EntryNotFoundError,
RevisionNotFoundError, # Import here to ease try/except in other part of the lib
RevisionNotFoundError, # noqa # Import here to ease try/except in other part of the lib
)
WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
HF_HUB_OFFLINE = os.environ.get("HF_HUB_OFFLINE", "0").lower() in ["true", "1", "yes"]
def weight_hub_files(
model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
) -> List[str]:
"""Get the weights filenames on the hub"""
api = HfApi()
info = api.model_info(model_id, revision=revision)
filenames = [
def _cached_weight_files(model_id: str, revision: Optional[str], extension: str) -> List[str]:
"""Guess weight files from the cached revision snapshot directory"""
d = _get_cached_revision_directory(model_id, revision)
if not d:
return []
filenames = _weight_files_from_dir(d, extension)
return filenames
def _weight_hub_files_from_model_info(info: hf_api.ModelInfo, extension: str) -> List[str]:
return [
s.rfilename
for s in info.siblings
if s.rfilename.endswith(extension)
@ -33,24 +38,26 @@ def weight_hub_files(
and "training" not in s.rfilename
]
if not filenames:
raise EntryNotFoundError(
f"No {extension} weights found for model {model_id} and revision {revision}.",
None,
)
def _weight_files_from_dir(d: Path, extension: str) -> List[str]:
# os.walk: do not iterate, just scan for depth 1, not recursively
# see _weight_hub_files_from_model_info, that's also what is
# done there with the len(s.rfilename.split("/")) == 1 condition
root, _, files = next(os.walk(str(d)))
filenames = [f for f in files
if f.endswith(extension)
and "arguments" not in f
and "args" not in f
and "training" not in f]
return filenames
def try_to_load_from_cache(
model_id: str, revision: Optional[str], filename: str
) -> Optional[Path]:
"""Try to load a file from the Hugging Face cache"""
def _get_cached_revision_directory(model_id: str, revision: Optional[str]) -> Optional[Path]:
if revision is None:
revision = "main"
object_id = model_id.replace("/", "--")
repo_cache = Path(HUGGINGFACE_HUB_CACHE) / f"models--{object_id}"
repo_cache = Path(HUGGINGFACE_HUB_CACHE) / Path(
file_download.repo_folder_name(repo_id=model_id, repo_type="model"))
if not repo_cache.is_dir():
# No cache for this model
@ -74,8 +81,42 @@ def try_to_load_from_cache(
# No cache for this revision and we won't try to return a random revision
return None
return snapshots_dir / revision
def weight_hub_files(
model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
) -> List[str]:
"""Get the weights filenames on the hub"""
api = HfApi()
if HF_HUB_OFFLINE:
filenames = _cached_weight_files(model_id, revision, extension)
else:
# Online case, fetch model info from the Hub
info = api.model_info(model_id, revision=revision)
filenames = _weight_hub_files_from_model_info(info, extension)
if not filenames:
raise EntryNotFoundError(
f"No {extension} weights found for model {model_id} and revision {revision}.",
None,
)
return filenames
def try_to_load_from_cache(
model_id: str, revision: Optional[str], filename: str
) -> Optional[Path]:
"""Try to load a file from the Hugging Face cache"""
d = _get_cached_revision_directory(model_id, revision)
if not d:
return None
# Check if file exists in cache
cached_file = snapshots_dir / revision / filename
cached_file = d / filename
return cached_file if cached_file.is_file() else None
@ -84,13 +125,14 @@ def weight_files(
) -> List[Path]:
"""Get the local files"""
# Local model
if Path(model_id).exists() and Path(model_id).is_dir():
local_files = list(Path(model_id).glob(f"*{extension}"))
d = Path(model_id)
if d.exists() and d.is_dir():
local_files = _weight_files_from_dir(d, extension)
if not local_files:
raise FileNotFoundError(
f"No local weights found in {model_id} with extension {extension}"
)
return local_files
return [Path(f) for f in local_files]
try:
filenames = weight_hub_files(model_id, revision, extension)
@ -138,33 +180,33 @@ def download_weights(
) -> List[Path]:
"""Download the safetensors files from the hub"""
def download_file(filename, tries=5, backoff: int = 5):
local_file = try_to_load_from_cache(model_id, revision, filename)
def download_file(fname, tries=5, backoff: int = 5):
local_file = try_to_load_from_cache(model_id, revision, fname)
if local_file is not None:
logger.info(f"File {filename} already present in cache.")
logger.info(f"File {fname} already present in cache.")
return Path(local_file)
for i in range(tries):
for idx in range(tries):
try:
logger.info(f"Download file: {filename}")
start_time = time.time()
logger.info(f"Download file: {fname}")
stime = time.time()
local_file = hf_hub_download(
filename=filename,
filename=fname,
repo_id=model_id,
revision=revision,
local_files_only=False,
local_files_only=HF_HUB_OFFLINE,
)
logger.info(
f"Downloaded {local_file} in {timedelta(seconds=int(time.time() - start_time))}."
f"Downloaded {local_file} in {timedelta(seconds=int(time.time() - stime))}."
)
return Path(local_file)
except Exception as e:
if i + 1 == tries:
if idx + 1 == tries:
raise e
logger.error(e)
logger.info(f"Retrying in {backoff} seconds")
time.sleep(backoff)
logger.info(f"Retry {i + 1}/{tries - 1}")
logger.info(f"Retry {idx + 1}/{tries - 1}")
# We do this instead of using tqdm because we want to parse the logs with the launcher
start_time = time.time()