From 47447ef017d6cdf205be795c7cf7f1086367aa24 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Thu, 1 Aug 2024 17:03:28 +0200 Subject: [PATCH] Unify attention output handling (#2343) - Always return the hidden states. - Create the output tensor inside the `attention` and `paged_attention` functions. This removes the difference between how the output is handled between attention (output parameter) and paged attention (return value). This also removes the assumption that the attention implementation can write to an output tensor (in preparation of FlashInfer). --- .../layers/attention/cuda.py | 17 ++++++++++------- .../layers/attention/ipex.py | 5 +++-- .../layers/attention/rocm.py | 9 ++++++--- .../custom_modeling/flash_cohere_modeling.py | 7 +------ .../custom_modeling/flash_dbrx_modeling.py | 7 +------ .../flash_deepseek_v2_modeling.py | 9 ++------- .../custom_modeling/flash_gemma2_modeling.py | 7 +------ .../custom_modeling/flash_gemma_modeling.py | 7 +------ .../custom_modeling/flash_gpt2_modeling.py | 7 +------ .../custom_modeling/flash_llama_modeling.py | 7 +------ .../custom_modeling/flash_mistral_modeling.py | 7 +------ .../custom_modeling/flash_mixtral_modeling.py | 7 +------ .../custom_modeling/flash_neox_modeling.py | 7 +------ .../custom_modeling/flash_phi_modeling.py | 7 +------ .../custom_modeling/flash_qwen2_modeling.py | 7 +------ .../models/custom_modeling/flash_rw_modeling.py | 14 ++------------ .../flash_santacoder_modeling.py | 7 +------ .../flash_starcoder2_modeling.py | 7 +------ 18 files changed, 36 insertions(+), 109 deletions(-) diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index c0c4da4d..dff742dc 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -34,7 +34,6 @@ def reshape_and_cache( def paged_attention( - out: torch.Tensor, query: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, @@ -85,7 +84,7 @@ def paged_attention( # This fails becuase we're using causal, therefore window_right is set to 0 and the split logic is never applied. if softcap is None: softcap = 0.0 - out2 = flash_attn_2_cuda.varlen_fwd( + out = flash_attn_2_cuda.varlen_fwd( query, key_cache, value_cache, @@ -108,13 +107,15 @@ def paged_attention( False, # return softmax None, # generator ) - return out2[0] + return out[0] else: if softcap is not None: raise RuntimeError("Paged attention doesn't support softcapping") input_lengths = seqlen.input_lengths from vllm._C import ops + out = torch.empty_like(query) + use_v1 = max_s <= 8192 and ( max_num_partitions == 1 or num_seqs * num_heads > 512 ) @@ -200,13 +201,13 @@ except ImportError: SUPPORTS_WINDOWING = V2 + if V2: def attention( q, k, v, - out, cu_seqlens, max_s, softmax_scale, @@ -214,6 +215,7 @@ if V2: causal=True, softcap=0.0, ): + out = torch.empty_like(q) if window_size_left <= 0 and window_size_left != -1: raise ValueError("`window_size_left` must be > 0 or -1") return flash_attn_2_cuda.varlen_fwd( @@ -238,7 +240,7 @@ if V2: softcap, False, None, - ) + )[0] else: @@ -246,7 +248,6 @@ else: q, k, v, - out, cu_seqlens, max_s, softmax_scale, @@ -286,6 +287,8 @@ else: .reshape(original_shape[0], -1, original_shape[2]) ) + out = torch.empty_like(q) + return flash_attn_cuda.fwd( q, k, @@ -302,4 +305,4 @@ else: False, 0, None, - ) + )[0] diff --git a/server/text_generation_server/layers/attention/ipex.py b/server/text_generation_server/layers/attention/ipex.py index 45a0a03e..e0956b26 100644 --- a/server/text_generation_server/layers/attention/ipex.py +++ b/server/text_generation_server/layers/attention/ipex.py @@ -10,13 +10,14 @@ def attention( q, k, v, - out, cu_seqlens, max_s, softmax_scale, window_size_left=-1, causal=True, ): + out = torch.empty_like(q) + # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. return ipex.llm.functional.varlen_attention( q, @@ -49,7 +50,6 @@ def reshape_and_cache( def paged_attention( - out: torch.Tensor, query: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, @@ -59,6 +59,7 @@ def paged_attention( seqlen: Seqlen, max_s: int, ): + out = torch.empty_like(query) ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( out, query, diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index 3ebe492a..69e64162 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -39,7 +39,6 @@ def reshape_and_cache( def paged_attention( - out: torch.Tensor, query: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, @@ -72,6 +71,8 @@ def paged_attention( max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE input_lengths = input_lengths.input_lengths + out = torch.empty_like(query) + # NOTE(woosuk): We use a simple heuristic to decide whether to use # PagedAttention V1 or V2. If the number of partitions is 1, we use # V1 to avoid the overhead of reduction. Also, if the number of @@ -174,7 +175,6 @@ if ENGINE == "ck": q, k, v, - out, cu_seqlens, max_s, softmax_scale, @@ -184,6 +184,8 @@ if ENGINE == "ck": if window_size_left <= 0 and window_size_left != -1: raise ValueError("`window_size_left` must be > 0 or -1") + out = torch.empty_like(q) + # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. return flash_attn_2_cuda.varlen_fwd( q, @@ -209,13 +211,14 @@ elif ENGINE == "triton": q, k, v, - out, cu_seqlens, max_s, softmax_scale, window_size_left=-1, causal=True, ): + out = torch.empty_like(q) + # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. output, _ = triton_attention( q, diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index b73dada6..e02a31d9 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -291,17 +291,13 @@ class FlashCohereAttention(torch.nn.Module): reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) - # output tensor - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, key, value, - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -309,7 +305,6 @@ class FlashCohereAttention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 3c301c8d..d3d1d1ef 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -330,17 +330,13 @@ class DbrxAttention(torch.nn.Module): reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) - # output tensor - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -348,7 +344,6 @@ class DbrxAttention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], diff --git a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py index 3e84b4a8..0905d3c2 100644 --- a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -358,25 +358,20 @@ class DeepseekV2Attention(torch.nn.Module): reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) - # Output tensor - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, key, value, - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, ) # Decode else: - paged_attention( - attn_output, + attn_output = paged_attention( query, kv_cache[0], kv_cache[1], diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index 0dc5b9cf..de86f514 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -231,17 +231,13 @@ class FlashGemma2Attention(torch.nn.Module): reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) - # output tensor - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -252,7 +248,6 @@ class FlashGemma2Attention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index dfe6510c..178efadb 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -225,17 +225,13 @@ class FlashGemmaAttention(torch.nn.Module): reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) - # output tensor - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -244,7 +240,6 @@ class FlashGemmaAttention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index a55a4af3..a19cff8c 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -225,17 +225,13 @@ class FlashGPT2Attention(torch.nn.Module): reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) - # output tensor - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, key, value, - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -243,7 +239,6 @@ class FlashGPT2Attention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index b55ddc23..9ea19a87 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -213,17 +213,13 @@ class FlashLlamaAttention(torch.nn.Module): reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) - # output tensor - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -231,7 +227,6 @@ class FlashLlamaAttention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 3d8f6bf4..dda53ff3 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -212,17 +212,13 @@ class MistralAttention(torch.nn.Module): kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots ) - # output tensor - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -231,7 +227,6 @@ class MistralAttention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 7cdca553..85431c6c 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -269,17 +269,13 @@ class MixtralAttention(torch.nn.Module): kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots ) - # output tensor - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -288,7 +284,6 @@ class MixtralAttention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 623b164c..b1b03ad7 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -158,17 +158,13 @@ class FlashNeoxAttention(torch.nn.Module): reshape_and_cache(qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots) - # output tensor - attn_output = torch.empty_like(qkv[:, 0]) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( qkv[:, 0], qkv[:, 1], qkv[:, 2], - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -176,7 +172,6 @@ class FlashNeoxAttention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, qkv[:, 0], kv_cache[0], kv_cache[1], diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index a1ce03b9..a9e18348 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -188,16 +188,12 @@ class FlashPhiAttention(torch.nn.Module): # Reshape key and value and cache reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) - # output tensor - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: - attention( + attn_output = attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -205,7 +201,6 @@ class FlashPhiAttention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index e357a287..865cc85d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -130,17 +130,13 @@ class Qwen2Attention(torch.nn.Module): kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots ) - # output tensor - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -149,7 +145,6 @@ class Qwen2Attention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index d7cad480..708641e7 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -201,17 +201,13 @@ class FlashRWAttention(torch.nn.Module): reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) - # output - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -219,7 +215,6 @@ class FlashRWAttention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], @@ -324,17 +319,13 @@ class FlashRWLargeAttention(torch.nn.Module): slots, ) - # output - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, torch.select(kv, dim=2, index=0), torch.select(kv, dim=2, index=1), - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -342,7 +333,6 @@ class FlashRWLargeAttention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 2b939a10..c2676782 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -286,17 +286,13 @@ class FlashMQAttention(torch.nn.Module): key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots ) - # output - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, torch.select(key_value, dim=1, index=0), torch.select(key_value, dim=1, index=1), - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -304,7 +300,6 @@ class FlashMQAttention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1], diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index cfa891d4..e562eb89 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -235,17 +235,13 @@ class Starcoder2Attention(torch.nn.Module): kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots ) - # output tensor - attn_output = torch.empty_like(query) - # Prefill if cu_seqlen_prefill is not None: # flash attention - attention( + attn_output = attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), - attn_output, cu_seqlen_prefill, max_s, self.softmax_scale, @@ -254,7 +250,6 @@ class Starcoder2Attention(torch.nn.Module): # Decode else: attn_output = paged_attention( - attn_output, query, kv_cache[0], kv_cache[1],