diff --git a/launcher/src/main.rs b/launcher/src/main.rs index c40a8461..e4d5bb85 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -481,6 +481,7 @@ fn shard_manager( rope_factor: Option, max_total_tokens: usize, max_batch_size: Option, + max_input_tokens: usize, otlp_endpoint: Option, log_level: LevelFilter, status_sender: mpsc::Sender, @@ -553,6 +554,10 @@ fn shard_manager( shard_args.push(otlp_endpoint); } + // In case we use sliding window, we may ignore the sliding in flash for some backends depending on the parameter. + shard_args.push("--max-input-tokens".to_string()); + shard_args.push(max_input_tokens.to_string()); + // Copy current process env let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect(); @@ -1009,6 +1014,7 @@ fn spawn_shards( args: &Args, cuda_graphs: Vec, max_total_tokens: usize, + max_input_tokens: usize, max_log_level: LevelFilter, shutdown: Arc, shutdown_receiver: &mpsc::Receiver<()>, @@ -1066,6 +1072,7 @@ fn spawn_shards( rope_factor, max_total_tokens, max_batch_size, + max_input_tokens, otlp_endpoint, max_log_level, status_sender, @@ -1540,6 +1547,7 @@ fn main() -> Result<(), LauncherError> { &args, cuda_graphs, max_total_tokens, + max_input_tokens, max_log_level, shutdown.clone(), &shutdown_receiver, diff --git a/server/Makefile-vllm b/server/Makefile-vllm index ded2f5d2..8c0437ea 100644 --- a/server/Makefile-vllm +++ b/server/Makefile-vllm @@ -1,5 +1,5 @@ commit_cuda := b5dfc61db88a81069e45b44f7cc99bd9e62a60fa -commit_rocm := ca6913b3c2ffacdcb7d15e914dc34adbc6c89479 +commit_rocm := 559200c1a028de990c1ddea761b0ccd62109e3a0 build-vllm-cuda: if [ ! -d 'vllm' ]; then \ pip install -U ninja packaging --no-cache-dir && \ @@ -19,5 +19,5 @@ build-vllm-rocm: PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build install-vllm-rocm: build-vllm-rocm - cd vllm && git fetch && git checkout $(commit_rocm) && \ + cd vllm && git fetch && git checkout $(commit_rocm) && \ PYTORCH_ROCM_ARCH="gfx90a;gfx942" pip install -e . diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 68b429d0..430323bc 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -42,6 +42,7 @@ def serve( logger_level: str = "INFO", json_output: bool = False, otlp_endpoint: Optional[str] = None, + max_input_tokens: Optional[int] = None, ): if sharded: assert ( @@ -98,6 +99,7 @@ def serve( dtype, trust_remote_code, uds_path, + max_input_tokens, ) diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index 535810aa..91ed5818 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -169,10 +169,8 @@ if ENGINE == "ck": ): if window_size_left <= 0 and window_size_left != -1: raise ValueError("`window_size_left` must be > 0 or -1") - if window_size_left != -1: - raise ValueError( - f"ROCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})." - ) + + # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. return flash_attn_2_cuda.varlen_fwd( q, k, @@ -204,10 +202,7 @@ elif ENGINE == "triton": window_size_left=-1, causal=True, ): - if window_size_left != -1: - raise ValueError( - f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})." - ) + # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. output, _ = triton_attention( q, k, diff --git a/server/text_generation_server/layers/attention/xpu.py b/server/text_generation_server/layers/attention/xpu.py index d9a096f9..8b6cb87b 100644 --- a/server/text_generation_server/layers/attention/xpu.py +++ b/server/text_generation_server/layers/attention/xpu.py @@ -14,10 +14,7 @@ def attention( softmax_scale, window_size_left=-1, ): - if window_size_left != -1: - raise ValueError( - f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})." - ) + # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. return ipex.llm.functional.varlen_attention( q, k, diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index ba353c11..a61cb83b 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -24,6 +24,8 @@ from text_generation_server.models.t5 import T5Sharded from text_generation_server.models.gpt_neox import GPTNeoxSharded from text_generation_server.models.phi import Phi +from text_generation_server.utils.import_utils import SYSTEM + # The flag below controls whether to allow TF32 on matmul. This flag defaults to False # in PyTorch 1.12 and later. torch.backends.cuda.matmul.allow_tf32 = True @@ -257,6 +259,7 @@ def get_model( speculate: Optional[int], dtype: Optional[str], trust_remote_code: bool, + max_input_tokens: int, ) -> Model: global FLASH_ATTENTION if dtype is None: @@ -410,11 +413,15 @@ def get_model( "Sharding is currently not supported with `exl2` quantization" ) sliding_window = config_dict.get("sliding_window", -1) - if sliding_window != -1 and not SUPPORTS_WINDOWING: - logger.warning( - f"Flash attention is available, but doesn't support windowing which is required by model {model_id}" + + if ( + (sliding_window is not None and sliding_window != -1) + and not SUPPORTS_WINDOWING + and max_input_tokens > sliding_window + ): + raise ValueError( + f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})." ) - FLASH_ATTENTION = False if model_type == MAMBA: return Mamba( @@ -701,7 +708,6 @@ def get_model( ) if model_type == MISTRAL: - sliding_window = config_dict.get("sliding_window", -1) if FLASH_ATTENTION: return FlashMistral( model_id, @@ -724,7 +730,6 @@ def get_model( ) if model_type == MIXTRAL: - sliding_window = config_dict.get("sliding_window", -1) if FLASH_ATTENTION: return FlashMixtral( model_id, @@ -747,7 +752,6 @@ def get_model( ) if model_type == STARCODER2: - sliding_window = config_dict.get("sliding_window", -1) if FLASH_ATTENTION: return FlashStarcoder2( model_id, @@ -771,8 +775,7 @@ def get_model( ) if model_type == QWEN2: - sliding_window = config_dict.get("sliding_window", -1) - if (sliding_window is None or sliding_window != -1) and SUPPORTS_WINDOWING: + if FLASH_ATTENTION: return FlashQwen2( model_id, revision, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index acf77b09..d16d3710 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -902,6 +902,8 @@ class FlashCausalLM(Model): os.environ.get("PYTORCH_TUNABLEOP_ENABLED") is None or os.environ.get("PYTORCH_TUNABLEOP_ENABLED") == "1" ): + torch.cuda.tunable.enable() + if os.environ.get("PYTORCH_TUNABLEOP_TUNING") != "0": torch.cuda.tunable.tuning_enable(True) @@ -910,8 +912,11 @@ class FlashCausalLM(Model): int(val) for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",") ] - else: + elif CUDA_GRAPHS is not None: tuning_sequences = CUDA_GRAPHS + else: + # For seqlen = 1, we dispatch to LLMM1 kernel. + tuning_sequences = [2, 3, 4, 5, 6, 7] tunableop_filepath = os.path.join( HUGGINGFACE_HUB_CACHE, diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 4118b3f6..569b6925 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -199,6 +199,7 @@ def serve( dtype: Optional[str], trust_remote_code: bool, uds_path: Path, + max_input_tokens: int, ): async def serve_inner( model_id: str, @@ -229,6 +230,7 @@ def serve( speculate, dtype, trust_remote_code, + max_input_tokens, ) except Exception: logger.exception("Error when initializing model")