fix decorators

This commit is contained in:
fxmarty 2024-06-14 07:45:58 +00:00 committed by Nicolas Patry
parent 4616c62914
commit 3de8f3647b
No known key found for this signature in database
GPG Key ID: E939E8CC91A1C674
9 changed files with 31 additions and 16 deletions

View File

@ -1,10 +1,10 @@
import pytest
from testing_utils import require_backend_async
from testing_utils import require_backend_async, require_backend
@pytest.fixture(scope="module")
@require_backend_async("cuda")
@require_backend("cuda")
def bloom_560_handle(launcher):
with launcher("bigscience/bloom-560m") as handle:
yield handle

View File

@ -1,10 +1,10 @@
import pytest
from testing_utils import SYSTEM, is_flaky_async, require_backend_async
from testing_utils import SYSTEM, is_flaky_async, require_backend_async, require_backend
@pytest.fixture(scope="module")
@require_backend_async("cuda", "rocm")
@require_backend("cuda", "rocm")
def flash_llama_awq_handle_sharded(launcher):
if SYSTEM == "rocm":
# On ROCm, for awq checkpoints, we need to use gptq kernel that supports ROCm.

View File

@ -1,12 +1,12 @@
import pytest
from testing_utils import require_backend_async
from testing_utils import require_backend_async, require_backend
# These tests do not pass on ROCm, that does not support head_dim > 128 (2b model is 256).
@pytest.fixture(scope="module")
@require_backend_async("cuda", "xpu")
@require_backend("cuda", "xpu")
def flash_gemma_handle(launcher):
with launcher("google/gemma-2b", num_shard=1) as handle:
yield handle

View File

@ -1,10 +1,11 @@
import pytest
from testing_utils import require_backend_async
from testing_utils import require_backend_async, require_backend
# These tests do not pass on ROCm, that does not support head_dim > 128 (2b model is 256).
@pytest.fixture(scope="module")
@require_backend_async("cuda", "xpu")
@require_backend("cuda", "xpu")
def flash_gemma_gptq_handle(launcher):
with launcher("TechxGenus/gemma-2b-GPTQ", num_shard=1, quantize="gptq") as handle:
yield handle

View File

@ -1,9 +1,9 @@
import pytest
from testing_utils import require_backend_async
from testing_utils import require_backend_async, require_backend
@pytest.fixture(scope="module")
@require_backend_async("cuda")
@require_backend("cuda")
def flash_llama_exl2_handle(launcher):
with launcher(
"turboderp/Llama-3-8B-Instruct-exl2",

View File

@ -3,13 +3,13 @@ import requests
import io
import base64
from testing_utils import require_backend_async
from testing_utils import require_backend_async, require_backend
# These tests do not pass on ROCm, that does not support head_dim > 128 (2b model is 256).
@pytest.fixture(scope="module")
@require_backend_async("cuda", "xpu")
@require_backend("cuda", "xpu")
def flash_pali_gemma_handle(launcher):
with launcher(
"google/paligemma-3b-pt-224",

View File

@ -1,12 +1,12 @@
import pytest
from testing_utils import require_backend_async
from testing_utils import require_backend_async, require_backend
# These tests do not pass on ROCm, with different generations.
@pytest.fixture(scope="module")
@require_backend_async("cuda")
@require_backend("cuda")
def flash_phi_handle(launcher):
with launcher("microsoft/phi-2", num_shard=1) as handle:
yield handle

View File

@ -1,10 +1,10 @@
import pytest
from testing_utils import require_backend_async
from testing_utils import require_backend_async, require_backend
@pytest.fixture(scope="module")
@require_backend_async("cuda")
@require_backend("cuda")
def fused_kernel_mamba_handle(launcher):
with launcher("state-spaces/mamba-130m", num_shard=1) as handle:
yield handle

View File

@ -51,6 +51,20 @@ def is_flaky_async(
return decorator
def require_backend(*args):
def decorator(func):
@functools.wraps(func)
def wrapper(*wrapper_args, **wrapper_kwargs):
if SYSTEM not in args:
pytest.skip(
f"Skipping as this test requires the backend {args} to be run, but current system is SYSTEM={SYSTEM}."
)
return func(*wrapper_args, **wrapper_kwargs)
return wrapper
return decorator
def require_backend_async(*args):
def decorator(func):