diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index 0b239484..206ac84c 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -35,6 +35,12 @@ DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None) HUGGING_FACE_HUB_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN", None) DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data") DOCKER_DEVICES = os.getenv("DOCKER_DEVICES") +SYSTEM = os.getenv("SYSTEM", None) + +if SYSTEM is None: + raise ValueError( + "The environment variable `SYSTEM` needs to be set to run TGI integration tests (one of 'cuda', 'rocm', 'xpu')." + ) class ResponseComparator(JSONSnapshotExtension): @@ -314,6 +320,7 @@ def launcher(event_loop): max_batch_prefill_tokens: Optional[int] = None, max_total_tokens: Optional[int] = None, ): + print("call local_launcher") port = random.randint(8000, 10_000) master_port = random.randint(10_000, 20_000) @@ -368,6 +375,7 @@ def launcher(event_loop): with tempfile.TemporaryFile("w+") as tmp: # We'll output stdout/stderr to a temporary file. Using a pipe # cause the process to block until stdout is read. + print("call subprocess.Popen, with args", args) with subprocess.Popen( args, stdout=tmp, @@ -399,6 +407,7 @@ def launcher(event_loop): max_batch_prefill_tokens: Optional[int] = None, max_total_tokens: Optional[int] = None, ): + print("call docker launcher") port = random.randint(8000, 10_000) args = ["--model-id", model_id, "--env"] @@ -466,6 +475,8 @@ def launcher(event_loop): docker.types.DeviceRequest(count=gpu_count, capabilities=[["gpu"]]) ] + print("call client.containers.run") + print("container_name", container_name) container = client.containers.run( DOCKER_IMAGE, command=args, diff --git a/integration-tests/models/test_flash_llama_marlin.py b/integration-tests/models/test_flash_llama_marlin.py index e7c5ccbd..32fc7a02 100644 --- a/integration-tests/models/test_flash_llama_marlin.py +++ b/integration-tests/models/test_flash_llama_marlin.py @@ -1,5 +1,7 @@ import pytest +from testing_utils import SYSTEM + @pytest.fixture(scope="module") def flash_llama_marlin_handle(launcher): @@ -11,7 +13,13 @@ def flash_llama_marlin_handle(launcher): @pytest.fixture(scope="module") async def flash_llama_marlin(flash_llama_marlin_handle): - await flash_llama_marlin_handle.health(300) + if SYSTEM != "cuda": + with pytest.raises(Exception) as exc_info: + await flash_llama_marlin_handle.health(300) + assert exc_info.value.args[0] == "only available on Nvidia" + pytest.skip(f"Marlin not supported on SYSTEM={SYSTEM}") + else: + await flash_llama_marlin_handle.health(300) return flash_llama_marlin_handle.client diff --git a/server/text_generation_server/layers/marlin.py b/server/text_generation_server/layers/marlin.py index a860d84b..4193f347 100644 --- a/server/text_generation_server/layers/marlin.py +++ b/server/text_generation_server/layers/marlin.py @@ -4,6 +4,8 @@ from typing import Optional import torch import torch.nn as nn +from text_generation_server.utils.import_utils import SYSTEM + try: import marlin except ImportError: @@ -38,6 +40,11 @@ class MarlinLinear(nn.Module): ): super().__init__() + if SYSTEM != "cuda": + raise NotImplementedError( + f"Marlin quantization kernel is only available on Nvidia GPUs, not on the current {SYSTEM} backend." + ) + if not has_sm_8_0: raise NotImplementedError( "Using quantized marlin models requires CUDA capability 8.0 or later"