@oOraph --------- Signed-off-by: Raphael Glon <oOraph@users.noreply.github.com> Co-authored-by: Raphael Glon <oOraph@users.noreply.github.com>
This commit is contained in:
parent
1b1bfa49b0
commit
8428ed1011
|
@ -556,6 +556,7 @@ mod tests {
|
|||
inputs: "Hello".to_string(),
|
||||
parameters: GenerateParameters {
|
||||
top_p: Some(1.0),
|
||||
max_new_tokens: Some(5),
|
||||
..default_parameters()
|
||||
},
|
||||
})
|
||||
|
@ -570,6 +571,7 @@ mod tests {
|
|||
inputs: "Hello".to_string(),
|
||||
parameters: GenerateParameters {
|
||||
top_p: Some(0.99),
|
||||
max_new_tokens: Some(5),
|
||||
..default_parameters()
|
||||
},
|
||||
})
|
||||
|
@ -584,6 +586,7 @@ mod tests {
|
|||
inputs: "Hello".to_string(),
|
||||
parameters: GenerateParameters {
|
||||
top_p: None,
|
||||
max_new_tokens: Some(5),
|
||||
..default_parameters()
|
||||
},
|
||||
})
|
||||
|
@ -616,6 +619,7 @@ mod tests {
|
|||
inputs: "Hello".to_string(),
|
||||
parameters: GenerateParameters {
|
||||
top_n_tokens: Some(5),
|
||||
max_new_tokens: Some(5),
|
||||
..default_parameters()
|
||||
},
|
||||
})
|
||||
|
@ -630,6 +634,7 @@ mod tests {
|
|||
inputs: "Hello".to_string(),
|
||||
parameters: GenerateParameters {
|
||||
top_n_tokens: Some(4),
|
||||
max_new_tokens: Some(5),
|
||||
..default_parameters()
|
||||
},
|
||||
})
|
||||
|
@ -641,6 +646,7 @@ mod tests {
|
|||
inputs: "Hello".to_string(),
|
||||
parameters: GenerateParameters {
|
||||
top_n_tokens: Some(0),
|
||||
max_new_tokens: Some(5),
|
||||
..default_parameters()
|
||||
},
|
||||
})
|
||||
|
@ -652,6 +658,7 @@ mod tests {
|
|||
inputs: "Hello".to_string(),
|
||||
parameters: GenerateParameters {
|
||||
top_n_tokens: None,
|
||||
max_new_tokens: Some(5),
|
||||
..default_parameters()
|
||||
},
|
||||
})
|
||||
|
|
|
@ -1,5 +1,13 @@
|
|||
import os
|
||||
import requests
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
import huggingface_hub.constants
|
||||
from huggingface_hub import hf_api
|
||||
|
||||
import text_generation_server.utils.hub
|
||||
from text_generation_server.utils.hub import (
|
||||
weight_hub_files,
|
||||
download_weights,
|
||||
|
@ -10,6 +18,52 @@ from text_generation_server.utils.hub import (
|
|||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def offline():
|
||||
current_value = text_generation_server.utils.hub.HF_HUB_OFFLINE
|
||||
text_generation_server.utils.hub.HF_HUB_OFFLINE = True
|
||||
yield "offline"
|
||||
text_generation_server.utils.hub.HF_HUB_OFFLINE = current_value
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def fresh_cache():
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
current_value = huggingface_hub.constants.HUGGINGFACE_HUB_CACHE
|
||||
huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = d
|
||||
text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = d
|
||||
os.environ['HUGGINGFACE_HUB_CACHE'] = d
|
||||
yield
|
||||
huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = current_value
|
||||
os.environ['HUGGINGFACE_HUB_CACHE'] = current_value
|
||||
text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = current_value
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def prefetched():
|
||||
model_id = "bert-base-uncased"
|
||||
huggingface_hub.snapshot_download(
|
||||
repo_id=model_id,
|
||||
revision="main",
|
||||
local_files_only=False,
|
||||
repo_type="model",
|
||||
allow_patterns=["*.safetensors"]
|
||||
)
|
||||
yield model_id
|
||||
|
||||
|
||||
def test_weight_hub_files_offline_error(offline, fresh_cache):
|
||||
# If the model is not prefetched then it will raise an error
|
||||
with pytest.raises(EntryNotFoundError):
|
||||
weight_hub_files("gpt2")
|
||||
|
||||
|
||||
def test_weight_hub_files_offline_ok(prefetched, offline):
|
||||
# If the model is prefetched then we should be able to get the weight files from local cache
|
||||
filenames = weight_hub_files(prefetched)
|
||||
assert filenames == ['model.safetensors']
|
||||
|
||||
|
||||
def test_weight_hub_files():
|
||||
filenames = weight_hub_files("bigscience/bloom-560m")
|
||||
assert filenames == ["model.safetensors"]
|
||||
|
@ -33,8 +87,11 @@ def test_download_weights():
|
|||
assert files == local_files
|
||||
|
||||
|
||||
def test_weight_files_error():
|
||||
def test_weight_files_revision_error():
|
||||
with pytest.raises(RevisionNotFoundError):
|
||||
weight_files("bigscience/bloom-560m", revision="error")
|
||||
|
||||
|
||||
def test_weight_files_not_cached_error(fresh_cache):
|
||||
with pytest.raises(LocalEntryNotFoundError):
|
||||
weight_files("bert-base-uncased")
|
||||
|
|
|
@ -6,24 +6,29 @@ from loguru import logger
|
|||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
|
||||
from huggingface_hub import HfApi, hf_hub_download
|
||||
from huggingface_hub import file_download, hf_api, HfApi, hf_hub_download
|
||||
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
||||
from huggingface_hub.utils import (
|
||||
LocalEntryNotFoundError,
|
||||
EntryNotFoundError,
|
||||
RevisionNotFoundError, # Import here to ease try/except in other part of the lib
|
||||
RevisionNotFoundError, # noqa # Import here to ease try/except in other part of the lib
|
||||
)
|
||||
|
||||
WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
|
||||
HF_HUB_OFFLINE = os.environ.get("HF_HUB_OFFLINE", "0").lower() in ["true", "1", "yes"]
|
||||
|
||||
|
||||
def weight_hub_files(
|
||||
model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
|
||||
) -> List[str]:
|
||||
"""Get the weights filenames on the hub"""
|
||||
api = HfApi()
|
||||
info = api.model_info(model_id, revision=revision)
|
||||
filenames = [
|
||||
def _cached_weight_files(model_id: str, revision: Optional[str], extension: str) -> List[str]:
|
||||
"""Guess weight files from the cached revision snapshot directory"""
|
||||
d = _get_cached_revision_directory(model_id, revision)
|
||||
if not d:
|
||||
return []
|
||||
filenames = _weight_files_from_dir(d, extension)
|
||||
return filenames
|
||||
|
||||
|
||||
def _weight_hub_files_from_model_info(info: hf_api.ModelInfo, extension: str) -> List[str]:
|
||||
return [
|
||||
s.rfilename
|
||||
for s in info.siblings
|
||||
if s.rfilename.endswith(extension)
|
||||
|
@ -33,24 +38,26 @@ def weight_hub_files(
|
|||
and "training" not in s.rfilename
|
||||
]
|
||||
|
||||
if not filenames:
|
||||
raise EntryNotFoundError(
|
||||
f"No {extension} weights found for model {model_id} and revision {revision}.",
|
||||
None,
|
||||
)
|
||||
|
||||
def _weight_files_from_dir(d: Path, extension: str) -> List[str]:
|
||||
# os.walk: do not iterate, just scan for depth 1, not recursively
|
||||
# see _weight_hub_files_from_model_info, that's also what is
|
||||
# done there with the len(s.rfilename.split("/")) == 1 condition
|
||||
root, _, files = next(os.walk(str(d)))
|
||||
filenames = [f for f in files
|
||||
if f.endswith(extension)
|
||||
and "arguments" not in f
|
||||
and "args" not in f
|
||||
and "training" not in f]
|
||||
return filenames
|
||||
|
||||
|
||||
def try_to_load_from_cache(
|
||||
model_id: str, revision: Optional[str], filename: str
|
||||
) -> Optional[Path]:
|
||||
"""Try to load a file from the Hugging Face cache"""
|
||||
def _get_cached_revision_directory(model_id: str, revision: Optional[str]) -> Optional[Path]:
|
||||
if revision is None:
|
||||
revision = "main"
|
||||
|
||||
object_id = model_id.replace("/", "--")
|
||||
repo_cache = Path(HUGGINGFACE_HUB_CACHE) / f"models--{object_id}"
|
||||
repo_cache = Path(HUGGINGFACE_HUB_CACHE) / Path(
|
||||
file_download.repo_folder_name(repo_id=model_id, repo_type="model"))
|
||||
|
||||
if not repo_cache.is_dir():
|
||||
# No cache for this model
|
||||
|
@ -74,8 +81,42 @@ def try_to_load_from_cache(
|
|||
# No cache for this revision and we won't try to return a random revision
|
||||
return None
|
||||
|
||||
return snapshots_dir / revision
|
||||
|
||||
|
||||
def weight_hub_files(
|
||||
model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
|
||||
) -> List[str]:
|
||||
"""Get the weights filenames on the hub"""
|
||||
api = HfApi()
|
||||
|
||||
if HF_HUB_OFFLINE:
|
||||
filenames = _cached_weight_files(model_id, revision, extension)
|
||||
else:
|
||||
# Online case, fetch model info from the Hub
|
||||
info = api.model_info(model_id, revision=revision)
|
||||
filenames = _weight_hub_files_from_model_info(info, extension)
|
||||
|
||||
if not filenames:
|
||||
raise EntryNotFoundError(
|
||||
f"No {extension} weights found for model {model_id} and revision {revision}.",
|
||||
None,
|
||||
)
|
||||
|
||||
return filenames
|
||||
|
||||
|
||||
def try_to_load_from_cache(
|
||||
model_id: str, revision: Optional[str], filename: str
|
||||
) -> Optional[Path]:
|
||||
"""Try to load a file from the Hugging Face cache"""
|
||||
|
||||
d = _get_cached_revision_directory(model_id, revision)
|
||||
if not d:
|
||||
return None
|
||||
|
||||
# Check if file exists in cache
|
||||
cached_file = snapshots_dir / revision / filename
|
||||
cached_file = d / filename
|
||||
return cached_file if cached_file.is_file() else None
|
||||
|
||||
|
||||
|
@ -84,13 +125,14 @@ def weight_files(
|
|||
) -> List[Path]:
|
||||
"""Get the local files"""
|
||||
# Local model
|
||||
if Path(model_id).exists() and Path(model_id).is_dir():
|
||||
local_files = list(Path(model_id).glob(f"*{extension}"))
|
||||
d = Path(model_id)
|
||||
if d.exists() and d.is_dir():
|
||||
local_files = _weight_files_from_dir(d, extension)
|
||||
if not local_files:
|
||||
raise FileNotFoundError(
|
||||
f"No local weights found in {model_id} with extension {extension}"
|
||||
)
|
||||
return local_files
|
||||
return [Path(f) for f in local_files]
|
||||
|
||||
try:
|
||||
filenames = weight_hub_files(model_id, revision, extension)
|
||||
|
@ -138,33 +180,33 @@ def download_weights(
|
|||
) -> List[Path]:
|
||||
"""Download the safetensors files from the hub"""
|
||||
|
||||
def download_file(filename, tries=5, backoff: int = 5):
|
||||
local_file = try_to_load_from_cache(model_id, revision, filename)
|
||||
def download_file(fname, tries=5, backoff: int = 5):
|
||||
local_file = try_to_load_from_cache(model_id, revision, fname)
|
||||
if local_file is not None:
|
||||
logger.info(f"File {filename} already present in cache.")
|
||||
logger.info(f"File {fname} already present in cache.")
|
||||
return Path(local_file)
|
||||
|
||||
for i in range(tries):
|
||||
for idx in range(tries):
|
||||
try:
|
||||
logger.info(f"Download file: {filename}")
|
||||
start_time = time.time()
|
||||
logger.info(f"Download file: {fname}")
|
||||
stime = time.time()
|
||||
local_file = hf_hub_download(
|
||||
filename=filename,
|
||||
filename=fname,
|
||||
repo_id=model_id,
|
||||
revision=revision,
|
||||
local_files_only=False,
|
||||
local_files_only=HF_HUB_OFFLINE,
|
||||
)
|
||||
logger.info(
|
||||
f"Downloaded {local_file} in {timedelta(seconds=int(time.time() - start_time))}."
|
||||
f"Downloaded {local_file} in {timedelta(seconds=int(time.time() - stime))}."
|
||||
)
|
||||
return Path(local_file)
|
||||
except Exception as e:
|
||||
if i + 1 == tries:
|
||||
if idx + 1 == tries:
|
||||
raise e
|
||||
logger.error(e)
|
||||
logger.info(f"Retrying in {backoff} seconds")
|
||||
time.sleep(backoff)
|
||||
logger.info(f"Retry {i + 1}/{tries - 1}")
|
||||
logger.info(f"Retry {idx + 1}/{tries - 1}")
|
||||
|
||||
# We do this instead of using tqdm because we want to parse the logs with the launcher
|
||||
start_time = time.time()
|
||||
|
|
Loading…
Reference in New Issue