From 00cc73b7b72f7afe485e571736a7c8d2ce1e0c06 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Mon, 1 Jul 2024 12:20:29 +0000 Subject: [PATCH] fix post merge --- integration-tests/clean_cache_and_download.py | 2 +- .../models/test_flash_llama_gptq_marlin.py | 68 ------------------- .../models/test_flash_llama_marlin.py | 2 +- 3 files changed, 2 insertions(+), 70 deletions(-) delete mode 100644 integration-tests/models/test_flash_llama_gptq_marlin.py diff --git a/integration-tests/clean_cache_and_download.py b/integration-tests/clean_cache_and_download.py index 0be55de0..1f4e942b 100644 --- a/integration-tests/clean_cache_and_download.py +++ b/integration-tests/clean_cache_and_download.py @@ -13,7 +13,7 @@ REQUIRED_MODELS = { "openai-community/gpt2": "main", "turboderp/Llama-3-8B-Instruct-exl2": "2.5bpw", "huggingface/llama-7b-gptq": "main", - "neuralmagic/llama-2-7b-chat-marlin": "main", + "astronomer/Llama-3-8B-Instruct-GPTQ-4-Bit": "main", "huggingface/llama-7b": "main", "FasterDecoding/medusa-vicuna-7b-v1.3": "refs/pr/1", "mistralai/Mistral-7B-Instruct-v0.1": "main", diff --git a/integration-tests/models/test_flash_llama_gptq_marlin.py b/integration-tests/models/test_flash_llama_gptq_marlin.py deleted file mode 100644 index ef56371d..00000000 --- a/integration-tests/models/test_flash_llama_gptq_marlin.py +++ /dev/null @@ -1,68 +0,0 @@ -import pytest - - -@pytest.fixture(scope="module") -def flash_llama_gptq_marlin_handle(launcher): - with launcher( - "astronomer/Llama-3-8B-Instruct-GPTQ-4-Bit", num_shard=2, quantize="marlin" - ) as handle: - yield handle - - -@pytest.fixture(scope="module") -async def flash_llama_gptq_marlin(flash_llama_gptq_marlin_handle): - await flash_llama_gptq_marlin_handle.health() - return flash_llama_gptq_marlin_handle.client - - -@pytest.mark.release -@pytest.mark.asyncio -@pytest.mark.private -async def test_flash_llama_gptq_marlin(flash_llama_gptq_marlin, response_snapshot): - response = await flash_llama_gptq_marlin.generate( - "Test request", max_new_tokens=10, decoder_input_details=True - ) - - assert response.details.generated_tokens == 10 - assert response == response_snapshot - - -@pytest.mark.release -@pytest.mark.asyncio -@pytest.mark.private -async def test_flash_llama_gptq_marlin_all_params( - flash_llama_gptq_marlin, response_snapshot -): - response = await flash_llama_gptq_marlin.generate( - "Test request", - max_new_tokens=10, - repetition_penalty=1.2, - return_full_text=True, - temperature=0.5, - top_p=0.9, - top_k=10, - truncate=5, - typical_p=0.9, - watermark=True, - decoder_input_details=True, - seed=0, - ) - - assert response.details.generated_tokens == 10 - assert response == response_snapshot - - -@pytest.mark.release -@pytest.mark.asyncio -@pytest.mark.private -async def test_flash_llama_gptq_marlin_load( - flash_llama_gptq_marlin, generate_load, response_snapshot -): - responses = await generate_load( - flash_llama_gptq_marlin, "Test request", max_new_tokens=10, n=4 - ) - - assert len(responses) == 4 - assert all([r.generated_text == responses[0].generated_text for r in responses]) - - assert responses == response_snapshot diff --git a/integration-tests/models/test_flash_llama_marlin.py b/integration-tests/models/test_flash_llama_marlin.py index 13c51b05..6af9feef 100644 --- a/integration-tests/models/test_flash_llama_marlin.py +++ b/integration-tests/models/test_flash_llama_marlin.py @@ -6,7 +6,7 @@ from testing_utils import SYSTEM @pytest.fixture(scope="module") def flash_llama_marlin_handle(launcher): with launcher( - "neuralmagic/llama-2-7b-chat-marlin", num_shard=2, quantize="marlin" + "astronomer/Llama-3-8B-Instruct-GPTQ-4-Bit", num_shard=2, quantize="marlin" ) as handle: yield handle