2024-04-17 02:41:12 -06:00
|
|
|
|
import pytest
|
|
|
|
|
import requests
|
|
|
|
|
import json
|
|
|
|
|
from aiohttp import ClientSession
|
|
|
|
|
|
|
|
|
|
from text_generation.types import (
|
|
|
|
|
Completion,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
|
|
|
def flash_llama_completion_handle(launcher):
|
|
|
|
|
with launcher(
|
2024-09-11 10:10:40 -06:00
|
|
|
|
"meta-llama/Meta-Llama-3.1-8B-Instruct",
|
2024-04-17 02:41:12 -06:00
|
|
|
|
) as handle:
|
|
|
|
|
yield handle
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
|
|
|
async def flash_llama_completion(flash_llama_completion_handle):
|
|
|
|
|
await flash_llama_completion_handle.health(300)
|
|
|
|
|
return flash_llama_completion_handle.client
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# NOTE: since `v1/completions` is a deprecated inferface/endpoint we do not provide a convience
|
|
|
|
|
# method for it. Instead, we use the `requests` library to make the HTTP request directly.
|
|
|
|
|
|
|
|
|
|
|
2024-06-25 08:53:20 -06:00
|
|
|
|
@pytest.mark.release
|
2024-04-17 02:41:12 -06:00
|
|
|
|
def test_flash_llama_completion_single_prompt(
|
|
|
|
|
flash_llama_completion, response_snapshot
|
|
|
|
|
):
|
|
|
|
|
response = requests.post(
|
|
|
|
|
f"{flash_llama_completion.base_url}/v1/completions",
|
|
|
|
|
json={
|
|
|
|
|
"model": "tgi",
|
2024-09-11 10:10:40 -06:00
|
|
|
|
"prompt": "What is Deep Learning?",
|
|
|
|
|
"max_tokens": 10,
|
|
|
|
|
"temperature": 0.0,
|
2024-04-17 02:41:12 -06:00
|
|
|
|
},
|
|
|
|
|
headers=flash_llama_completion.headers,
|
|
|
|
|
stream=False,
|
|
|
|
|
)
|
|
|
|
|
response = response.json()
|
|
|
|
|
assert len(response["choices"]) == 1
|
2024-09-11 10:10:40 -06:00
|
|
|
|
assert (
|
|
|
|
|
response["choices"][0]["text"]
|
|
|
|
|
== " A Beginner’s Guide\nDeep learning is a subset"
|
|
|
|
|
)
|
2024-04-17 02:41:12 -06:00
|
|
|
|
assert response == response_snapshot
|
|
|
|
|
|
|
|
|
|
|
2024-06-25 08:53:20 -06:00
|
|
|
|
@pytest.mark.release
|
2024-04-17 02:41:12 -06:00
|
|
|
|
def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot):
|
|
|
|
|
response = requests.post(
|
|
|
|
|
f"{flash_llama_completion.base_url}/v1/completions",
|
|
|
|
|
json={
|
|
|
|
|
"model": "tgi",
|
2024-09-11 10:10:40 -06:00
|
|
|
|
"prompt": [
|
|
|
|
|
"What is Deep Learning?",
|
|
|
|
|
"Is water wet?",
|
|
|
|
|
"What is the capital of France?",
|
|
|
|
|
"def mai",
|
|
|
|
|
],
|
2024-04-17 02:41:12 -06:00
|
|
|
|
"max_tokens": 10,
|
|
|
|
|
"seed": 0,
|
2024-09-11 10:10:40 -06:00
|
|
|
|
"temperature": 0.0,
|
2024-04-17 02:41:12 -06:00
|
|
|
|
},
|
|
|
|
|
headers=flash_llama_completion.headers,
|
|
|
|
|
stream=False,
|
|
|
|
|
)
|
|
|
|
|
response = response.json()
|
|
|
|
|
assert len(response["choices"]) == 4
|
|
|
|
|
|
2024-09-11 10:10:40 -06:00
|
|
|
|
all_indexes = [(choice["index"], choice["text"]) for choice in response["choices"]]
|
2024-04-17 02:41:12 -06:00
|
|
|
|
all_indexes.sort()
|
2024-09-11 10:10:40 -06:00
|
|
|
|
all_indices, all_strings = zip(*all_indexes)
|
|
|
|
|
assert list(all_indices) == [0, 1, 2, 3]
|
|
|
|
|
assert list(all_strings) == [
|
|
|
|
|
" A Beginner’s Guide\nDeep learning is a subset",
|
|
|
|
|
" This is a question that has puzzled many people for",
|
|
|
|
|
" Paris\nWhat is the capital of France?\nThe",
|
|
|
|
|
'usculas_minusculas(s):\n """\n',
|
|
|
|
|
]
|
2024-04-17 02:41:12 -06:00
|
|
|
|
|
|
|
|
|
assert response == response_snapshot
|
|
|
|
|
|
|
|
|
|
|
2024-06-25 08:53:20 -06:00
|
|
|
|
@pytest.mark.release
|
2024-04-17 02:41:12 -06:00
|
|
|
|
async def test_flash_llama_completion_many_prompts_stream(
|
|
|
|
|
flash_llama_completion, response_snapshot
|
|
|
|
|
):
|
|
|
|
|
request = {
|
|
|
|
|
"model": "tgi",
|
|
|
|
|
"prompt": [
|
2024-09-11 10:10:40 -06:00
|
|
|
|
"What is Deep Learning?",
|
2024-04-17 02:41:12 -06:00
|
|
|
|
"Is water wet?",
|
|
|
|
|
"What is the capital of France?",
|
|
|
|
|
"def mai",
|
|
|
|
|
],
|
|
|
|
|
"max_tokens": 10,
|
|
|
|
|
"seed": 0,
|
2024-09-11 10:10:40 -06:00
|
|
|
|
"temperature": 0.0,
|
2024-04-17 02:41:12 -06:00
|
|
|
|
"stream": True,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
url = f"{flash_llama_completion.base_url}/v1/completions"
|
|
|
|
|
|
|
|
|
|
chunks = []
|
2024-09-11 10:10:40 -06:00
|
|
|
|
strings = [""] * 4
|
2024-04-17 02:41:12 -06:00
|
|
|
|
async with ClientSession(headers=flash_llama_completion.headers) as session:
|
|
|
|
|
async with session.post(url, json=request) as response:
|
|
|
|
|
# iterate over the stream
|
|
|
|
|
async for chunk in response.content.iter_any():
|
|
|
|
|
# remove "data:"
|
|
|
|
|
chunk = chunk.decode().split("\n\n")
|
|
|
|
|
# remove "data:" if present
|
|
|
|
|
chunk = [c.replace("data:", "") for c in chunk]
|
|
|
|
|
# remove empty strings
|
|
|
|
|
chunk = [c for c in chunk if c]
|
Improve the handling of quantized weights (#2250)
* Improve the handling of quantized weights
Handling of quantized weights was split between two mechanisms:
- For quantized checkpoints, we used the new weight loader
infrastructure.
- For quantization while loading (EETQ, FP8, bitsandbytes) we
instead relied on conditional in `get_linear`.
Weight loaders support context managers to selectively load
particular layers with different weight loaders, which is useful
for models like Idefics2 AWQ, which uses a quantized text model,
but unquantized vision and connector models. However, the context
manager would be overrided by `get_linear`, which string-checks
`quantizer`. Also, the context manager would not work with
EETQ, FP8, and bitsandbytes.
This change migrates all quantizers to the weight loader infrastructure.
This has several benefits:
- We can use context managers with all quantizers.
- All the implementation details move down to the quantizer layers,
`get_linear` does not need to know how to handle quantizer linear
layers.
- All quantizer weights are strongly typed, we don't pass around
raw tensors.
- We don't have to pass around the `quantizer` string everywhere.
* Exclude non-MLP layers when using FP8 quantization with Llama
2024-07-19 01:37:39 -06:00
|
|
|
|
# remove completion marking chunk
|
|
|
|
|
chunk = [c for c in chunk if c != " [DONE]"]
|
2024-04-17 02:41:12 -06:00
|
|
|
|
# parse json
|
|
|
|
|
chunk = [json.loads(c) for c in chunk]
|
|
|
|
|
|
|
|
|
|
for c in chunk:
|
|
|
|
|
chunks.append(Completion(**c))
|
|
|
|
|
assert "choices" in c
|
2024-09-11 10:10:40 -06:00
|
|
|
|
index = c["choices"][0]["index"]
|
|
|
|
|
assert 0 <= index <= 4
|
|
|
|
|
strings[index] += c["choices"][0]["text"]
|
2024-04-17 02:41:12 -06:00
|
|
|
|
|
|
|
|
|
assert response.status == 200
|
2024-09-11 10:10:40 -06:00
|
|
|
|
assert list(strings) == [
|
|
|
|
|
" A Beginner’s Guide\nDeep learning is a subset",
|
|
|
|
|
" This is a question that has puzzled many people for",
|
|
|
|
|
" Paris\nWhat is the capital of France?\nThe",
|
|
|
|
|
'usculas_minusculas(s):\n """\n',
|
|
|
|
|
]
|
2024-04-17 02:41:12 -06:00
|
|
|
|
assert chunks == response_snapshot
|