do not use tunableop for non flash-causal-lm modezls
This commit is contained in:
parent
c2f4b7f93e
commit
add4d42cb3
|
@ -265,4 +265,10 @@ jobs:
|
|||
echo "DOCKER_VOLUME:"
|
||||
echo $DOCKER_VOLUME
|
||||
|
||||
# TunableOp warmup is rather slow, do it only for a few seqlens.
|
||||
if [[ ${{ inputs.hardware }} == "rocm" ]]
|
||||
then
|
||||
PYTORCH_TUNABLEOP_SEQLENS=2,4
|
||||
fi
|
||||
|
||||
pytest -s -vvvvv integration-tests ${PYTEST_FLAGS}
|
||||
|
|
|
@ -466,6 +466,16 @@ class Mamba(Model):
|
|||
return MambaBatch
|
||||
|
||||
def warmup(self, batch) -> Optional[int]:
|
||||
if SYSTEM == "rocm" and (
|
||||
os.environ.get("PYTORCH_TUNABLEOP_ENABLED") is None
|
||||
or os.environ.get("PYTORCH_TUNABLEOP_ENABLED") == "1"
|
||||
):
|
||||
logger.info(
|
||||
f"ROCm: Got PYTORCH_TUNABLEOP_ENABLED=1 but TunableOp is not supported for {self.model_id} (instance of {self.__class__.__name__}). Disabling TunableOp."
|
||||
)
|
||||
torch.cuda.tunable.tuning_enable(False)
|
||||
torch.cuda.tunable.enable(False)
|
||||
|
||||
# TODO: implement warmup for Mamba if needed
|
||||
if CUDA_GRAPHS:
|
||||
if self.speculate is None or self.speculate == 0:
|
||||
|
|
|
@ -15,6 +15,9 @@ from text_generation_server.utils.adapter import (
|
|||
AdapterParameters,
|
||||
AdapterSource,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.models.globals import CUDA_GRAPHS
|
||||
import os
|
||||
from loguru import logger
|
||||
|
||||
|
||||
|
@ -100,7 +103,23 @@ class Model(ABC):
|
|||
raise NotImplementedError
|
||||
|
||||
def warmup(self, batch: B) -> Optional[int]:
|
||||
if SYSTEM == "rocm" and (
|
||||
os.environ.get("PYTORCH_TUNABLEOP_ENABLED") is None
|
||||
or os.environ.get("PYTORCH_TUNABLEOP_ENABLED") == "1"
|
||||
):
|
||||
logger.info(
|
||||
f"ROCm: Got PYTORCH_TUNABLEOP_ENABLED=1 but TunableOp is not supported for {self.model_id} (instance of {self.__class__.__name__}). Disabling TunableOp."
|
||||
)
|
||||
torch.cuda.tunable.tuning_enable(False)
|
||||
torch.cuda.tunable.enable(False)
|
||||
|
||||
self.generate_token(batch)
|
||||
|
||||
if CUDA_GRAPHS:
|
||||
logger.info(
|
||||
f"Got CUDA_GRAPHS={CUDA_GRAPHS} but cuda graphs are not supported for {self.model_id} (instance of {self.__class__.__name__}). Cuda graphs will not be used."
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def decode_token(
|
||||
|
|
Loading…
Reference in New Issue