diff --git a/integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt.json b/integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt.json new file mode 100644 index 00000000..a452399e --- /dev/null +++ b/integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt.json @@ -0,0 +1,23 @@ +{ + "choices": [ + { + "finish_reason": "length", + "index": 0, + "logprobs": null, + "message": { + "content": "Hi, I hope this is the right place for your written question. Please provide the maximum possible length to help me complete the message for you! Based", + "role": "assistant" + } + } + ], + "created": 1731082056, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "chat.completion", + "system_fingerprint": "2.4.1-dev0-native", + "usage": { + "completion_tokens": 30, + "prompt_tokens": 57, + "total_tokens": 87 + } +} diff --git a/integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt_continue.json b/integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt_continue.json new file mode 100644 index 00000000..3e48bc37 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt_continue.json @@ -0,0 +1,23 @@ +{ + "choices": [ + { + "finish_reason": "length", + "index": 0, + "logprobs": null, + "message": { + "content": ": Thanks for the awesome slides, they were just what we needed to produce the presentation we needed to deliver for our company's budgeting system", + "role": "assistant" + } + } + ], + "created": 1731082129, + "id": "", + "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "object": "chat.completion", + "system_fingerprint": "2.4.1-dev0-native", + "usage": { + "completion_tokens": 30, + "prompt_tokens": 44, + "total_tokens": 74 + } +} diff --git a/integration-tests/models/test_continue_final_message.py b/integration-tests/models/test_continue_final_message.py new file mode 100644 index 00000000..9a8b07ad --- /dev/null +++ b/integration-tests/models/test_continue_final_message.py @@ -0,0 +1,81 @@ +import pytest +import requests + + +@pytest.fixture(scope="module") +def llama_continue_final_message_handle(launcher): + with launcher( + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + num_shard=1, + disable_grammar_support=False, + use_flash_attention=False, + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def llama_continue_final_message(llama_continue_final_message_handle): + await llama_continue_final_message_handle.health(300) + return llama_continue_final_message_handle.client + + +def test_llama_completion_single_prompt( + llama_continue_final_message, response_snapshot +): + response = requests.post( + f"{llama_continue_final_message.base_url}/v1/chat/completions", + json={ + "model": "tgi", + "messages": [ + {"role": "system", "content": "system message"}, + {"role": "user", "content": "user message"}, + {"role": "assistant", "content": "assistant message"}, + ], + "max_tokens": 30, + "stream": False, + "seed": 1337, + "continue_final_message": False, + }, + headers=llama_continue_final_message.headers, + stream=False, + ) + response = response.json() + print(response) + assert len(response["choices"]) == 1 + content = response["choices"][0]["message"]["content"] + assert ( + content + == "Hi, I hope this is the right place for your written question. Please provide the maximum possible length to help me complete the message for you! Based" + ) + assert response == response_snapshot + + +def test_llama_completion_single_prompt_continue( + llama_continue_final_message, response_snapshot +): + response = requests.post( + f"{llama_continue_final_message.base_url}/v1/chat/completions", + json={ + "model": "tgi", + "messages": [ + {"role": "system", "content": "system message"}, + {"role": "user", "content": "user message"}, + {"role": "assistant", "content": "assistant message"}, + ], + "max_tokens": 30, + "stream": False, + "seed": 1337, + "continue_final_message": True, + }, + headers=llama_continue_final_message.headers, + stream=False, + ) + response = response.json() + print(response) + assert len(response["choices"]) == 1 + content = response["choices"][0]["message"]["content"] + assert ( + content + == ": Thanks for the awesome slides, they were just what we needed to produce the presentation we needed to deliver for our company's budgeting system" + ) + assert response == response_snapshot