110 lines
3.2 KiB
Python
110 lines
3.2 KiB
Python
|
import pytest
|
||
|
import requests
|
||
|
import json
|
||
|
from aiohttp import ClientSession
|
||
|
|
||
|
from text_generation.types import (
|
||
|
Completion,
|
||
|
)
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="module")
|
||
|
def flash_llama_completion_handle(launcher):
|
||
|
with launcher(
|
||
|
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||
|
) as handle:
|
||
|
yield handle
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="module")
|
||
|
async def flash_llama_completion(flash_llama_completion_handle):
|
||
|
await flash_llama_completion_handle.health(300)
|
||
|
return flash_llama_completion_handle.client
|
||
|
|
||
|
|
||
|
# NOTE: since `v1/completions` is a deprecated inferface/endpoint we do not provide a convience
|
||
|
# method for it. Instead, we use the `requests` library to make the HTTP request directly.
|
||
|
|
||
|
|
||
|
def test_flash_llama_completion_single_prompt(
|
||
|
flash_llama_completion, response_snapshot
|
||
|
):
|
||
|
response = requests.post(
|
||
|
f"{flash_llama_completion.base_url}/v1/completions",
|
||
|
json={
|
||
|
"model": "tgi",
|
||
|
"prompt": "Say this is a test",
|
||
|
"max_tokens": 5,
|
||
|
"seed": 0,
|
||
|
},
|
||
|
headers=flash_llama_completion.headers,
|
||
|
stream=False,
|
||
|
)
|
||
|
response = response.json()
|
||
|
assert len(response["choices"]) == 1
|
||
|
|
||
|
assert response == response_snapshot
|
||
|
|
||
|
|
||
|
def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot):
|
||
|
response = requests.post(
|
||
|
f"{flash_llama_completion.base_url}/v1/completions",
|
||
|
json={
|
||
|
"model": "tgi",
|
||
|
"prompt": ["Say", "this", "is", "a"],
|
||
|
"max_tokens": 10,
|
||
|
"seed": 0,
|
||
|
},
|
||
|
headers=flash_llama_completion.headers,
|
||
|
stream=False,
|
||
|
)
|
||
|
response = response.json()
|
||
|
assert len(response["choices"]) == 4
|
||
|
|
||
|
all_indexes = [choice["index"] for choice in response["choices"]]
|
||
|
all_indexes.sort()
|
||
|
assert all_indexes == [0, 1, 2, 3]
|
||
|
|
||
|
assert response == response_snapshot
|
||
|
|
||
|
|
||
|
async def test_flash_llama_completion_many_prompts_stream(
|
||
|
flash_llama_completion, response_snapshot
|
||
|
):
|
||
|
request = {
|
||
|
"model": "tgi",
|
||
|
"prompt": [
|
||
|
"What color is the sky?",
|
||
|
"Is water wet?",
|
||
|
"What is the capital of France?",
|
||
|
"def mai",
|
||
|
],
|
||
|
"max_tokens": 10,
|
||
|
"seed": 0,
|
||
|
"stream": True,
|
||
|
}
|
||
|
|
||
|
url = f"{flash_llama_completion.base_url}/v1/completions"
|
||
|
|
||
|
chunks = []
|
||
|
async with ClientSession(headers=flash_llama_completion.headers) as session:
|
||
|
async with session.post(url, json=request) as response:
|
||
|
# iterate over the stream
|
||
|
async for chunk in response.content.iter_any():
|
||
|
# remove "data:"
|
||
|
chunk = chunk.decode().split("\n\n")
|
||
|
# remove "data:" if present
|
||
|
chunk = [c.replace("data:", "") for c in chunk]
|
||
|
# remove empty strings
|
||
|
chunk = [c for c in chunk if c]
|
||
|
# parse json
|
||
|
chunk = [json.loads(c) for c in chunk]
|
||
|
|
||
|
for c in chunk:
|
||
|
chunks.append(Completion(**c))
|
||
|
assert "choices" in c
|
||
|
assert 0 <= c["choices"][0]["index"] <= 4
|
||
|
|
||
|
assert response.status == 200
|
||
|
assert chunks == response_snapshot
|