disable marlin tests on rocm/xpu
This commit is contained in:
parent
41699e9bbf
commit
de6f2cd08d
|
@ -35,6 +35,12 @@ DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
|
||||||
HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None)
|
HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None)
|
||||||
DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data")
|
DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data")
|
||||||
DOCKER_DEVICES = os.getenv("DOCKER_DEVICES")
|
DOCKER_DEVICES = os.getenv("DOCKER_DEVICES")
|
||||||
|
SYSTEM = os.getenv("SYSTEM", None)
|
||||||
|
|
||||||
|
if SYSTEM is None:
|
||||||
|
raise ValueError(
|
||||||
|
"The environment variable `SYSTEM` needs to be set to run TGI integration tests (one of 'cuda', 'rocm', 'xpu')."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ResponseComparator(JSONSnapshotExtension):
|
class ResponseComparator(JSONSnapshotExtension):
|
||||||
|
@ -314,6 +320,7 @@ def launcher(event_loop):
|
||||||
max_batch_prefill_tokens: Optional[int] = None,
|
max_batch_prefill_tokens: Optional[int] = None,
|
||||||
max_total_tokens: Optional[int] = None,
|
max_total_tokens: Optional[int] = None,
|
||||||
):
|
):
|
||||||
|
print("call local_launcher")
|
||||||
port = random.randint(8000, 10_000)
|
port = random.randint(8000, 10_000)
|
||||||
master_port = random.randint(10_000, 20_000)
|
master_port = random.randint(10_000, 20_000)
|
||||||
|
|
||||||
|
@ -368,6 +375,7 @@ def launcher(event_loop):
|
||||||
with tempfile.TemporaryFile("w+") as tmp:
|
with tempfile.TemporaryFile("w+") as tmp:
|
||||||
# We'll output stdout/stderr to a temporary file. Using a pipe
|
# We'll output stdout/stderr to a temporary file. Using a pipe
|
||||||
# cause the process to block until stdout is read.
|
# cause the process to block until stdout is read.
|
||||||
|
print("call subprocess.Popen, with args", args)
|
||||||
with subprocess.Popen(
|
with subprocess.Popen(
|
||||||
args,
|
args,
|
||||||
stdout=tmp,
|
stdout=tmp,
|
||||||
|
@ -399,6 +407,7 @@ def launcher(event_loop):
|
||||||
max_batch_prefill_tokens: Optional[int] = None,
|
max_batch_prefill_tokens: Optional[int] = None,
|
||||||
max_total_tokens: Optional[int] = None,
|
max_total_tokens: Optional[int] = None,
|
||||||
):
|
):
|
||||||
|
print("call docker launcher")
|
||||||
port = random.randint(8000, 10_000)
|
port = random.randint(8000, 10_000)
|
||||||
|
|
||||||
args = ["--model-id", model_id, "--env"]
|
args = ["--model-id", model_id, "--env"]
|
||||||
|
@ -466,6 +475,8 @@ def launcher(event_loop):
|
||||||
docker.types.DeviceRequest(count=gpu_count, capabilities=[["gpu"]])
|
docker.types.DeviceRequest(count=gpu_count, capabilities=[["gpu"]])
|
||||||
]
|
]
|
||||||
|
|
||||||
|
print("call client.containers.run")
|
||||||
|
print("container_name", container_name)
|
||||||
container = client.containers.run(
|
container = client.containers.run(
|
||||||
DOCKER_IMAGE,
|
DOCKER_IMAGE,
|
||||||
command=args,
|
command=args,
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from testing_utils import SYSTEM
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def flash_llama_marlin_handle(launcher):
|
def flash_llama_marlin_handle(launcher):
|
||||||
|
@ -11,7 +13,13 @@ def flash_llama_marlin_handle(launcher):
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
async def flash_llama_marlin(flash_llama_marlin_handle):
|
async def flash_llama_marlin(flash_llama_marlin_handle):
|
||||||
await flash_llama_marlin_handle.health(300)
|
if SYSTEM != "cuda":
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
await flash_llama_marlin_handle.health(300)
|
||||||
|
assert exc_info.value.args[0] == "only available on Nvidia"
|
||||||
|
pytest.skip(f"Marlin not supported on SYSTEM={SYSTEM}")
|
||||||
|
else:
|
||||||
|
await flash_llama_marlin_handle.health(300)
|
||||||
return flash_llama_marlin_handle.client
|
return flash_llama_marlin_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,8 @@ from typing import Optional
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import marlin
|
import marlin
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
@ -38,6 +40,11 @@ class MarlinLinear(nn.Module):
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
if SYSTEM != "cuda":
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Marlin quantization kernel is only available on Nvidia GPUs, not on the current {SYSTEM} backend."
|
||||||
|
)
|
||||||
|
|
||||||
if not has_sm_8_0:
|
if not has_sm_8_0:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Using quantized marlin models requires CUDA capability 8.0 or later"
|
"Using quantized marlin models requires CUDA capability 8.0 or later"
|
||||||
|
|
Loading…
Reference in New Issue