do not use tunableop for non flash-causal-lm modezls
This commit is contained in:
parent
c2f4b7f93e
commit
add4d42cb3
|
@ -265,4 +265,10 @@ jobs:
|
||||||
echo "DOCKER_VOLUME:"
|
echo "DOCKER_VOLUME:"
|
||||||
echo $DOCKER_VOLUME
|
echo $DOCKER_VOLUME
|
||||||
|
|
||||||
|
# TunableOp warmup is rather slow, do it only for a few seqlens.
|
||||||
|
if [[ ${{ inputs.hardware }} == "rocm" ]]
|
||||||
|
then
|
||||||
|
PYTORCH_TUNABLEOP_SEQLENS=2,4
|
||||||
|
fi
|
||||||
|
|
||||||
pytest -s -vvvvv integration-tests ${PYTEST_FLAGS}
|
pytest -s -vvvvv integration-tests ${PYTEST_FLAGS}
|
||||||
|
|
|
@ -466,6 +466,16 @@ class Mamba(Model):
|
||||||
return MambaBatch
|
return MambaBatch
|
||||||
|
|
||||||
def warmup(self, batch) -> Optional[int]:
|
def warmup(self, batch) -> Optional[int]:
|
||||||
|
if SYSTEM == "rocm" and (
|
||||||
|
os.environ.get("PYTORCH_TUNABLEOP_ENABLED") is None
|
||||||
|
or os.environ.get("PYTORCH_TUNABLEOP_ENABLED") == "1"
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
f"ROCm: Got PYTORCH_TUNABLEOP_ENABLED=1 but TunableOp is not supported for {self.model_id} (instance of {self.__class__.__name__}). Disabling TunableOp."
|
||||||
|
)
|
||||||
|
torch.cuda.tunable.tuning_enable(False)
|
||||||
|
torch.cuda.tunable.enable(False)
|
||||||
|
|
||||||
# TODO: implement warmup for Mamba if needed
|
# TODO: implement warmup for Mamba if needed
|
||||||
if CUDA_GRAPHS:
|
if CUDA_GRAPHS:
|
||||||
if self.speculate is None or self.speculate == 0:
|
if self.speculate is None or self.speculate == 0:
|
||||||
|
|
|
@ -15,6 +15,9 @@ from text_generation_server.utils.adapter import (
|
||||||
AdapterParameters,
|
AdapterParameters,
|
||||||
AdapterSource,
|
AdapterSource,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
from text_generation_server.models.globals import CUDA_GRAPHS
|
||||||
|
import os
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
|
@ -100,7 +103,23 @@ class Model(ABC):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def warmup(self, batch: B) -> Optional[int]:
|
def warmup(self, batch: B) -> Optional[int]:
|
||||||
|
if SYSTEM == "rocm" and (
|
||||||
|
os.environ.get("PYTORCH_TUNABLEOP_ENABLED") is None
|
||||||
|
or os.environ.get("PYTORCH_TUNABLEOP_ENABLED") == "1"
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
f"ROCm: Got PYTORCH_TUNABLEOP_ENABLED=1 but TunableOp is not supported for {self.model_id} (instance of {self.__class__.__name__}). Disabling TunableOp."
|
||||||
|
)
|
||||||
|
torch.cuda.tunable.tuning_enable(False)
|
||||||
|
torch.cuda.tunable.enable(False)
|
||||||
|
|
||||||
self.generate_token(batch)
|
self.generate_token(batch)
|
||||||
|
|
||||||
|
if CUDA_GRAPHS:
|
||||||
|
logger.info(
|
||||||
|
f"Got CUDA_GRAPHS={CUDA_GRAPHS} but cuda graphs are not supported for {self.model_id} (instance of {self.__class__.__name__}). Cuda graphs will not be used."
|
||||||
|
)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def decode_token(
|
def decode_token(
|
||||||
|
|
Loading…
Reference in New Issue