hf_text-generation-inference/server/tests/utils/test_hub.py

98 lines
2.8 KiB
Python
Raw Normal View History

import os
import requests
import tempfile
2023-02-14 05:02:16 -07:00
import pytest
import huggingface_hub.constants
from huggingface_hub import hf_api
import text_generation_server.utils.hub
2023-03-07 10:52:22 -07:00
from text_generation_server.utils.hub import (
2023-02-14 05:02:16 -07:00
weight_hub_files,
download_weights,
weight_files,
EntryNotFoundError,
LocalEntryNotFoundError,
RevisionNotFoundError,
)
@pytest.fixture()
def offline():
current_value = text_generation_server.utils.hub.HF_HUB_OFFLINE
text_generation_server.utils.hub.HF_HUB_OFFLINE = True
yield "offline"
text_generation_server.utils.hub.HF_HUB_OFFLINE = current_value
@pytest.fixture()
def fresh_cache():
with tempfile.TemporaryDirectory() as d:
current_value = huggingface_hub.constants.HUGGINGFACE_HUB_CACHE
huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = d
text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = d
os.environ['HUGGINGFACE_HUB_CACHE'] = d
yield
huggingface_hub.constants.HUGGINGFACE_HUB_CACHE = current_value
os.environ['HUGGINGFACE_HUB_CACHE'] = current_value
text_generation_server.utils.hub.HUGGINGFACE_HUB_CACHE = current_value
@pytest.fixture()
def prefetched():
model_id = "bert-base-uncased"
huggingface_hub.snapshot_download(
repo_id=model_id,
revision="main",
local_files_only=False,
repo_type="model",
allow_patterns=["*.safetensors"]
)
yield model_id
def test_weight_hub_files_offline_error(offline, fresh_cache):
# If the model is not prefetched then it will raise an error
with pytest.raises(EntryNotFoundError):
weight_hub_files("gpt2")
def test_weight_hub_files_offline_ok(prefetched, offline):
# If the model is prefetched then we should be able to get the weight files from local cache
filenames = weight_hub_files(prefetched)
assert filenames == ['model.safetensors']
2023-02-14 05:02:16 -07:00
def test_weight_hub_files():
filenames = weight_hub_files("bigscience/bloom-560m")
assert filenames == ["model.safetensors"]
def test_weight_hub_files_llm():
filenames = weight_hub_files("bigscience/bloom")
assert filenames == [f"model_{i:05d}-of-00072.safetensors" for i in range(1, 73)]
def test_weight_hub_files_empty():
with pytest.raises(EntryNotFoundError):
weight_hub_files("bigscience/bloom", extension=".errors")
def test_download_weights():
model_id = "bigscience/bloom-560m"
filenames = weight_hub_files(model_id)
files = download_weights(filenames, model_id)
local_files = weight_files("bigscience/bloom-560m")
assert files == local_files
def test_weight_files_revision_error():
2023-02-14 05:02:16 -07:00
with pytest.raises(RevisionNotFoundError):
weight_files("bigscience/bloom-560m", revision="error")
def test_weight_files_not_cached_error(fresh_cache):
2023-02-14 05:02:16 -07:00
with pytest.raises(LocalEntryNotFoundError):
weight_files("bert-base-uncased")