disable marlin tests on rocm/xpu

This commit is contained in:
fxmarty 2024-06-10 13:06:11 +00:00
parent 41699e9bbf
commit de6f2cd08d
3 changed files with 27 additions and 1 deletions

View File

@ -35,6 +35,12 @@ DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None)
DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data")
DOCKER_DEVICES = os.getenv("DOCKER_DEVICES")
SYSTEM = os.getenv("SYSTEM", None)
if SYSTEM is None:
raise ValueError(
"The environment variable `SYSTEM` needs to be set to run TGI integration tests (one of 'cuda', 'rocm', 'xpu')."
)
class ResponseComparator(JSONSnapshotExtension):
@ -314,6 +320,7 @@ def launcher(event_loop):
max_batch_prefill_tokens: Optional[int] = None,
max_total_tokens: Optional[int] = None,
):
print("call local_launcher")
port = random.randint(8000, 10_000)
master_port = random.randint(10_000, 20_000)
@ -368,6 +375,7 @@ def launcher(event_loop):
with tempfile.TemporaryFile("w+") as tmp:
# We'll output stdout/stderr to a temporary file. Using a pipe
# cause the process to block until stdout is read.
print("call subprocess.Popen, with args", args)
with subprocess.Popen(
args,
stdout=tmp,
@ -399,6 +407,7 @@ def launcher(event_loop):
max_batch_prefill_tokens: Optional[int] = None,
max_total_tokens: Optional[int] = None,
):
print("call docker launcher")
port = random.randint(8000, 10_000)
args = ["--model-id", model_id, "--env"]
@ -466,6 +475,8 @@ def launcher(event_loop):
docker.types.DeviceRequest(count=gpu_count, capabilities=[["gpu"]])
]
print("call client.containers.run")
print("container_name", container_name)
container = client.containers.run(
DOCKER_IMAGE,
command=args,

View File

@ -1,5 +1,7 @@
import pytest
from testing_utils import SYSTEM
@pytest.fixture(scope="module")
def flash_llama_marlin_handle(launcher):
@ -11,7 +13,13 @@ def flash_llama_marlin_handle(launcher):
@pytest.fixture(scope="module")
async def flash_llama_marlin(flash_llama_marlin_handle):
await flash_llama_marlin_handle.health(300)
if SYSTEM != "cuda":
with pytest.raises(Exception) as exc_info:
await flash_llama_marlin_handle.health(300)
assert exc_info.value.args[0] == "only available on Nvidia"
pytest.skip(f"Marlin not supported on SYSTEM={SYSTEM}")
else:
await flash_llama_marlin_handle.health(300)
return flash_llama_marlin_handle.client

View File

@ -4,6 +4,8 @@ from typing import Optional
import torch
import torch.nn as nn
from text_generation_server.utils.import_utils import SYSTEM
try:
import marlin
except ImportError:
@ -38,6 +40,11 @@ class MarlinLinear(nn.Module):
):
super().__init__()
if SYSTEM != "cuda":
raise NotImplementedError(
f"Marlin quantization kernel is only available on Nvidia GPUs, not on the current {SYSTEM} backend."
)
if not has_sm_8_0:
raise NotImplementedError(
"Using quantized marlin models requires CUDA capability 8.0 or later"