diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 4e0202a8..a6911df8 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -28,7 +28,6 @@ from transformers.activations import ACT2FN from typing import Optional, List, Tuple from text_generation_server.utils.import_utils import SYSTEM -from text_generation_server.models.globals import FLASH_DECODING from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.layers import ( TensorParallelRowLinear, @@ -155,72 +154,34 @@ class FlashLlamaAttention(torch.nn.Module): # output tensor attn_output = torch.empty_like(query) - if FLASH_DECODING: - # Prefill - kv_cache[0].view(-1, self.num_key_value_heads, self.head_size)[slots] = kv[ - :, 0 - ] - kv_cache[1].view(-1, self.num_key_value_heads, self.head_size)[slots] = kv[ - :, 1 - ] + paged_attention.reshape_and_cache( + kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots + ) - if cu_seqlen_prefill is not None: - # flash attention - flash_attn.attention( - query, - # torch.select(kv, dim=1, index=0), - # torch.select(kv, dim=1, index=1), - kv_cache[0], - kv_cache[1], - attn_output, - cu_seqlen_prefill, - block_tables, - max_s, - self.softmax_scale, - ) - # Decode - else: - paged_attention.attention( - attn_output, - query, - kv_cache[0], - kv_cache[1], - self.kv_head_mapping, - self.softmax_scale, - block_tables, - input_lengths, - max_s, - ) - else: - paged_attention.reshape_and_cache( - kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots + if cu_seqlen_prefill is not None: + # flash attention + flash_attn.attention( + query, + torch.select(kv, dim=1, index=0), + torch.select(kv, dim=1, index=1), + attn_output, + cu_seqlen_prefill, + max_s, + self.softmax_scale, + ) + # Decode + else: + paged_attention.attention( + attn_output, + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, + self.softmax_scale, + block_tables, + input_lengths, + max_s, ) - # Prefill - if cu_seqlen_prefill is not None: - # flash attention - flash_attn.attention( - query, - torch.select(kv, dim=1, index=0), - torch.select(kv, dim=1, index=1), - attn_output, - cu_seqlen_prefill, - None, - max_s, - self.softmax_scale, - ) - # Decode - else: - paged_attention.attention( - attn_output, - query, - kv_cache[0], - kv_cache[1], - self.kv_head_mapping, - self.softmax_scale, - block_tables, - input_lengths, - max_s, - ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index d1cd4418..cc51fe29 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -27,7 +27,6 @@ from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple from text_generation_server.utils.import_utils import SYSTEM -from text_generation_server.models.globals import FLASH_DECODING from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.layers import ( TensorParallelRowLinear, @@ -217,77 +216,39 @@ class MistralAttention(torch.nn.Module): attn_output = torch.empty_like(query) - if FLASH_DECODING: - # Prefill - kv_cache[0].view(-1, self.num_key_value_heads, self.head_size)[slots] = kv[ - :, 0 - ] - kv_cache[1].view(-1, self.num_key_value_heads, self.head_size)[slots] = kv[ - :, 1 - ] - - if cu_seqlen_prefill is not None: - # flash attention - flash_attn.attention( - query, - # torch.select(kv, dim=1, index=0), - # torch.select(kv, dim=1, index=1), - kv_cache[0], - kv_cache[1], - attn_output, - cu_seqlen_prefill, - block_tables, - max_s, - self.softmax_scale, - ) - # Decode - else: - paged_attention.attention( - attn_output, - query, - kv_cache[0], - kv_cache[1], - self.kv_head_mapping, - self.softmax_scale, - block_tables, - input_lengths, - max_s, - ) + if prefill_cache_indices is not None: + kv_to_cache = kv[prefill_cache_indices] else: - if prefill_cache_indices is not None: - kv_to_cache = kv[prefill_cache_indices] - else: - kv_to_cache = kv + kv_to_cache = kv - paged_attention.reshape_and_cache( - kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots + paged_attention.reshape_and_cache( + kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots + ) + # Prefill + if cu_seqlen_prefill is not None: + # flash attention + flash_attn.attention( + query, + torch.select(kv, dim=1, index=0), + torch.select(kv, dim=1, index=1), + attn_output, + cu_seqlen_prefill, + max_s, + self.softmax_scale, + ) + # Decode + else: + paged_attention.attention( + attn_output, + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, + self.softmax_scale, + block_tables, + input_lengths, + max_s, ) - # Prefill - if cu_seqlen_prefill is not None: - # flash attention - flash_attn.attention( - query, - torch.select(kv, dim=1, index=0), - torch.select(kv, dim=1, index=1), - attn_output, - cu_seqlen_prefill, - None, - max_s, - self.softmax_scale, - ) - # Decode - else: - paged_attention.attention( - attn_output, - query, - kv_cache[0], - kv_cache[1], - self.kv_head_mapping, - self.softmax_scale, - block_tables, - input_lengths, - max_s, - ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) diff --git a/server/text_generation_server/utils/flash_attn.py b/server/text_generation_server/utils/flash_attn.py index df3b61af..4f5cf10b 100644 --- a/server/text_generation_server/utils/flash_attn.py +++ b/server/text_generation_server/utils/flash_attn.py @@ -134,7 +134,6 @@ elif HAS_FLASH_ATTN_V2_CUDA: v, out, cu_seqlens, - block_tables, max_s, softmax_scale, window_size_left=-1, @@ -150,7 +149,7 @@ elif HAS_FLASH_ATTN_V2_CUDA: cu_seqlens, cu_seqlens, None, - block_tables, + None, None, max_s, max_s, diff --git a/server/text_generation_server/utils/paged_attention.py b/server/text_generation_server/utils/paged_attention.py index ea5c4558..f8af5dc4 100644 --- a/server/text_generation_server/utils/paged_attention.py +++ b/server/text_generation_server/utils/paged_attention.py @@ -28,9 +28,14 @@ def reshape_and_cache( key, value, key_cache, value_cache, slots ) else: - cache_ops.reshape_and_cache( - key, value, key_cache, value_cache, slots, "auto", 1.0 - ) + if FLASH_DECODING: + shape = key_cache.shape + key_cache.view(-1, shape[-2], shape[-1])[slots] = key + value_cache.view(-1, shape[-2], shape[-1])[slots] = value + else: + cache_ops.reshape_and_cache( + key, value, key_cache, value_cache, slots, "auto", 1.0 + ) def attention(