diff --git a/clients/python/tests/conftest.py b/clients/python/tests/conftest.py index 48734f0d..17bb73b5 100644 --- a/clients/python/tests/conftest.py +++ b/clients/python/tests/conftest.py @@ -9,6 +9,11 @@ def flan_t5_xxl(): return "google/flan-t5-xxl" +@pytest.fixture +def llama_7b(): + return "meta-llama/Llama-2-7b-chat-hf" + + @pytest.fixture def fake_model(): return "fake/model" @@ -34,6 +39,11 @@ def flan_t5_xxl_url(base_url, flan_t5_xxl): return f"{base_url}/{flan_t5_xxl}" +@pytest.fixture +def llama_7b_url(base_url, llama_7b): + return f"{base_url}/{llama_7b}" + + @pytest.fixture def fake_url(base_url, fake_model): return f"{base_url}/{fake_model}" diff --git a/clients/python/tests/test_client.py b/clients/python/tests/test_client.py index 1e25e1b1..8aed865b 100644 --- a/clients/python/tests/test_client.py +++ b/clients/python/tests/test_client.py @@ -5,24 +5,24 @@ from text_generation.errors import NotFoundError, ValidationError from text_generation.types import FinishReason, InputToken -def test_generate(flan_t5_xxl_url, hf_headers): - client = Client(flan_t5_xxl_url, hf_headers) +def test_generate(llama_7b_url, hf_headers): + client = Client(llama_7b_url, hf_headers) response = client.generate("test", max_new_tokens=1, decoder_input_details=True) - assert response.generated_text == "" + assert response.generated_text == "_" assert response.details.finish_reason == FinishReason.Length assert response.details.generated_tokens == 1 assert response.details.seed is None - assert len(response.details.prefill) == 1 - assert response.details.prefill[0] == InputToken(id=0, text="", logprob=None) + assert len(response.details.prefill) == 2 + assert response.details.prefill[0] == InputToken(id=1, text="", logprob=None) assert len(response.details.tokens) == 1 - assert response.details.tokens[0].id == 3 - assert response.details.tokens[0].text == " " + assert response.details.tokens[0].id == 29918 + assert response.details.tokens[0].text == "_" assert not response.details.tokens[0].special -def test_generate_best_of(flan_t5_xxl_url, hf_headers): - client = Client(flan_t5_xxl_url, hf_headers) +def test_generate_best_of(llama_7b_url, hf_headers): + client = Client(llama_7b_url, hf_headers) response = client.generate( "test", max_new_tokens=1, best_of=2, do_sample=True, decoder_input_details=True ) @@ -39,14 +39,14 @@ def test_generate_not_found(fake_url, hf_headers): client.generate("test") -def test_generate_validation_error(flan_t5_xxl_url, hf_headers): - client = Client(flan_t5_xxl_url, hf_headers) +def test_generate_validation_error(llama_7b_url, hf_headers): + client = Client(llama_7b_url, hf_headers) with pytest.raises(ValidationError): client.generate("test", max_new_tokens=10_000) -def test_generate_stream(flan_t5_xxl_url, hf_headers): - client = Client(flan_t5_xxl_url, hf_headers) +def test_generate_stream(llama_7b_url, hf_headers): + client = Client(llama_7b_url, hf_headers) responses = [ response for response in client.generate_stream("test", max_new_tokens=1) ] @@ -54,7 +54,7 @@ def test_generate_stream(flan_t5_xxl_url, hf_headers): assert len(responses) == 1 response = responses[0] - assert response.generated_text == "" + assert response.generated_text == "_" assert response.details.finish_reason == FinishReason.Length assert response.details.generated_tokens == 1 assert response.details.seed is None @@ -66,34 +66,37 @@ def test_generate_stream_not_found(fake_url, hf_headers): list(client.generate_stream("test")) -def test_generate_stream_validation_error(flan_t5_xxl_url, hf_headers): - client = Client(flan_t5_xxl_url, hf_headers) +def test_generate_stream_validation_error(llama_7b_url, hf_headers): + client = Client(llama_7b_url, hf_headers) with pytest.raises(ValidationError): list(client.generate_stream("test", max_new_tokens=10_000)) @pytest.mark.asyncio -async def test_generate_async(flan_t5_xxl_url, hf_headers): - client = AsyncClient(flan_t5_xxl_url, hf_headers) +async def test_generate_async(llama_7b_url, hf_headers): + client = AsyncClient(llama_7b_url, hf_headers) response = await client.generate( "test", max_new_tokens=1, decoder_input_details=True ) - assert response.generated_text == "" + assert response.generated_text == "_" assert response.details.finish_reason == FinishReason.Length assert response.details.generated_tokens == 1 assert response.details.seed is None - assert len(response.details.prefill) == 1 - assert response.details.prefill[0] == InputToken(id=0, text="", logprob=None) + assert len(response.details.prefill) == 2 + assert response.details.prefill[0] == InputToken(id=1, text="", logprob=None) + assert response.details.prefill[1] == InputToken( + id=1243, text="test", logprob=-10.96875 + ) assert len(response.details.tokens) == 1 - assert response.details.tokens[0].id == 3 - assert response.details.tokens[0].text == " " + assert response.details.tokens[0].id == 29918 + assert response.details.tokens[0].text == "_" assert not response.details.tokens[0].special @pytest.mark.asyncio -async def test_generate_async_best_of(flan_t5_xxl_url, hf_headers): - client = AsyncClient(flan_t5_xxl_url, hf_headers) +async def test_generate_async_best_of(llama_7b_url, hf_headers): + client = AsyncClient(llama_7b_url, hf_headers) response = await client.generate( "test", max_new_tokens=1, best_of=2, do_sample=True, decoder_input_details=True ) @@ -112,15 +115,15 @@ async def test_generate_async_not_found(fake_url, hf_headers): @pytest.mark.asyncio -async def test_generate_async_validation_error(flan_t5_xxl_url, hf_headers): - client = AsyncClient(flan_t5_xxl_url, hf_headers) +async def test_generate_async_validation_error(llama_7b_url, hf_headers): + client = AsyncClient(llama_7b_url, hf_headers) with pytest.raises(ValidationError): await client.generate("test", max_new_tokens=10_000) @pytest.mark.asyncio -async def test_generate_stream_async(flan_t5_xxl_url, hf_headers): - client = AsyncClient(flan_t5_xxl_url, hf_headers) +async def test_generate_stream_async(llama_7b_url, hf_headers): + client = AsyncClient(llama_7b_url, hf_headers) responses = [ response async for response in client.generate_stream("test", max_new_tokens=1) ] @@ -128,7 +131,7 @@ async def test_generate_stream_async(flan_t5_xxl_url, hf_headers): assert len(responses) == 1 response = responses[0] - assert response.generated_text == "" + assert response.generated_text == "_" assert response.details.finish_reason == FinishReason.Length assert response.details.generated_tokens == 1 assert response.details.seed is None @@ -143,8 +146,8 @@ async def test_generate_stream_async_not_found(fake_url, hf_headers): @pytest.mark.asyncio -async def test_generate_stream_async_validation_error(flan_t5_xxl_url, hf_headers): - client = AsyncClient(flan_t5_xxl_url, hf_headers) +async def test_generate_stream_async_validation_error(llama_7b_url, hf_headers): + client = AsyncClient(llama_7b_url, hf_headers) with pytest.raises(ValidationError): async for _ in client.generate_stream("test", max_new_tokens=10_000): pass