do not use tunableop for non flash-causal-lm modezls

This commit is contained in:
Felix Marty 2024-07-02 12:52:55 +00:00
parent c2f4b7f93e
commit add4d42cb3
3 changed files with 35 additions and 0 deletions

View File

@ -265,4 +265,10 @@ jobs:
echo "DOCKER_VOLUME:"
echo $DOCKER_VOLUME
# TunableOp warmup is rather slow, do it only for a few seqlens.
if [[ ${{ inputs.hardware }} == "rocm" ]]
then
PYTORCH_TUNABLEOP_SEQLENS=2,4
fi
pytest -s -vvvvv integration-tests ${PYTEST_FLAGS}

View File

@ -466,6 +466,16 @@ class Mamba(Model):
return MambaBatch
def warmup(self, batch) -> Optional[int]:
if SYSTEM == "rocm" and (
os.environ.get("PYTORCH_TUNABLEOP_ENABLED") is None
or os.environ.get("PYTORCH_TUNABLEOP_ENABLED") == "1"
):
logger.info(
f"ROCm: Got PYTORCH_TUNABLEOP_ENABLED=1 but TunableOp is not supported for {self.model_id} (instance of {self.__class__.__name__}). Disabling TunableOp."
)
torch.cuda.tunable.tuning_enable(False)
torch.cuda.tunable.enable(False)
# TODO: implement warmup for Mamba if needed
if CUDA_GRAPHS:
if self.speculate is None or self.speculate == 0:

View File

@ -15,6 +15,9 @@ from text_generation_server.utils.adapter import (
AdapterParameters,
AdapterSource,
)
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import CUDA_GRAPHS
import os
from loguru import logger
@ -100,7 +103,23 @@ class Model(ABC):
raise NotImplementedError
def warmup(self, batch: B) -> Optional[int]:
if SYSTEM == "rocm" and (
os.environ.get("PYTORCH_TUNABLEOP_ENABLED") is None
or os.environ.get("PYTORCH_TUNABLEOP_ENABLED") == "1"
):
logger.info(
f"ROCm: Got PYTORCH_TUNABLEOP_ENABLED=1 but TunableOp is not supported for {self.model_id} (instance of {self.__class__.__name__}). Disabling TunableOp."
)
torch.cuda.tunable.tuning_enable(False)
torch.cuda.tunable.enable(False)
self.generate_token(batch)
if CUDA_GRAPHS:
logger.info(
f"Got CUDA_GRAPHS={CUDA_GRAPHS} but cuda graphs are not supported for {self.model_id} (instance of {self.__class__.__name__}). Cuda graphs will not be used."
)
return None
def decode_token(