diff --git a/integration-tests/models/test_bloom_560m.py b/integration-tests/models/test_bloom_560m.py index 17b4cd00..f0174730 100644 --- a/integration-tests/models/test_bloom_560m.py +++ b/integration-tests/models/test_bloom_560m.py @@ -1,10 +1,10 @@ import pytest -from testing_utils import require_backend_async +from testing_utils import require_backend_async, require_backend @pytest.fixture(scope="module") -@require_backend_async("cuda") +@require_backend("cuda") def bloom_560_handle(launcher): with launcher("bigscience/bloom-560m") as handle: yield handle diff --git a/integration-tests/models/test_flash_awq_sharded.py b/integration-tests/models/test_flash_awq_sharded.py index b5dc21a9..d7104ff1 100644 --- a/integration-tests/models/test_flash_awq_sharded.py +++ b/integration-tests/models/test_flash_awq_sharded.py @@ -1,10 +1,10 @@ import pytest -from testing_utils import SYSTEM, is_flaky_async, require_backend_async +from testing_utils import SYSTEM, is_flaky_async, require_backend_async, require_backend @pytest.fixture(scope="module") -@require_backend_async("cuda", "rocm") +@require_backend("cuda", "rocm") def flash_llama_awq_handle_sharded(launcher): if SYSTEM == "rocm": # On ROCm, for awq checkpoints, we need to use gptq kernel that supports ROCm. diff --git a/integration-tests/models/test_flash_gemma.py b/integration-tests/models/test_flash_gemma.py index a3a9a910..f6888efe 100644 --- a/integration-tests/models/test_flash_gemma.py +++ b/integration-tests/models/test_flash_gemma.py @@ -1,12 +1,12 @@ import pytest -from testing_utils import require_backend_async +from testing_utils import require_backend_async, require_backend # These tests do not pass on ROCm, that does not support head_dim > 128 (2b model is 256). @pytest.fixture(scope="module") -@require_backend_async("cuda", "xpu") +@require_backend("cuda", "xpu") def flash_gemma_handle(launcher): with launcher("google/gemma-2b", num_shard=1) as handle: yield handle diff --git a/integration-tests/models/test_flash_gemma_gptq.py b/integration-tests/models/test_flash_gemma_gptq.py index 95c7d551..14a8075d 100644 --- a/integration-tests/models/test_flash_gemma_gptq.py +++ b/integration-tests/models/test_flash_gemma_gptq.py @@ -1,10 +1,11 @@ import pytest -from testing_utils import require_backend_async +from testing_utils import require_backend_async, require_backend +# These tests do not pass on ROCm, that does not support head_dim > 128 (2b model is 256). @pytest.fixture(scope="module") -@require_backend_async("cuda", "xpu") +@require_backend("cuda", "xpu") def flash_gemma_gptq_handle(launcher): with launcher("TechxGenus/gemma-2b-GPTQ", num_shard=1, quantize="gptq") as handle: yield handle diff --git a/integration-tests/models/test_flash_llama_exl2.py b/integration-tests/models/test_flash_llama_exl2.py index 4109dc36..2db40257 100644 --- a/integration-tests/models/test_flash_llama_exl2.py +++ b/integration-tests/models/test_flash_llama_exl2.py @@ -1,9 +1,9 @@ import pytest -from testing_utils import require_backend_async +from testing_utils import require_backend_async, require_backend @pytest.fixture(scope="module") -@require_backend_async("cuda") +@require_backend("cuda") def flash_llama_exl2_handle(launcher): with launcher( "turboderp/Llama-3-8B-Instruct-exl2", diff --git a/integration-tests/models/test_flash_pali_gemma.py b/integration-tests/models/test_flash_pali_gemma.py index 23dc5b0e..00c12821 100644 --- a/integration-tests/models/test_flash_pali_gemma.py +++ b/integration-tests/models/test_flash_pali_gemma.py @@ -3,13 +3,13 @@ import requests import io import base64 -from testing_utils import require_backend_async +from testing_utils import require_backend_async, require_backend # These tests do not pass on ROCm, that does not support head_dim > 128 (2b model is 256). @pytest.fixture(scope="module") -@require_backend_async("cuda", "xpu") +@require_backend("cuda", "xpu") def flash_pali_gemma_handle(launcher): with launcher( "google/paligemma-3b-pt-224", diff --git a/integration-tests/models/test_flash_phi.py b/integration-tests/models/test_flash_phi.py index 7b482eb8..9d0abfb3 100644 --- a/integration-tests/models/test_flash_phi.py +++ b/integration-tests/models/test_flash_phi.py @@ -1,12 +1,12 @@ import pytest -from testing_utils import require_backend_async +from testing_utils import require_backend_async, require_backend # These tests do not pass on ROCm, with different generations. @pytest.fixture(scope="module") -@require_backend_async("cuda") +@require_backend("cuda") def flash_phi_handle(launcher): with launcher("microsoft/phi-2", num_shard=1) as handle: yield handle diff --git a/integration-tests/models/test_mamba.py b/integration-tests/models/test_mamba.py index e0d0f58e..0f939705 100644 --- a/integration-tests/models/test_mamba.py +++ b/integration-tests/models/test_mamba.py @@ -1,10 +1,10 @@ import pytest -from testing_utils import require_backend_async +from testing_utils import require_backend_async, require_backend @pytest.fixture(scope="module") -@require_backend_async("cuda") +@require_backend("cuda") def fused_kernel_mamba_handle(launcher): with launcher("state-spaces/mamba-130m", num_shard=1) as handle: yield handle diff --git a/integration-tests/models/testing_utils.py b/integration-tests/models/testing_utils.py index de76463c..606a24c0 100644 --- a/integration-tests/models/testing_utils.py +++ b/integration-tests/models/testing_utils.py @@ -51,6 +51,20 @@ def is_flaky_async( return decorator +def require_backend(*args): + def decorator(func): + @functools.wraps(func) + def wrapper(*wrapper_args, **wrapper_kwargs): + if SYSTEM not in args: + pytest.skip( + f"Skipping as this test requires the backend {args} to be run, but current system is SYSTEM={SYSTEM}." + ) + return func(*wrapper_args, **wrapper_kwargs) + + return wrapper + + return decorator + def require_backend_async(*args): def decorator(func):