@oOraph --------- Signed-off-by: Raphael Glon <oOraph@users.noreply.github.com> Co-authored-by: Raphael Glon <oOraph@users.noreply.github.com>
This commit is contained in:
parent
1b1bfa49b0
commit
8428ed1011
|
@ -556,6 +556,7 @@ mod tests {
|
||||||
inputs: "Hello".to_string(),
|
inputs: "Hello".to_string(),
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
top_p: Some(1.0),
|
top_p: Some(1.0),
|
||||||
|
max_new_tokens: Some(5),
|
||||||
..default_parameters()
|
..default_parameters()
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
@ -570,6 +571,7 @@ mod tests {
|
||||||
inputs: "Hello".to_string(),
|
inputs: "Hello".to_string(),
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
top_p: Some(0.99),
|
top_p: Some(0.99),
|
||||||
|
max_new_tokens: Some(5),
|
||||||
..default_parameters()
|
..default_parameters()
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
@ -584,6 +586,7 @@ mod tests {
|
||||||
inputs: "Hello".to_string(),
|
inputs: "Hello".to_string(),
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
top_p: None,
|
top_p: None,
|
||||||
|
max_new_tokens: Some(5),
|
||||||
..default_parameters()
|
..default_parameters()
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
@ -616,6 +619,7 @@ mod tests {
|
||||||
inputs: "Hello".to_string(),
|
inputs: "Hello".to_string(),
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
top_n_tokens: Some(5),
|
top_n_tokens: Some(5),
|
||||||
|
max_new_tokens: Some(5),
|
||||||
..default_parameters()
|
..default_parameters()
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
@ -630,6 +634,7 @@ mod tests {
|
||||||
inputs: "Hello".to_string(),
|
inputs: "Hello".to_string(),
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
top_n_tokens: Some(4),
|
top_n_tokens: Some(4),
|
||||||
|
max_new_tokens: Some(5),
|
||||||
..default_parameters()
|
..default_parameters()
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
@ -641,6 +646,7 @@ mod tests {
|
||||||
inputs: "Hello".to_string(),
|
inputs: "Hello".to_string(),
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
top_n_tokens: Some(0),
|
top_n_tokens: Some(0),
|
||||||
|
max_new_tokens: Some(5),
|
||||||
..default_parameters()
|
..default_parameters()
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
@ -652,6 +658,7 @@ mod tests {
|
||||||
inputs: "Hello".to_string(),
|
inputs: "Hello".to_string(),
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
top_n_tokens: None,
|
top_n_tokens: None,
|
||||||
|
max_new_tokens: Some(5),
|
||||||
..default_parameters()
|
..default_parameters()
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
|
@ -1,5 +1,13 @@
|
||||||
|
import os
|
||||||
|
import requests
|
||||||
|
import tempfile
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
import huggingface_hub.constants
|
||||||
|
from huggingface_hub import hf_api
|
||||||
|
|
||||||
|
import text_generation_server.utils.hub
|
||||||
from text_generation_server.utils.hub import (
|
from text_generation_server.utils.hub import (
|
||||||
weight_hub_files,
|
weight_hub_files,
|
||||||
download_weights,
|
download_weights,
|
||||||
|
@ -10,6 +18,52 @@ from text_generation_server.utils.hub import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def offline():
|
||||||
|
current_value = text_generation_server.utils.hub.HF_HUB_OFFLINE
|
||||||
|
text_generation_server.utils.hub.HF_HUB_OFFLINE = True
|
||||||
|
yield "offline"
|
||||||
|
text_generation_server.utils.hub.HF_HUB_OFFLINE = current_value
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def fresh_cache():
|
||||||
|
with tempfile.TemporaryDirectory() as d:
|
||||||
|
current_value = huggingface_hub.constants.HUGGINGFACE_HUB_CACHE
|
||||||
|
huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = d
|
||||||
|
text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = d
|
||||||
|
os.environ['HUGGINGFACE_HUB_CACHE'] = d
|
||||||
|
yield
|
||||||
|
huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = current_value
|
||||||
|
os.environ['HUGGINGFACE_HUB_CACHE'] = current_value
|
||||||
|
text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = current_value
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def prefetched():
|
||||||
|
model_id = "bert-base-uncased"
|
||||||
|
huggingface_hub.snapshot_download(
|
||||||
|
repo_id=model_id,
|
||||||
|
revision="main",
|
||||||
|
local_files_only=False,
|
||||||
|
repo_type="model",
|
||||||
|
allow_patterns=["*.safetensors"]
|
||||||
|
)
|
||||||
|
yield model_id
|
||||||
|
|
||||||
|
|
||||||
|
def test_weight_hub_files_offline_error(offline, fresh_cache):
|
||||||
|
# If the model is not prefetched then it will raise an error
|
||||||
|
with pytest.raises(EntryNotFoundError):
|
||||||
|
weight_hub_files("gpt2")
|
||||||
|
|
||||||
|
|
||||||
|
def test_weight_hub_files_offline_ok(prefetched, offline):
|
||||||
|
# If the model is prefetched then we should be able to get the weight files from local cache
|
||||||
|
filenames = weight_hub_files(prefetched)
|
||||||
|
assert filenames == ['model.safetensors']
|
||||||
|
|
||||||
|
|
||||||
def test_weight_hub_files():
|
def test_weight_hub_files():
|
||||||
filenames = weight_hub_files("bigscience/bloom-560m")
|
filenames = weight_hub_files("bigscience/bloom-560m")
|
||||||
assert filenames == ["model.safetensors"]
|
assert filenames == ["model.safetensors"]
|
||||||
|
@ -33,8 +87,11 @@ def test_download_weights():
|
||||||
assert files == local_files
|
assert files == local_files
|
||||||
|
|
||||||
|
|
||||||
def test_weight_files_error():
|
def test_weight_files_revision_error():
|
||||||
with pytest.raises(RevisionNotFoundError):
|
with pytest.raises(RevisionNotFoundError):
|
||||||
weight_files("bigscience/bloom-560m", revision="error")
|
weight_files("bigscience/bloom-560m", revision="error")
|
||||||
|
|
||||||
|
|
||||||
|
def test_weight_files_not_cached_error(fresh_cache):
|
||||||
with pytest.raises(LocalEntryNotFoundError):
|
with pytest.raises(LocalEntryNotFoundError):
|
||||||
weight_files("bert-base-uncased")
|
weight_files("bert-base-uncased")
|
||||||
|
|
|
@ -6,24 +6,29 @@ from loguru import logger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
|
|
||||||
from huggingface_hub import HfApi, hf_hub_download
|
from huggingface_hub import file_download, hf_api, HfApi, hf_hub_download
|
||||||
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
||||||
from huggingface_hub.utils import (
|
from huggingface_hub.utils import (
|
||||||
LocalEntryNotFoundError,
|
LocalEntryNotFoundError,
|
||||||
EntryNotFoundError,
|
EntryNotFoundError,
|
||||||
RevisionNotFoundError, # Import here to ease try/except in other part of the lib
|
RevisionNotFoundError, # noqa # Import here to ease try/except in other part of the lib
|
||||||
)
|
)
|
||||||
|
|
||||||
WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
|
WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
|
||||||
|
HF_HUB_OFFLINE = os.environ.get("HF_HUB_OFFLINE", "0").lower() in ["true", "1", "yes"]
|
||||||
|
|
||||||
|
|
||||||
def weight_hub_files(
|
def _cached_weight_files(model_id: str, revision: Optional[str], extension: str) -> List[str]:
|
||||||
model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
|
"""Guess weight files from the cached revision snapshot directory"""
|
||||||
) -> List[str]:
|
d = _get_cached_revision_directory(model_id, revision)
|
||||||
"""Get the weights filenames on the hub"""
|
if not d:
|
||||||
api = HfApi()
|
return []
|
||||||
info = api.model_info(model_id, revision=revision)
|
filenames = _weight_files_from_dir(d, extension)
|
||||||
filenames = [
|
return filenames
|
||||||
|
|
||||||
|
|
||||||
|
def _weight_hub_files_from_model_info(info: hf_api.ModelInfo, extension: str) -> List[str]:
|
||||||
|
return [
|
||||||
s.rfilename
|
s.rfilename
|
||||||
for s in info.siblings
|
for s in info.siblings
|
||||||
if s.rfilename.endswith(extension)
|
if s.rfilename.endswith(extension)
|
||||||
|
@ -33,24 +38,26 @@ def weight_hub_files(
|
||||||
and "training" not in s.rfilename
|
and "training" not in s.rfilename
|
||||||
]
|
]
|
||||||
|
|
||||||
if not filenames:
|
|
||||||
raise EntryNotFoundError(
|
|
||||||
f"No {extension} weights found for model {model_id} and revision {revision}.",
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
def _weight_files_from_dir(d: Path, extension: str) -> List[str]:
|
||||||
|
# os.walk: do not iterate, just scan for depth 1, not recursively
|
||||||
|
# see _weight_hub_files_from_model_info, that's also what is
|
||||||
|
# done there with the len(s.rfilename.split("/")) == 1 condition
|
||||||
|
root, _, files = next(os.walk(str(d)))
|
||||||
|
filenames = [f for f in files
|
||||||
|
if f.endswith(extension)
|
||||||
|
and "arguments" not in f
|
||||||
|
and "args" not in f
|
||||||
|
and "training" not in f]
|
||||||
return filenames
|
return filenames
|
||||||
|
|
||||||
|
|
||||||
def try_to_load_from_cache(
|
def _get_cached_revision_directory(model_id: str, revision: Optional[str]) -> Optional[Path]:
|
||||||
model_id: str, revision: Optional[str], filename: str
|
|
||||||
) -> Optional[Path]:
|
|
||||||
"""Try to load a file from the Hugging Face cache"""
|
|
||||||
if revision is None:
|
if revision is None:
|
||||||
revision = "main"
|
revision = "main"
|
||||||
|
|
||||||
object_id = model_id.replace("/", "--")
|
repo_cache = Path(HUGGINGFACE_HUB_CACHE) / Path(
|
||||||
repo_cache = Path(HUGGINGFACE_HUB_CACHE) / f"models--{object_id}"
|
file_download.repo_folder_name(repo_id=model_id, repo_type="model"))
|
||||||
|
|
||||||
if not repo_cache.is_dir():
|
if not repo_cache.is_dir():
|
||||||
# No cache for this model
|
# No cache for this model
|
||||||
|
@ -74,8 +81,42 @@ def try_to_load_from_cache(
|
||||||
# No cache for this revision and we won't try to return a random revision
|
# No cache for this revision and we won't try to return a random revision
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
return snapshots_dir / revision
|
||||||
|
|
||||||
|
|
||||||
|
def weight_hub_files(
|
||||||
|
model_id: str, revision: Optional[str] = None, extension: str = ".safetensors"
|
||||||
|
) -> List[str]:
|
||||||
|
"""Get the weights filenames on the hub"""
|
||||||
|
api = HfApi()
|
||||||
|
|
||||||
|
if HF_HUB_OFFLINE:
|
||||||
|
filenames = _cached_weight_files(model_id, revision, extension)
|
||||||
|
else:
|
||||||
|
# Online case, fetch model info from the Hub
|
||||||
|
info = api.model_info(model_id, revision=revision)
|
||||||
|
filenames = _weight_hub_files_from_model_info(info, extension)
|
||||||
|
|
||||||
|
if not filenames:
|
||||||
|
raise EntryNotFoundError(
|
||||||
|
f"No {extension} weights found for model {model_id} and revision {revision}.",
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
return filenames
|
||||||
|
|
||||||
|
|
||||||
|
def try_to_load_from_cache(
|
||||||
|
model_id: str, revision: Optional[str], filename: str
|
||||||
|
) -> Optional[Path]:
|
||||||
|
"""Try to load a file from the Hugging Face cache"""
|
||||||
|
|
||||||
|
d = _get_cached_revision_directory(model_id, revision)
|
||||||
|
if not d:
|
||||||
|
return None
|
||||||
|
|
||||||
# Check if file exists in cache
|
# Check if file exists in cache
|
||||||
cached_file = snapshots_dir / revision / filename
|
cached_file = d / filename
|
||||||
return cached_file if cached_file.is_file() else None
|
return cached_file if cached_file.is_file() else None
|
||||||
|
|
||||||
|
|
||||||
|
@ -84,13 +125,14 @@ def weight_files(
|
||||||
) -> List[Path]:
|
) -> List[Path]:
|
||||||
"""Get the local files"""
|
"""Get the local files"""
|
||||||
# Local model
|
# Local model
|
||||||
if Path(model_id).exists() and Path(model_id).is_dir():
|
d = Path(model_id)
|
||||||
local_files = list(Path(model_id).glob(f"*{extension}"))
|
if d.exists() and d.is_dir():
|
||||||
|
local_files = _weight_files_from_dir(d, extension)
|
||||||
if not local_files:
|
if not local_files:
|
||||||
raise FileNotFoundError(
|
raise FileNotFoundError(
|
||||||
f"No local weights found in {model_id} with extension {extension}"
|
f"No local weights found in {model_id} with extension {extension}"
|
||||||
)
|
)
|
||||||
return local_files
|
return [Path(f) for f in local_files]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
filenames = weight_hub_files(model_id, revision, extension)
|
filenames = weight_hub_files(model_id, revision, extension)
|
||||||
|
@ -138,33 +180,33 @@ def download_weights(
|
||||||
) -> List[Path]:
|
) -> List[Path]:
|
||||||
"""Download the safetensors files from the hub"""
|
"""Download the safetensors files from the hub"""
|
||||||
|
|
||||||
def download_file(filename, tries=5, backoff: int = 5):
|
def download_file(fname, tries=5, backoff: int = 5):
|
||||||
local_file = try_to_load_from_cache(model_id, revision, filename)
|
local_file = try_to_load_from_cache(model_id, revision, fname)
|
||||||
if local_file is not None:
|
if local_file is not None:
|
||||||
logger.info(f"File {filename} already present in cache.")
|
logger.info(f"File {fname} already present in cache.")
|
||||||
return Path(local_file)
|
return Path(local_file)
|
||||||
|
|
||||||
for i in range(tries):
|
for idx in range(tries):
|
||||||
try:
|
try:
|
||||||
logger.info(f"Download file: {filename}")
|
logger.info(f"Download file: {fname}")
|
||||||
start_time = time.time()
|
stime = time.time()
|
||||||
local_file = hf_hub_download(
|
local_file = hf_hub_download(
|
||||||
filename=filename,
|
filename=fname,
|
||||||
repo_id=model_id,
|
repo_id=model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
local_files_only=False,
|
local_files_only=HF_HUB_OFFLINE,
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Downloaded {local_file} in {timedelta(seconds=int(time.time() - start_time))}."
|
f"Downloaded {local_file} in {timedelta(seconds=int(time.time() - stime))}."
|
||||||
)
|
)
|
||||||
return Path(local_file)
|
return Path(local_file)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if i + 1 == tries:
|
if idx + 1 == tries:
|
||||||
raise e
|
raise e
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
logger.info(f"Retrying in {backoff} seconds")
|
logger.info(f"Retrying in {backoff} seconds")
|
||||||
time.sleep(backoff)
|
time.sleep(backoff)
|
||||||
logger.info(f"Retry {i + 1}/{tries - 1}")
|
logger.info(f"Retry {idx + 1}/{tries - 1}")
|
||||||
|
|
||||||
# We do this instead of using tqdm because we want to parse the logs with the launcher
|
# We do this instead of using tqdm because we want to parse the logs with the launcher
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
Loading…
Reference in New Issue