skip exl2 tests on rocm
This commit is contained in:
parent
5a4b798f98
commit
406885638b
|
@ -9,7 +9,7 @@ def flash_llama_awq_handle(launcher):
|
|||
# On ROCm, for awq checkpoints, we need to use gptq kernel that supports ROCm.
|
||||
quantize = "gptq"
|
||||
elif SYSTEM == "xpu":
|
||||
pytest.skiptest("AWQ is not supported on xpu")
|
||||
pytest.skip("AWQ is not supported on xpu")
|
||||
else:
|
||||
quantize = "awq"
|
||||
|
||||
|
|
|
@ -4,12 +4,11 @@ from testing_utils import SYSTEM, is_flaky_async
|
|||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@require_backend_async("cuda", "rocm")
|
||||
def flash_llama_awq_handle_sharded(launcher):
|
||||
if SYSTEM == "rocm":
|
||||
# On ROCm, for awq checkpoints, we need to use gptq kernel that supports ROCm.
|
||||
quantize = "gptq"
|
||||
elif SYSTEM == "xpu":
|
||||
pytest.skiptest("AWQ is not supported on xpu")
|
||||
else:
|
||||
quantize = "awq"
|
||||
|
||||
|
@ -22,6 +21,7 @@ def flash_llama_awq_handle_sharded(launcher):
|
|||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@require_backend_async("cuda", "rocm")
|
||||
async def flash_llama_awq_sharded(flash_llama_awq_handle_sharded):
|
||||
await flash_llama_awq_handle_sharded.health(300)
|
||||
return flash_llama_awq_handle_sharded.client
|
||||
|
@ -29,6 +29,7 @@ async def flash_llama_awq_sharded(flash_llama_awq_handle_sharded):
|
|||
|
||||
@is_flaky_async(max_attempts=5)
|
||||
@pytest.mark.asyncio
|
||||
@require_backend_async("cuda", "rocm")
|
||||
async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapshot):
|
||||
response = await flash_llama_awq_sharded.generate(
|
||||
"What is Deep Learning?", max_new_tokens=10, decoder_input_details=True
|
||||
|
@ -47,14 +48,12 @@ async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapsho
|
|||
assert response == response_snapshot
|
||||
|
||||
|
||||
@require_backend_async("cuda")
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_llama_awq_load_sharded(
|
||||
flash_llama_awq_sharded, generate_load, response_snapshot
|
||||
):
|
||||
if SYSTEM == "rocm":
|
||||
pytest.skiptest(
|
||||
"This test relies on ExllamaV2 on ROCm systems, which is highly non-determinstic (flaky)"
|
||||
)
|
||||
# TODO: This test is highly non-deterministic on ROCm.
|
||||
|
||||
responses = await generate_load(
|
||||
flash_llama_awq_sharded, "What is Deep Learning?", max_new_tokens=10, n=4
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
import pytest
|
||||
from testing_utils import require_backend_async
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@require_backend_async("cuda")
|
||||
def flash_llama_exl2_handle(launcher):
|
||||
with launcher(
|
||||
"turboderp/Llama-3-8B-Instruct-exl2",
|
||||
|
@ -16,11 +18,13 @@ def flash_llama_exl2_handle(launcher):
|
|||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@require_backend_async("cuda")
|
||||
async def flash_llama_exl2(flash_llama_exl2_handle):
|
||||
await flash_llama_exl2_handle.health(300)
|
||||
return flash_llama_exl2_handle.client
|
||||
|
||||
|
||||
@require_backend_async("cuda")
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapshot):
|
||||
|
@ -32,6 +36,7 @@ async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapsh
|
|||
assert response == ignore_logprob_response_snapshot
|
||||
|
||||
|
||||
@require_backend_async("cuda")
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_exl2_all_params(
|
||||
|
@ -58,6 +63,7 @@ async def test_flash_llama_exl2_all_params(
|
|||
assert response == ignore_logprob_response_snapshot
|
||||
|
||||
|
||||
@require_backend_async("cuda")
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_exl2_load(
|
||||
|
|
Loading…
Reference in New Issue