diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index c0c4da4d..dff742dc 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -34,7 +34,6 @@ def reshape_and_cache( def paged_attention( - out: torch.Tensor, query: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, @@ -85,7 +84,7 @@ def paged_attention( # This fails becuase we're using causal, therefore window_right is set to 0 and the split logic is never applied. if softcap is None: softcap = 0.0 - out2 = flash_attn_2_cuda.varlen_fwd( + out = flash_attn_2_cuda.varlen_fwd( query, key_cache, value_cache, @@ -108,13 +107,15 @@ def paged_attention( False, # return softmax None, # generator ) - return out2[0] + return out[0] else: if softcap is not None: raise RuntimeError("Paged attention doesn't support softcapping") input_lengths = seqlen.input_lengths from vllm._C import ops + out = torch.empty_like(query) + use_v1 = max_s <= 8192 and ( max_num_partitions == 1 or num_seqs * num_heads > 512 ) @@ -200,13 +201,13 @@ except ImportError: SUPPORTS_WINDOWING = V2 + if V2: def attention( q, k, v, - out, cu_seqlens, max_s, softmax_scale, @@ -214,6 +215,7 @@ if V2: causal=True, softcap=0.0, ): + out = torch.empty_like(q) if window_size_left <= 0 and window_size_left != -1: raise ValueError("`window_size_left` must be > 0 or -1") return flash_attn_2_cuda.varlen_fwd( @@ -238,7 +240,7 @@ if V2: softcap, False, None, - ) + )[0] else: @@ -246,7 +248,6 @@ else: q, k, v, - out, cu_seqlens, max_s, softmax_scale, @@ -286,6 +287,8 @@ else: .reshape(original_shape[0], -1, original_shape[2]) ) + out = torch.empty_like(q) + return flash_attn_cuda.fwd( q, k, @@ -302,4 +305,4 @@ else: False, 0, None, - ) + )[0] diff --git a/server/text_generation_server/layers/attention/ipex.py b/server/text_generation_server/layers/attention/ipex.py index 45a0a03e..e0956b26 100644 --- a/server/text_generation_server/layers/attention/ipex.py +++ b/server/text_generation_server/layers/attention/ipex.py @@ -10,13 +10,14 @@ def attention( q, k, v, - out, cu_seqlens, max_s, softmax_scale, window_size_left=-1, causal=True, ): + out = torch.empty_like(q) + # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. return ipex.llm.functional.varlen_attention( q, @@ -49,7 +50,6 @@ def reshape_and_cache( def paged_attention( - out: torch.Tensor, query: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, @@ -59,6 +59,7 @@ def paged_attention( seqlen: Seqlen, max_s: int, ): + out = torch.empty_like(query) ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( out, query, diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index 3ebe492a..69e64162 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -39,7 +39,6 @@ def reshape_and_cache( def paged_attention( - out: torch.Tensor, query: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, @@ -72,6 +71,8 @@ def paged_attention( max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE input_lengths = input_lengths.input_lengths + out = torch.empty_like(query) + # NOTE(woosuk): We use a simple heuristic to decide whether to use # PagedAttention V1 or V2. If the number of partitions is 1, we use # V1 to avoid the overhead of reduction. Also, if the number of @@ -174,7 +175,6 @@ if ENGINE == "ck": q, k, v, - out, cu_seqlens, max_s, softmax_scale, @@ -184,6 +184,8 @@ if ENGINE == "ck": if window_size_left <= 0 and window_size_left != -1: raise ValueError("`window_size_left` must be > 0 or -1") + out = torch.empty_like(q) + # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. return flash_attn_2_cuda.varlen_fwd( q, @@ -209,13 +211,14 @@ elif ENGINE == "triton": q, k, v, - out, cu_seqlens, max_s, softmax_scale, window_size_left=-1, causal=True, ): + out = torch.empty_like(q) + # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. output, _ = triton_attention( q, diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index b73dada6..e02a31d9 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -291,17 +291,13 @@ class FlashCohereAttention(torch.nn.Module): reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) - # output tensor - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, key, value, - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -309,7 +305,6 @@ class FlashCohereAttention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 3c301c8d..d3d1d1ef 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -330,17 +330,13 @@ class DbrxAttention(torch.nn.Module): reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) - # output tensor - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -348,7 +344,6 @@ class DbrxAttention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], diff --git a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py index 3e84b4a8..0905d3c2 100644 --- a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -358,25 +358,20 @@ class DeepseekV2Attention(torch.nn.Module): reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) - # Output tensor - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, key, value, - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, ) # Decode else: - paged_attention( - attn_output, + attn_output = paged_attention( query, kv_cache[0], kv_cache[1], diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index 0dc5b9cf..de86f514 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -231,17 +231,13 @@ class FlashGemma2Attention(torch.nn.Module): reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) - # output tensor - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -252,7 +248,6 @@ class FlashGemma2Attention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index dfe6510c..178efadb 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -225,17 +225,13 @@ class FlashGemmaAttention(torch.nn.Module): reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) - # output tensor - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -244,7 +240,6 @@ class FlashGemmaAttention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index a55a4af3..a19cff8c 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -225,17 +225,13 @@ class FlashGPT2Attention(torch.nn.Module): reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) - # output tensor - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, key, value, - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -243,7 +239,6 @@ class FlashGPT2Attention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index b55ddc23..9ea19a87 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -213,17 +213,13 @@ class FlashLlamaAttention(torch.nn.Module): reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) - # output tensor - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -231,7 +227,6 @@ class FlashLlamaAttention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 3d8f6bf4..dda53ff3 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -212,17 +212,13 @@ class MistralAttention(torch.nn.Module): kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots ) - # output tensor - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -231,7 +227,6 @@ class MistralAttention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 7cdca553..85431c6c 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -269,17 +269,13 @@ class MixtralAttention(torch.nn.Module): kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots ) - # output tensor - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -288,7 +284,6 @@ class MixtralAttention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 623b164c..b1b03ad7 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -158,17 +158,13 @@ class FlashNeoxAttention(torch.nn.Module): reshape_and_cache(qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots) - # output tensor - attn_output = torch.empty_like(qkv[:, 0]) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( qkv[:, 0], qkv[:, 1], qkv[:, 2], - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -176,7 +172,6 @@ class FlashNeoxAttention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, qkv[:, 0], kv_cache[0], kv_cache[1], diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index a1ce03b9..a9e18348 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -188,16 +188,12 @@ class FlashPhiAttention(torch.nn.Module): # Reshape key and value and cache reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) - # output tensor - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: - attention( + attn_output = attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -205,7 +201,6 @@ class FlashPhiAttention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index e357a287..865cc85d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -130,17 +130,13 @@ class Qwen2Attention(torch.nn.Module): kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots ) - # output tensor - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -149,7 +145,6 @@ class Qwen2Attention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index d7cad480..708641e7 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -201,17 +201,13 @@ class FlashRWAttention(torch.nn.Module): reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) - # output - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -219,7 +215,6 @@ class FlashRWAttention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], @@ -324,17 +319,13 @@ class FlashRWLargeAttention(torch.nn.Module): slots, ) - # output - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, torch.select(kv, dim=2, index=0), torch.select(kv, dim=2, index=1), - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -342,7 +333,6 @@ class FlashRWLargeAttention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 2b939a10..c2676782 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -286,17 +286,13 @@ class FlashMQAttention(torch.nn.Module): key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots ) - # output - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, torch.select(key_value, dim=1, index=0), torch.select(key_value, dim=1, index=1), - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -304,7 +300,6 @@ class FlashMQAttention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index cfa891d4..e562eb89 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -235,17 +235,13 @@ class Starcoder2Attention(torch.nn.Module): kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots ) - # output tensor - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -254,7 +250,6 @@ class Starcoder2Attention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1],