hf_text-generation-inference/server/tests/test_utils.py

72 lines
2.1 KiB
Python
Raw Normal View History

2022-12-08 10:49:33 -07:00
import pytest
2023-01-31 10:53:56 -07:00
from huggingface_hub.utils import RevisionNotFoundError
2022-12-08 10:49:33 -07:00
from text_generation.utils import (
weight_hub_files,
download_weights,
weight_files,
2022-12-12 10:25:22 -07:00
StopSequenceCriteria,
StoppingCriteria,
2022-12-08 10:49:33 -07:00
LocalEntryNotFoundError,
FinishReason,
2022-12-08 10:49:33 -07:00
)
2022-12-12 10:25:22 -07:00
def test_stop_sequence_criteria():
2022-12-16 08:03:39 -07:00
criteria = StopSequenceCriteria("/test;")
2022-12-12 10:25:22 -07:00
2022-12-16 08:03:39 -07:00
assert not criteria("/")
assert not criteria("/test")
assert criteria("/test;")
assert not criteria("/test; ")
2022-12-12 10:25:22 -07:00
2022-12-16 08:03:39 -07:00
def test_stopping_criteria():
criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
assert criteria(65827, "/test") == (False, None)
assert criteria(30, ";") == (True, FinishReason.FINISH_REASON_STOP_SEQUENCE)
2022-12-12 10:25:22 -07:00
2022-12-16 08:03:39 -07:00
def test_stopping_criteria_eos():
criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
assert criteria(1, "") == (False, None)
assert criteria(0, "") == (True, FinishReason.FINISH_REASON_EOS_TOKEN)
2022-12-12 10:25:22 -07:00
def test_stopping_criteria_max():
2022-12-16 08:03:39 -07:00
criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
assert criteria(1, "") == (False, None)
assert criteria(1, "") == (False, None)
assert criteria(1, "") == (False, None)
assert criteria(1, "") == (False, None)
assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH)
2022-12-12 10:25:22 -07:00
2022-12-08 10:49:33 -07:00
def test_weight_hub_files():
filenames = weight_hub_files("bigscience/bloom-560m")
assert filenames == ["model.safetensors"]
def test_weight_hub_files_llm():
filenames = weight_hub_files("bigscience/bloom")
assert filenames == [f"model_{i:05d}-of-00072.safetensors" for i in range(1, 73)]
def test_weight_hub_files_empty():
2023-01-31 10:53:56 -07:00
filenames = weight_hub_files("bigscience/bloom", extension=".errors")
2022-12-08 10:49:33 -07:00
assert filenames == []
def test_download_weights():
files = download_weights("bigscience/bloom-560m")
local_files = weight_files("bigscience/bloom-560m")
assert files == local_files
def test_weight_files_error():
2023-01-31 10:53:56 -07:00
with pytest.raises(RevisionNotFoundError):
weight_files("bigscience/bloom-560m", revision="error")
2022-12-08 10:49:33 -07:00
with pytest.raises(LocalEntryNotFoundError):
weight_files("bert-base-uncased")