diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 18b3a09f..121917f0 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -265,4 +265,10 @@ jobs: echo "DOCKER_VOLUME:" echo $DOCKER_VOLUME + # TunableOp warmup is rather slow, do it only for a few seqlens. + if [[ ${{ inputs.hardware }} == "rocm" ]] + then + PYTORCH_TUNABLEOP_SEQLENS=2,4 + fi + pytest -s -vvvvv integration-tests ${PYTEST_FLAGS} diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index 9189b45c..1ce8346c 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -466,6 +466,16 @@ class Mamba(Model): return MambaBatch def warmup(self, batch) -> Optional[int]: + if SYSTEM == "rocm" and ( + os.environ.get("PYTORCH_TUNABLEOP_ENABLED") is None + or os.environ.get("PYTORCH_TUNABLEOP_ENABLED") == "1" + ): + logger.info( + f"ROCm: Got PYTORCH_TUNABLEOP_ENABLED=1 but TunableOp is not supported for {self.model_id} (instance of {self.__class__.__name__}). Disabling TunableOp." + ) + torch.cuda.tunable.tuning_enable(False) + torch.cuda.tunable.enable(False) + # TODO: implement warmup for Mamba if needed if CUDA_GRAPHS: if self.speculate is None or self.speculate == 0: diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index c90fd38a..819414aa 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -15,6 +15,9 @@ from text_generation_server.utils.adapter import ( AdapterParameters, AdapterSource, ) +from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.models.globals import CUDA_GRAPHS +import os from loguru import logger @@ -100,7 +103,23 @@ class Model(ABC): raise NotImplementedError def warmup(self, batch: B) -> Optional[int]: + if SYSTEM == "rocm" and ( + os.environ.get("PYTORCH_TUNABLEOP_ENABLED") is None + or os.environ.get("PYTORCH_TUNABLEOP_ENABLED") == "1" + ): + logger.info( + f"ROCm: Got PYTORCH_TUNABLEOP_ENABLED=1 but TunableOp is not supported for {self.model_id} (instance of {self.__class__.__name__}). Disabling TunableOp." + ) + torch.cuda.tunable.tuning_enable(False) + torch.cuda.tunable.enable(False) + self.generate_token(batch) + + if CUDA_GRAPHS: + logger.info( + f"Got CUDA_GRAPHS={CUDA_GRAPHS} but cuda graphs are not supported for {self.model_id} (instance of {self.__class__.__name__}). Cuda graphs will not be used." + ) + return None def decode_token(