From 406885638bbc2fc664b2ea7f98974e053d8d487a Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Tue, 11 Jun 2024 09:29:08 +0000 Subject: [PATCH] skip exl2 tests on rocm --- integration-tests/models/test_flash_awq.py | 2 +- integration-tests/models/test_flash_awq_sharded.py | 11 +++++------ integration-tests/models/test_flash_llama_exl2.py | 6 ++++++ 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/integration-tests/models/test_flash_awq.py b/integration-tests/models/test_flash_awq.py index d3a17ab5..ff4c761c 100644 --- a/integration-tests/models/test_flash_awq.py +++ b/integration-tests/models/test_flash_awq.py @@ -9,7 +9,7 @@ def flash_llama_awq_handle(launcher): # On ROCm, for awq checkpoints, we need to use gptq kernel that supports ROCm. quantize = "gptq" elif SYSTEM == "xpu": - pytest.skiptest("AWQ is not supported on xpu") + pytest.skip("AWQ is not supported on xpu") else: quantize = "awq" diff --git a/integration-tests/models/test_flash_awq_sharded.py b/integration-tests/models/test_flash_awq_sharded.py index 4486fb6f..a6d5db1e 100644 --- a/integration-tests/models/test_flash_awq_sharded.py +++ b/integration-tests/models/test_flash_awq_sharded.py @@ -4,12 +4,11 @@ from testing_utils import SYSTEM, is_flaky_async @pytest.fixture(scope="module") +@require_backend_async("cuda", "rocm") def flash_llama_awq_handle_sharded(launcher): if SYSTEM == "rocm": # On ROCm, for awq checkpoints, we need to use gptq kernel that supports ROCm. quantize = "gptq" - elif SYSTEM == "xpu": - pytest.skiptest("AWQ is not supported on xpu") else: quantize = "awq" @@ -22,6 +21,7 @@ def flash_llama_awq_handle_sharded(launcher): @pytest.fixture(scope="module") +@require_backend_async("cuda", "rocm") async def flash_llama_awq_sharded(flash_llama_awq_handle_sharded): await flash_llama_awq_handle_sharded.health(300) return flash_llama_awq_handle_sharded.client @@ -29,6 +29,7 @@ async def flash_llama_awq_sharded(flash_llama_awq_handle_sharded): @is_flaky_async(max_attempts=5) @pytest.mark.asyncio +@require_backend_async("cuda", "rocm") async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapshot): response = await flash_llama_awq_sharded.generate( "What is Deep Learning?", max_new_tokens=10, decoder_input_details=True @@ -47,14 +48,12 @@ async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapsho assert response == response_snapshot +@require_backend_async("cuda") @pytest.mark.asyncio async def test_flash_llama_awq_load_sharded( flash_llama_awq_sharded, generate_load, response_snapshot ): - if SYSTEM == "rocm": - pytest.skiptest( - "This test relies on ExllamaV2 on ROCm systems, which is highly non-determinstic (flaky)" - ) + # TODO: This test is highly non-deterministic on ROCm. responses = await generate_load( flash_llama_awq_sharded, "What is Deep Learning?", max_new_tokens=10, n=4 diff --git a/integration-tests/models/test_flash_llama_exl2.py b/integration-tests/models/test_flash_llama_exl2.py index 18319f60..4109dc36 100644 --- a/integration-tests/models/test_flash_llama_exl2.py +++ b/integration-tests/models/test_flash_llama_exl2.py @@ -1,7 +1,9 @@ import pytest +from testing_utils import require_backend_async @pytest.fixture(scope="module") +@require_backend_async("cuda") def flash_llama_exl2_handle(launcher): with launcher( "turboderp/Llama-3-8B-Instruct-exl2", @@ -16,11 +18,13 @@ def flash_llama_exl2_handle(launcher): @pytest.fixture(scope="module") +@require_backend_async("cuda") async def flash_llama_exl2(flash_llama_exl2_handle): await flash_llama_exl2_handle.health(300) return flash_llama_exl2_handle.client +@require_backend_async("cuda") @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapshot): @@ -32,6 +36,7 @@ async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapsh assert response == ignore_logprob_response_snapshot +@require_backend_async("cuda") @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_exl2_all_params( @@ -58,6 +63,7 @@ async def test_flash_llama_exl2_all_params( assert response == ignore_logprob_response_snapshot +@require_backend_async("cuda") @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_exl2_load(