fix post merge

This commit is contained in:
Felix Marty 2024-07-01 12:20:29 +00:00
parent 9fd395fae4
commit 00cc73b7b7
3 changed files with 2 additions and 70 deletions

View File

@ -13,7 +13,7 @@ REQUIRED_MODELS = {
"openai-community/gpt2": "main", "openai-community/gpt2": "main",
"turboderp/Llama-3-8B-Instruct-exl2": "2.5bpw", "turboderp/Llama-3-8B-Instruct-exl2": "2.5bpw",
"huggingface/llama-7b-gptq": "main", "huggingface/llama-7b-gptq": "main",
"neuralmagic/llama-2-7b-chat-marlin": "main", "astronomer/Llama-3-8B-Instruct-GPTQ-4-Bit": "main",
"huggingface/llama-7b": "main", "huggingface/llama-7b": "main",
"FasterDecoding/medusa-vicuna-7b-v1.3": "refs/pr/1", "FasterDecoding/medusa-vicuna-7b-v1.3": "refs/pr/1",
"mistralai/Mistral-7B-Instruct-v0.1": "main", "mistralai/Mistral-7B-Instruct-v0.1": "main",

View File

@ -1,68 +0,0 @@
import pytest
@pytest.fixture(scope="module")
def flash_llama_gptq_marlin_handle(launcher):
with launcher(
"astronomer/Llama-3-8B-Instruct-GPTQ-4-Bit", num_shard=2, quantize="marlin"
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_llama_gptq_marlin(flash_llama_gptq_marlin_handle):
await flash_llama_gptq_marlin_handle.health()
return flash_llama_gptq_marlin_handle.client
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_gptq_marlin(flash_llama_gptq_marlin, response_snapshot):
response = await flash_llama_gptq_marlin.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_gptq_marlin_all_params(
flash_llama_gptq_marlin, response_snapshot
):
response = await flash_llama_gptq_marlin.generate(
"Test request",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_gptq_marlin_load(
flash_llama_gptq_marlin, generate_load, response_snapshot
):
responses = await generate_load(
flash_llama_gptq_marlin, "Test request", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot

View File

@ -6,7 +6,7 @@ from testing_utils import SYSTEM
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def flash_llama_marlin_handle(launcher): def flash_llama_marlin_handle(launcher):
with launcher( with launcher(
"neuralmagic/llama-2-7b-chat-marlin", num_shard=2, quantize="marlin" "astronomer/Llama-3-8B-Instruct-GPTQ-4-Bit", num_shard=2, quantize="marlin"
) as handle: ) as handle:
yield handle yield handle