skip exl2 tests on rocm

This commit is contained in:
fxmarty 2024-06-11 09:29:08 +00:00 committed by Nicolas Patry
parent 5a4b798f98
commit 406885638b
No known key found for this signature in database
GPG Key ID: E939E8CC91A1C674
3 changed files with 12 additions and 7 deletions

View File

@ -9,7 +9,7 @@ def flash_llama_awq_handle(launcher):
# On ROCm, for awq checkpoints, we need to use gptq kernel that supports ROCm.
quantize = "gptq"
elif SYSTEM == "xpu":
pytest.skiptest("AWQ is not supported on xpu")
pytest.skip("AWQ is not supported on xpu")
else:
quantize = "awq"

View File

@ -4,12 +4,11 @@ from testing_utils import SYSTEM, is_flaky_async
@pytest.fixture(scope="module")
@require_backend_async("cuda", "rocm")
def flash_llama_awq_handle_sharded(launcher):
if SYSTEM == "rocm":
# On ROCm, for awq checkpoints, we need to use gptq kernel that supports ROCm.
quantize = "gptq"
elif SYSTEM == "xpu":
pytest.skiptest("AWQ is not supported on xpu")
else:
quantize = "awq"
@ -22,6 +21,7 @@ def flash_llama_awq_handle_sharded(launcher):
@pytest.fixture(scope="module")
@require_backend_async("cuda", "rocm")
async def flash_llama_awq_sharded(flash_llama_awq_handle_sharded):
await flash_llama_awq_handle_sharded.health(300)
return flash_llama_awq_handle_sharded.client
@ -29,6 +29,7 @@ async def flash_llama_awq_sharded(flash_llama_awq_handle_sharded):
@is_flaky_async(max_attempts=5)
@pytest.mark.asyncio
@require_backend_async("cuda", "rocm")
async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapshot):
response = await flash_llama_awq_sharded.generate(
"What is Deep Learning?", max_new_tokens=10, decoder_input_details=True
@ -47,14 +48,12 @@ async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapsho
assert response == response_snapshot
@require_backend_async("cuda")
@pytest.mark.asyncio
async def test_flash_llama_awq_load_sharded(
flash_llama_awq_sharded, generate_load, response_snapshot
):
if SYSTEM == "rocm":
pytest.skiptest(
"This test relies on ExllamaV2 on ROCm systems, which is highly non-determinstic (flaky)"
)
# TODO: This test is highly non-deterministic on ROCm.
responses = await generate_load(
flash_llama_awq_sharded, "What is Deep Learning?", max_new_tokens=10, n=4

View File

@ -1,7 +1,9 @@
import pytest
from testing_utils import require_backend_async
@pytest.fixture(scope="module")
@require_backend_async("cuda")
def flash_llama_exl2_handle(launcher):
with launcher(
"turboderp/Llama-3-8B-Instruct-exl2",
@ -16,11 +18,13 @@ def flash_llama_exl2_handle(launcher):
@pytest.fixture(scope="module")
@require_backend_async("cuda")
async def flash_llama_exl2(flash_llama_exl2_handle):
await flash_llama_exl2_handle.health(300)
return flash_llama_exl2_handle.client
@require_backend_async("cuda")
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapshot):
@ -32,6 +36,7 @@ async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapsh
assert response == ignore_logprob_response_snapshot
@require_backend_async("cuda")
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_exl2_all_params(
@ -58,6 +63,7 @@ async def test_flash_llama_exl2_all_params(
assert response == ignore_logprob_response_snapshot
@require_backend_async("cuda")
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_exl2_load(