diff --git a/router/src/infer/v2/scheduler.rs b/router/src/infer/v2/scheduler.rs index ba6f520d..e4c3de26 100644 --- a/router/src/infer/v2/scheduler.rs +++ b/router/src/infer/v2/scheduler.rs @@ -39,7 +39,14 @@ impl SchedulerV2 { speculate: u32, generation_health: Arc, ) -> Self { - let queue = Queue::new(requires_padding, 16, window_size, speculate); + // Infer shared state + let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") { + matches!(flashdecoding.to_lowercase().as_str(), "1" | "true") + } else { + false + }; + let block_size = if flashdecoding { 256 } else { 16 }; + let queue = Queue::new(requires_padding, block_size, window_size, speculate); let batching_task_notifier = Arc::new(Notify::new()); // Spawn batching background task that contains all the inference logic diff --git a/router/src/infer/v3/scheduler.rs b/router/src/infer/v3/scheduler.rs index ad03dd83..543ce89f 100644 --- a/router/src/infer/v3/scheduler.rs +++ b/router/src/infer/v3/scheduler.rs @@ -39,9 +39,15 @@ impl SchedulerV3 { speculate: u32, generation_health: Arc, ) -> Self { + let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") { + matches!(flashdecoding.to_lowercase().as_str(), "1" | "true") + } else { + false + }; + let block_size = if flashdecoding { 256 } else { 16 }; let queue = Queue::new( requires_padding, - 16, + block_size, window_size, speculate, max_batch_total_tokens, diff --git a/server/text_generation_server/layers/attention/__init__.py b/server/text_generation_server/layers/attention/__init__.py index e74180e7..c8bccefe 100644 --- a/server/text_generation_server/layers/attention/__init__.py +++ b/server/text_generation_server/layers/attention/__init__.py @@ -1,6 +1,8 @@ from text_generation_server.utils.import_utils import SYSTEM import os +from .common import Seqlen + if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": raise ImportError("`USE_FLASH_ATTENTION` is false.") if SYSTEM == "cuda": diff --git a/server/text_generation_server/layers/attention/common.py b/server/text_generation_server/layers/attention/common.py new file mode 100644 index 00000000..bd0717ce --- /dev/null +++ b/server/text_generation_server/layers/attention/common.py @@ -0,0 +1,44 @@ +from dataclasses import dataclass +from text_generation_server.models.globals import FLASH_DECODING +import torch +from typing import Optional + + +if FLASH_DECODING: + + @dataclass + class Seqlen: + input_lengths: torch.Tensor + cu_seqlen_q: Optional[torch.Tensor] + cu_seqlen_k: Optional[torch.Tensor] + + def __init__(self, input_lengths): + self.input_lengths = input_lengths + device = self.input_lengths.device + shape = self.input_lengths.shape + cu_seqlen_q = torch.arange( + shape[0] + 1, + device=device, + dtype=torch.int32, + ) + cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32) + # cuda graphs don't like this and this is necessary to clamp within mistral + # Although FA2 might not want the clamping + # cu_seqlen_k[0] = 0 + torch.cumsum(self.input_lengths, -1, out=cu_seqlen_k[1:]) + + self.cu_seqlen_q = cu_seqlen_q + self.cu_seqlen_k = cu_seqlen_k + + def clamp(self, max): + # Flash decoding doesn't need to clamp + return self + +else: + + @dataclass + class Seqlen: + input_lengths: torch.Tensor + + def clamp(self, max): + return Seqlen(torch.clamp(self.input_lengths, max=max)) diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index 583337bd..94b69899 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -1,5 +1,7 @@ import torch from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.models.globals import FLASH_DECODING, BLOCK_SIZE +from text_generation_server.layers.attention import Seqlen major, minor = torch.cuda.get_device_capability() is_sm75 = major == 7 and minor == 5 @@ -21,7 +23,14 @@ def reshape_and_cache( value_cache: torch.Tensor, slots: torch.Tensor, ): - cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0) + if FLASH_DECODING: + shape = key_cache.shape + key_cache.view(-1, shape[-2], shape[-1])[slots] = key + value_cache.view(-1, shape[-2], shape[-1])[slots] = value + else: + cache_ops.reshape_and_cache( + key, value, key_cache, value_cache, slots, "auto", 1.0 + ) def paged_attention( @@ -32,7 +41,7 @@ def paged_attention( kv_head_mapping: torch.Tensor, softmax_scale: float, block_tables: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, ): # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py @@ -53,7 +62,8 @@ def paged_attention( # # value_cache => [num_blocks, num_heads, head_size, block_size] - block_size = value_cache.shape[3] + # block_size = value_cache.shape[3] + block_size = BLOCK_SIZE num_seqs, num_heads, head_size = query.shape max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE @@ -62,58 +72,95 @@ def paged_attention( # V1 to avoid the overhead of reduction. Also, if the number of # sequences or heads is large, we use V1 since there is enough work # to parallelize. - from vllm._C import ops + if FLASH_DECODING: + max_q = 1 + max_k = max_s + import flash_attn_2_cuda - use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512) - if use_v1: - ops.paged_attention_v1( - out, + # TODO fixme when flash contains the fix. + # Number of splits is not correctly handled + # by the current path + # https://github.com/Dao-AILab/flash-attention/blob/320fb59487658f033f56711efd3d61b7c7a6f8f3/csrc/flash_attn/flash_api.cpp#L577 + # This fails becuase we're using causal, therefore window_right is set to 0 and the split logic is never applied. + out2 = flash_attn_2_cuda.varlen_fwd( query, key_cache, value_cache, - kv_head_mapping, - softmax_scale, - block_tables, - input_lengths, - block_size, - max_s, None, - "auto", - 1.0, + seqlen.cu_seqlen_q, + seqlen.cu_seqlen_k, + None, + block_tables, + None, + max_q, + max_k, + 0.0, # dropout + softmax_scale, + False, # zero_tensors + True, # causal + -1, # Window_left + -1, # Window right + False, # return softmax + None, # generator ) + return out2[0] else: - # Run PagedAttention V2. - assert _PARTITION_SIZE % block_size == 0 - tmp_output = torch.empty( - size=(num_seqs, num_heads, max_num_partitions, head_size), - dtype=out.dtype, - device=out.device, - ) - exp_sums = torch.empty( - size=(num_seqs, num_heads, max_num_partitions), - dtype=torch.float32, - device=out.device, - ) - max_logits = torch.empty_like(exp_sums) + input_lengths = seqlen.input_lengths + from vllm._C import ops - ops.paged_attention_v2( - out, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - kv_head_mapping, - softmax_scale, - block_tables, - input_lengths, - block_size, - max_s, - None, - "auto", - 1.0, + use_v1 = max_s <= 8192 and ( + max_num_partitions == 1 or num_seqs * num_heads > 512 ) + if use_v1: + ops.paged_attention_v1( + out, + query, + key_cache, + value_cache, + kv_head_mapping, + softmax_scale, + block_tables, + input_lengths, + block_size, + max_s, + None, + "auto", + 1.0, + ) + else: + # Run PagedAttention V2. + assert _PARTITION_SIZE % block_size == 0 + tmp_output = torch.empty( + size=(num_seqs, num_heads, max_num_partitions, head_size), + dtype=out.dtype, + device=out.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, max_num_partitions), + dtype=torch.float32, + device=out.device, + ) + max_logits = torch.empty_like(exp_sums) + + ops.paged_attention_v2( + out, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + kv_head_mapping, + softmax_scale, + block_tables, + input_lengths, + block_size, + max_s, + None, + "auto", + 1.0, + ) + return out try: diff --git a/server/text_generation_server/layers/attention/ipex.py b/server/text_generation_server/layers/attention/ipex.py index 7f086b68..db79c589 100644 --- a/server/text_generation_server/layers/attention/ipex.py +++ b/server/text_generation_server/layers/attention/ipex.py @@ -55,7 +55,8 @@ def paged_attention( kv_head_mapping: torch.Tensor, softmax_scale: float, block_tables: torch.Tensor, - input_lengths: torch.Tensor, + cu_seqlen_q: torch.Tensor, + cu_seqlen_k: torch.Tensor, max_s: int, ): return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( @@ -66,7 +67,7 @@ def paged_attention( kv_head_mapping, softmax_scale, block_tables, - input_lengths, + cu_seqlen_q, BLOCK_SIZE, max_s, None, diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index 91ed5818..36db12d0 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -1,6 +1,7 @@ import os import torch from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.models.globals import FLASH_DECODING from loguru import logger major, minor = torch.cuda.get_device_capability() @@ -26,7 +27,14 @@ def reshape_and_cache( value_cache: torch.Tensor, slots: torch.Tensor, ): - cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0) + if FLASH_DECODING: + shape = key_cache.shape + key_cache.view(-1, shape[-2], shape[-1])[slots] = key + value_cache.view(-1, shape[-2], shape[-1])[slots] = value + else: + cache_ops.reshape_and_cache( + key, value, key_cache, value_cache, slots, "auto", 1.0 + ) def paged_attention( @@ -37,7 +45,8 @@ def paged_attention( kv_head_mapping: torch.Tensor, softmax_scale: float, block_tables: torch.Tensor, - input_lengths: torch.Tensor, + cu_seqlen_q: torch.Tensor, + cu_seqlen_k: torch.Tensor, max_s: int, ): # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py @@ -61,6 +70,7 @@ def paged_attention( block_size = value_cache.shape[3] num_seqs, num_heads, head_size = query.shape max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE + input_lengths = cu_seqlen_k # NOTE(woosuk): We use a simple heuristic to decide whether to use # PagedAttention V1 or V2. If the number of partitions is 1, we use @@ -119,6 +129,7 @@ def paged_attention( "auto", 1.0, ) + return out if ENGINE != "triton": diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index f2f0f457..5ea43290 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -12,7 +12,6 @@ from pathlib import Path from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.models.model import Model from text_generation_server.models.causal_lm import CausalLM -from text_generation_server.models.flash_causal_lm import FlashCausalLM from text_generation_server.models.bloom import BLOOMSharded from text_generation_server.models.mpt import MPTSharded from text_generation_server.models.seq2seq_lm import Seq2SeqLM @@ -53,6 +52,7 @@ FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models." FLASH_ATTENTION = True try: + from text_generation_server.models.flash_causal_lm import FlashCausalLM from text_generation_server.models.flash_rw import FlashRWSharded from text_generation_server.models.flash_gpt2 import FlashGPT2 from text_generation_server.models.flash_neox import FlashNeoXSharded @@ -92,6 +92,7 @@ except ImportError as e: FLASH_ATTENTION = False if FLASH_ATTENTION: + __all__.append(FlashCausalLM) __all__.append(FlashGPT2) __all__.append(FlashNeoXSharded) __all__.append(FlashRWSharded) diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 2850a6f3..e088f9aa 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -30,6 +30,7 @@ from text_generation_server.layers.attention import ( attention, reshape_and_cache, ) +from text_generation_server.models.globals import FLASH_DECODING from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers import ( TensorParallelRowLinear, @@ -259,8 +260,8 @@ class FlashCohereAttention(torch.nn.Module): cu_seqlen_prefill, kv_cache, block_tables, - slots, input_lengths, + slots, max_s, ): qkv = self.query_key_value(hidden_states) @@ -304,7 +305,7 @@ class FlashCohereAttention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], @@ -464,6 +465,7 @@ class FlashCohereModel(torch.nn.Module): ) residual = None + for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 9d56e4ef..aea7f399 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -336,7 +336,7 @@ class DbrxAttention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index a71de61f..cfa6b2fe 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -251,7 +251,7 @@ class FlashGemma2Attention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 82891823..842df0d4 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -245,7 +245,7 @@ class FlashGemmaAttention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index 7e7510c7..9f800146 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -245,7 +245,7 @@ class FlashGPT2Attention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 0ea9f623..77a7e2d5 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -33,6 +33,7 @@ from text_generation_server.layers.attention import ( attention, reshape_and_cache, ) +from text_generation_server.models.globals import FLASH_DECODING from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -213,7 +214,7 @@ class FlashLlamaAttention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index d1ba5564..69ed5f64 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -28,6 +28,7 @@ from typing import Optional, List, Tuple from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( + Seqlen, paged_attention, attention, reshape_and_cache, @@ -229,7 +230,7 @@ class MistralAttention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], @@ -512,7 +513,7 @@ class FlashMistralForCausalLM(torch.nn.Module): elif self.max_past is not None: # Clamp in decode mode as paged attention requires clamped values whereas the flash attention # kernel requires the true values - input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor) + input_lengths = input_lengths.clamp(max=self.max_past_tensor) inputs_embeds = self.embed_tokens(input_ids) hidden_states = self.model( diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 2e839d15..2d6a7f97 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -291,7 +291,7 @@ class MixtralAttention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], @@ -647,7 +647,7 @@ class FlashMixtralForCausalLM(torch.nn.Module): elif self.max_past is not None: # Clamp in decode mode as paged attention requires clamped values whereas the flash attention # kernel requires the true values - input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor) + input_lengths = input_lengths.clamp(max=self.max_past_tensor) hidden_states = self.model( input_ids, diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index b87fd4ca..33aebc2b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -168,7 +168,7 @@ class FlashNeoxAttention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, qkv[:, 0], kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index 3f445f97..f237ea37 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -207,7 +207,7 @@ class FlashPhiAttention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 69f38c3a..2e281386 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -149,7 +149,7 @@ class Qwen2Attention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 04d4ba51..e7614232 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -217,7 +217,7 @@ class FlashRWAttention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], @@ -340,7 +340,7 @@ class FlashRWLargeAttention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index badfc367..30989a37 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -301,7 +301,7 @@ class FlashMQAttention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index f6a2e15d..df864bc1 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -255,7 +255,7 @@ class Starcoder2Attention(torch.nn.Module): ) # Decode else: - paged_attention( + attn_output = paged_attention( attn_output, query, kv_cache[0], diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index a0a78b33..49a088a1 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -30,10 +30,13 @@ from text_generation_server.models.types import ( from text_generation_server.pb import generate_pb2 from text_generation_server.models.globals import ( MEM_POOL, + FLASH_DECODING, + BLOCK_SIZE, CUDA_GRAPHS, get_adapter_to_index, MODEL_ID, ) +from text_generation_server.layers.attention import Seqlen from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils.dist import MEMORY_FRACTION from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments @@ -46,7 +49,6 @@ from text_generation_server.utils.import_utils import ( tracer = trace.get_tracer(__name__) -BLOCK_SIZE: int = 16 # Will be set in init SLIDING_WINDOW: Optional[int] = None @@ -856,7 +858,23 @@ class FlashCausalLM(Model): else: x = BLOCK_SIZE // element_size - if SYSTEM == "ipex" and device == torch.device("cpu"): + if FLASH_DECODING: + self.kv_cache = [ + ( + torch.empty( + (num_blocks, BLOCK_SIZE, num_heads, head_size), + dtype=dtype, + device=device, + ), + torch.empty( + (num_blocks, BLOCK_SIZE, num_heads, head_size), + dtype=dtype, + device=device, + ), + ) + for _ in range(num_layers) + ] + elif SYSTEM == "ipex" and device == torch.device("cpu"): self.kv_cache = [ ( torch.empty( @@ -908,6 +926,7 @@ class FlashCausalLM(Model): "slots": slots, "input_lengths": input_lengths, } + input_lengths = Seqlen(input_lengths=input_lengths) graph = torch.cuda.CUDAGraph() self.cuda_graphs[bs]["graph"] = graph @@ -1067,6 +1086,7 @@ class FlashCausalLM(Model): # Dummy value, some models (starcoder2) don't accept `None`. input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) + input_lengths = Seqlen(input_lengths=input_lengths) # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. self.model.forward( @@ -1153,6 +1173,7 @@ class FlashCausalLM(Model): cuda_graph = None if cu_seqlen_prefill is not None or cuda_graph is None: + input_lengths = Seqlen(input_lengths=input_lengths) logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index bde86e6e..06035ccd 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -5,6 +5,12 @@ from typing import Dict MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None # This is overridden by the cli +FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"} +BLOCK_SIZE: int = 256 if FLASH_DECODING else 16 +if FLASH_DECODING: + logger.info("Using FLASH_DECODING") + + cuda_graphs = os.getenv("CUDA_GRAPHS") if cuda_graphs is not None: try: @@ -15,8 +21,6 @@ if cuda_graphs is not None: ) else: cuda_graphs = None - - # sorting the cuda graphs in descending order helps reduce the # memory impact and results in less memory usage if cuda_graphs is not None: