fix: bump clients test base url to llama (#1751)

This PR bumps the client tests from `google/flan-t5-xxl` to
`meta-llama/Llama-2-7b-chat-hf` to resolve issues when calling the
endpoint and `google/flan-t5-xxl` is not available

run with
```bash
make python-client-tests

clients/python/tests/test_client.py ..............     [ 43%]
clients/python/tests/test_errors.py ..........         [ 75%]
clients/python/tests/test_inference_api.py ......      [ 93%]
clients/python/tests/test_types.py ..                  [100%]
```

**note `google/flan-t5-xxl` function is currently unused but still
included in the `conftest.py`
This commit is contained in:
drbh 2024-04-16 16:56:47 -04:00 committed by GitHub
parent 00f365353e
commit e4d31a40db
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 45 additions and 32 deletions

View File

@ -9,6 +9,11 @@ def flan_t5_xxl():
return "google/flan-t5-xxl" return "google/flan-t5-xxl"
@pytest.fixture
def llama_7b():
return "meta-llama/Llama-2-7b-chat-hf"
@pytest.fixture @pytest.fixture
def fake_model(): def fake_model():
return "fake/model" return "fake/model"
@ -34,6 +39,11 @@ def flan_t5_xxl_url(base_url, flan_t5_xxl):
return f"{base_url}/{flan_t5_xxl}" return f"{base_url}/{flan_t5_xxl}"
@pytest.fixture
def llama_7b_url(base_url, llama_7b):
return f"{base_url}/{llama_7b}"
@pytest.fixture @pytest.fixture
def fake_url(base_url, fake_model): def fake_url(base_url, fake_model):
return f"{base_url}/{fake_model}" return f"{base_url}/{fake_model}"

View File

@ -5,24 +5,24 @@ from text_generation.errors import NotFoundError, ValidationError
from text_generation.types import FinishReason, InputToken from text_generation.types import FinishReason, InputToken
def test_generate(flan_t5_xxl_url, hf_headers): def test_generate(llama_7b_url, hf_headers):
client = Client(flan_t5_xxl_url, hf_headers) client = Client(llama_7b_url, hf_headers)
response = client.generate("test", max_new_tokens=1, decoder_input_details=True) response = client.generate("test", max_new_tokens=1, decoder_input_details=True)
assert response.generated_text == "" assert response.generated_text == "_"
assert response.details.finish_reason == FinishReason.Length assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1 assert response.details.generated_tokens == 1
assert response.details.seed is None assert response.details.seed is None
assert len(response.details.prefill) == 1 assert len(response.details.prefill) == 2
assert response.details.prefill[0] == InputToken(id=0, text="<pad>", logprob=None) assert response.details.prefill[0] == InputToken(id=1, text="<s>", logprob=None)
assert len(response.details.tokens) == 1 assert len(response.details.tokens) == 1
assert response.details.tokens[0].id == 3 assert response.details.tokens[0].id == 29918
assert response.details.tokens[0].text == " " assert response.details.tokens[0].text == "_"
assert not response.details.tokens[0].special assert not response.details.tokens[0].special
def test_generate_best_of(flan_t5_xxl_url, hf_headers): def test_generate_best_of(llama_7b_url, hf_headers):
client = Client(flan_t5_xxl_url, hf_headers) client = Client(llama_7b_url, hf_headers)
response = client.generate( response = client.generate(
"test", max_new_tokens=1, best_of=2, do_sample=True, decoder_input_details=True "test", max_new_tokens=1, best_of=2, do_sample=True, decoder_input_details=True
) )
@ -39,14 +39,14 @@ def test_generate_not_found(fake_url, hf_headers):
client.generate("test") client.generate("test")
def test_generate_validation_error(flan_t5_xxl_url, hf_headers): def test_generate_validation_error(llama_7b_url, hf_headers):
client = Client(flan_t5_xxl_url, hf_headers) client = Client(llama_7b_url, hf_headers)
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
client.generate("test", max_new_tokens=10_000) client.generate("test", max_new_tokens=10_000)
def test_generate_stream(flan_t5_xxl_url, hf_headers): def test_generate_stream(llama_7b_url, hf_headers):
client = Client(flan_t5_xxl_url, hf_headers) client = Client(llama_7b_url, hf_headers)
responses = [ responses = [
response for response in client.generate_stream("test", max_new_tokens=1) response for response in client.generate_stream("test", max_new_tokens=1)
] ]
@ -54,7 +54,7 @@ def test_generate_stream(flan_t5_xxl_url, hf_headers):
assert len(responses) == 1 assert len(responses) == 1
response = responses[0] response = responses[0]
assert response.generated_text == "" assert response.generated_text == "_"
assert response.details.finish_reason == FinishReason.Length assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1 assert response.details.generated_tokens == 1
assert response.details.seed is None assert response.details.seed is None
@ -66,34 +66,37 @@ def test_generate_stream_not_found(fake_url, hf_headers):
list(client.generate_stream("test")) list(client.generate_stream("test"))
def test_generate_stream_validation_error(flan_t5_xxl_url, hf_headers): def test_generate_stream_validation_error(llama_7b_url, hf_headers):
client = Client(flan_t5_xxl_url, hf_headers) client = Client(llama_7b_url, hf_headers)
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
list(client.generate_stream("test", max_new_tokens=10_000)) list(client.generate_stream("test", max_new_tokens=10_000))
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate_async(flan_t5_xxl_url, hf_headers): async def test_generate_async(llama_7b_url, hf_headers):
client = AsyncClient(flan_t5_xxl_url, hf_headers) client = AsyncClient(llama_7b_url, hf_headers)
response = await client.generate( response = await client.generate(
"test", max_new_tokens=1, decoder_input_details=True "test", max_new_tokens=1, decoder_input_details=True
) )
assert response.generated_text == "" assert response.generated_text == "_"
assert response.details.finish_reason == FinishReason.Length assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1 assert response.details.generated_tokens == 1
assert response.details.seed is None assert response.details.seed is None
assert len(response.details.prefill) == 1 assert len(response.details.prefill) == 2
assert response.details.prefill[0] == InputToken(id=0, text="<pad>", logprob=None) assert response.details.prefill[0] == InputToken(id=1, text="<s>", logprob=None)
assert response.details.prefill[1] == InputToken(
id=1243, text="test", logprob=-10.96875
)
assert len(response.details.tokens) == 1 assert len(response.details.tokens) == 1
assert response.details.tokens[0].id == 3 assert response.details.tokens[0].id == 29918
assert response.details.tokens[0].text == " " assert response.details.tokens[0].text == "_"
assert not response.details.tokens[0].special assert not response.details.tokens[0].special
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate_async_best_of(flan_t5_xxl_url, hf_headers): async def test_generate_async_best_of(llama_7b_url, hf_headers):
client = AsyncClient(flan_t5_xxl_url, hf_headers) client = AsyncClient(llama_7b_url, hf_headers)
response = await client.generate( response = await client.generate(
"test", max_new_tokens=1, best_of=2, do_sample=True, decoder_input_details=True "test", max_new_tokens=1, best_of=2, do_sample=True, decoder_input_details=True
) )
@ -112,15 +115,15 @@ async def test_generate_async_not_found(fake_url, hf_headers):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate_async_validation_error(flan_t5_xxl_url, hf_headers): async def test_generate_async_validation_error(llama_7b_url, hf_headers):
client = AsyncClient(flan_t5_xxl_url, hf_headers) client = AsyncClient(llama_7b_url, hf_headers)
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
await client.generate("test", max_new_tokens=10_000) await client.generate("test", max_new_tokens=10_000)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate_stream_async(flan_t5_xxl_url, hf_headers): async def test_generate_stream_async(llama_7b_url, hf_headers):
client = AsyncClient(flan_t5_xxl_url, hf_headers) client = AsyncClient(llama_7b_url, hf_headers)
responses = [ responses = [
response async for response in client.generate_stream("test", max_new_tokens=1) response async for response in client.generate_stream("test", max_new_tokens=1)
] ]
@ -128,7 +131,7 @@ async def test_generate_stream_async(flan_t5_xxl_url, hf_headers):
assert len(responses) == 1 assert len(responses) == 1
response = responses[0] response = responses[0]
assert response.generated_text == "" assert response.generated_text == "_"
assert response.details.finish_reason == FinishReason.Length assert response.details.finish_reason == FinishReason.Length
assert response.details.generated_tokens == 1 assert response.details.generated_tokens == 1
assert response.details.seed is None assert response.details.seed is None
@ -143,8 +146,8 @@ async def test_generate_stream_async_not_found(fake_url, hf_headers):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate_stream_async_validation_error(flan_t5_xxl_url, hf_headers): async def test_generate_stream_async_validation_error(llama_7b_url, hf_headers):
client = AsyncClient(flan_t5_xxl_url, hf_headers) client = AsyncClient(llama_7b_url, hf_headers)
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
async for _ in client.generate_stream("test", max_new_tokens=10_000): async for _ in client.generate_stream("test", max_new_tokens=10_000):
pass pass