fix decorators
This commit is contained in:
parent
4616c62914
commit
3de8f3647b
|
@ -1,10 +1,10 @@
|
|||
import pytest
|
||||
|
||||
from testing_utils import require_backend_async
|
||||
from testing_utils import require_backend_async, require_backend
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@require_backend_async("cuda")
|
||||
@require_backend("cuda")
|
||||
def bloom_560_handle(launcher):
|
||||
with launcher("bigscience/bloom-560m") as handle:
|
||||
yield handle
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
import pytest
|
||||
|
||||
from testing_utils import SYSTEM, is_flaky_async, require_backend_async
|
||||
from testing_utils import SYSTEM, is_flaky_async, require_backend_async, require_backend
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@require_backend_async("cuda", "rocm")
|
||||
@require_backend("cuda", "rocm")
|
||||
def flash_llama_awq_handle_sharded(launcher):
|
||||
if SYSTEM == "rocm":
|
||||
# On ROCm, for awq checkpoints, we need to use gptq kernel that supports ROCm.
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
import pytest
|
||||
|
||||
from testing_utils import require_backend_async
|
||||
from testing_utils import require_backend_async, require_backend
|
||||
|
||||
# These tests do not pass on ROCm, that does not support head_dim > 128 (2b model is 256).
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@require_backend_async("cuda", "xpu")
|
||||
@require_backend("cuda", "xpu")
|
||||
def flash_gemma_handle(launcher):
|
||||
with launcher("google/gemma-2b", num_shard=1) as handle:
|
||||
yield handle
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
import pytest
|
||||
|
||||
from testing_utils import require_backend_async
|
||||
from testing_utils import require_backend_async, require_backend
|
||||
|
||||
# These tests do not pass on ROCm, that does not support head_dim > 128 (2b model is 256).
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@require_backend_async("cuda", "xpu")
|
||||
@require_backend("cuda", "xpu")
|
||||
def flash_gemma_gptq_handle(launcher):
|
||||
with launcher("TechxGenus/gemma-2b-GPTQ", num_shard=1, quantize="gptq") as handle:
|
||||
yield handle
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
import pytest
|
||||
from testing_utils import require_backend_async
|
||||
from testing_utils import require_backend_async, require_backend
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@require_backend_async("cuda")
|
||||
@require_backend("cuda")
|
||||
def flash_llama_exl2_handle(launcher):
|
||||
with launcher(
|
||||
"turboderp/Llama-3-8B-Instruct-exl2",
|
||||
|
|
|
@ -3,13 +3,13 @@ import requests
|
|||
import io
|
||||
import base64
|
||||
|
||||
from testing_utils import require_backend_async
|
||||
from testing_utils import require_backend_async, require_backend
|
||||
|
||||
# These tests do not pass on ROCm, that does not support head_dim > 128 (2b model is 256).
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@require_backend_async("cuda", "xpu")
|
||||
@require_backend("cuda", "xpu")
|
||||
def flash_pali_gemma_handle(launcher):
|
||||
with launcher(
|
||||
"google/paligemma-3b-pt-224",
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
import pytest
|
||||
|
||||
from testing_utils import require_backend_async
|
||||
from testing_utils import require_backend_async, require_backend
|
||||
|
||||
# These tests do not pass on ROCm, with different generations.
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@require_backend_async("cuda")
|
||||
@require_backend("cuda")
|
||||
def flash_phi_handle(launcher):
|
||||
with launcher("microsoft/phi-2", num_shard=1) as handle:
|
||||
yield handle
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
import pytest
|
||||
|
||||
from testing_utils import require_backend_async
|
||||
from testing_utils import require_backend_async, require_backend
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@require_backend_async("cuda")
|
||||
@require_backend("cuda")
|
||||
def fused_kernel_mamba_handle(launcher):
|
||||
with launcher("state-spaces/mamba-130m", num_shard=1) as handle:
|
||||
yield handle
|
||||
|
|
|
@ -51,6 +51,20 @@ def is_flaky_async(
|
|||
|
||||
return decorator
|
||||
|
||||
def require_backend(*args):
|
||||
def decorator(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*wrapper_args, **wrapper_kwargs):
|
||||
if SYSTEM not in args:
|
||||
pytest.skip(
|
||||
f"Skipping as this test requires the backend {args} to be run, but current system is SYSTEM={SYSTEM}."
|
||||
)
|
||||
return func(*wrapper_args, **wrapper_kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def require_backend_async(*args):
|
||||
def decorator(func):
|
||||
|
|
Loading…
Reference in New Issue