From 59ea38cbca3f1996f75bfd9b0ba9579a30a5558f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Thu, 17 Oct 2024 10:42:52 +0200 Subject: [PATCH] Simplify the `attention` function (#2609) * Simplify the `attention` function - Use one definition rather than multiple. - Add `key`/`value` arguments, so that we don't need the `PREFILL_IN_KVCACHE` constant. - Make it kwargs-only (to avoid mixing up the various `Tensor` args). * Fixup flashinfer support --- .../layers/attention/__init__.py | 4 - .../layers/attention/cuda.py | 352 ++++++------------ .../layers/attention/ipex.py | 38 +- .../layers/attention/kv_cache.py | 3 +- .../layers/attention/rocm.py | 89 ++--- .../custom_modeling/flash_cohere_modeling.py | 17 +- .../custom_modeling/flash_dbrx_modeling.py | 17 +- .../flash_deepseek_v2_modeling.py | 17 +- .../custom_modeling/flash_gemma2_modeling.py | 18 +- .../custom_modeling/flash_gemma_modeling.py | 17 +- .../custom_modeling/flash_gpt2_modeling.py | 17 +- .../custom_modeling/flash_gptj_modeling.py | 17 +- .../custom_modeling/flash_llama_modeling.py | 18 +- .../custom_modeling/flash_mistral_modeling.py | 17 +- .../custom_modeling/flash_mixtral_modeling.py | 17 +- .../custom_modeling/flash_neox_modeling.py | 17 +- .../custom_modeling/flash_phi_modeling.py | 17 +- .../custom_modeling/flash_qwen2_modeling.py | 17 +- .../custom_modeling/flash_rw_modeling.py | 33 +- .../flash_santacoder_modeling.py | 17 +- .../flash_starcoder2_modeling.py | 17 +- 21 files changed, 313 insertions(+), 463 deletions(-) diff --git a/server/text_generation_server/layers/attention/__init__.py b/server/text_generation_server/layers/attention/__init__.py index cc7f0caa..b7ca36f1 100644 --- a/server/text_generation_server/layers/attention/__init__.py +++ b/server/text_generation_server/layers/attention/__init__.py @@ -8,7 +8,6 @@ if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": raise ImportError("`USE_FLASH_ATTENTION` is false.") if SYSTEM == "cuda": from .cuda import ( - PREFILL_IN_KV_CACHE, SUPPORTS_WINDOWING, attention, paged_attention, @@ -16,7 +15,6 @@ if SYSTEM == "cuda": ) elif SYSTEM == "rocm": from .rocm import ( - PREFILL_IN_KV_CACHE, SUPPORTS_WINDOWING, attention, paged_attention, @@ -24,7 +22,6 @@ elif SYSTEM == "rocm": ) elif SYSTEM == "ipex": from .ipex import ( - PREFILL_IN_KV_CACHE, SUPPORTS_WINDOWING, attention, paged_attention, @@ -40,7 +37,6 @@ __all__ = [ "attention", "paged_attention", "reshape_and_cache", - "PREFILL_IN_KV_CACHE", "SUPPORTS_WINDOWING", "KVCache", "Seqlen", diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index 265a8ae4..5846bfe5 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -1,4 +1,5 @@ import torch +from text_generation_server.layers.attention.kv_cache import KVCache from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.models.globals import ( ATTENTION, @@ -38,8 +39,7 @@ def reshape_and_cache( def paged_attention( query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, + kv_cache: KVCache, kv_head_mapping: torch.Tensor, softmax_scale: float, block_tables: torch.Tensor, @@ -80,7 +80,7 @@ def paged_attention( return decode_state.get().forward( query.contiguous(), - paged_kv_cache=(key_cache, value_cache), + paged_kv_cache=(kv_cache.key, kv_cache.value), logits_soft_cap=softcap, sm_scale=softmax_scale, ) @@ -98,8 +98,8 @@ def paged_attention( softcap = 0.0 out = flash_attn_2_cuda.varlen_fwd( query, - key_cache, - value_cache, + kv_cache.key, + kv_cache.value, None, seqlen.cu_seqlen_q, seqlen.cu_seqlen_k, @@ -135,8 +135,8 @@ def paged_attention( ops.paged_attention_v1( out, query, - key_cache, - value_cache, + kv_cache.key, + kv_cache.value, kv_head_mapping, softmax_scale, block_tables, @@ -168,8 +168,8 @@ def paged_attention( max_logits, tmp_output, query, - key_cache, - value_cache, + kv_cache.key, + kv_cache.value, kv_head_mapping, softmax_scale, block_tables, @@ -216,263 +216,133 @@ except ImportError: ) from e +if ATTENTION == "flashdecoding" and not V2: + raise ValueError("Flash decoding requires Flash Attention V2") + SUPPORTS_WINDOWING = V2 -if ATTENTION == "flashinfer": - def attention( - q: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - seqlen: Seqlen, - block_tables: torch.Tensor, - softmax_scale, - window_size_left=-1, - causal=True, - softcap=0.0, - ): +def attention( + *, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: KVCache, + seqlen: Seqlen, + block_tables: torch.Tensor, + softmax_scale: float, + window_size_left: int = -1, + causal: bool = True, + softcap: Optional[float] = None, +): + if ATTENTION == "flashinfer": from text_generation_server.layers.attention.flashinfer import ( prefill_with_paged_kv_state, ) + if softcap is None: + softcap = 0.0 + return prefill_with_paged_kv_state.get().forward( - q.contiguous(), + query.contiguous(), causal=causal, - paged_kv_cache=(key_cache, value_cache), + paged_kv_cache=(kv_cache.key, kv_cache.value), logits_soft_cap=softcap, sm_scale=softmax_scale, window_left=window_size_left, ) -elif ATTENTION == "flashdecoding": - if V2: + # If we are using flashdecoding or paged, we always use flash-attn for + # the prefill. We have to branch on whether we use flash-attn v1 or v2. + elif V2: + out = torch.empty_like(query) + if window_size_left <= 0 and window_size_left != -1: + raise ValueError("`window_size_left` must be > 0 or -1") - def attention( - q, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - seqlen: Seqlen, - block_tables: torch.Tensor, + if softcap is None: + softcap = 0.0 + + return flash_attn_2_cuda.varlen_fwd( + query, + # flashdecoding: pass the KV caches, paged: pass the KV. + kv_cache.key if ATTENTION == "flashdecoding" else key, + kv_cache.value if ATTENTION == "flashdecoding" else value, + out, + seqlen.cu_seqlen_q, + seqlen.cu_seqlen_k, + None, + None, + block_tables if ATTENTION == "flashdecoding" else None, + None, + seqlen.max_q, + seqlen.max_k, + 0.0, softmax_scale, - window_size_left=-1, - causal=True, - softcap=0.0, - ): - out = torch.empty_like(q) - if window_size_left <= 0 and window_size_left != -1: - raise ValueError("`window_size_left` must be > 0 or -1") - return flash_attn_2_cuda.varlen_fwd( - q, - key_cache, - value_cache, - out, - seqlen.cu_seqlen_q, - seqlen.cu_seqlen_k, - None, - None, - block_tables, - None, - seqlen.max_q, - seqlen.max_k, - 0.0, - softmax_scale, - False, - causal, - window_size_left, - 0, - softcap, - False, - None, - )[0] + False, + causal, + window_size_left, + 0, + softcap, + False, + None, + )[0] else: - - def attention( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - seqlen: Seqlen, - block_tables: torch.Tensor, - softmax_scale: float, - window_size_left: int = -1, - causal: bool = True, - softcap=None, - ): - if window_size_left != -1: - raise NotImplementedError( - "window_size_left is only available with flash attn v2" - ) - if softcap is not None: - raise NotImplementedError( - "softcap is only available with flash attn v2" - ) - - # Flash attention v1 requires q, k and v to have the same number of heads - if k.shape[1] != q.shape[1]: - # MQA expand - if k.shape[1] == 1: - k = k.expand(-1, q.shape[1], -1) - # Grouped attention reshape - else: - original_shape = k.shape - k = ( - k.unsqueeze(2) - .expand(-1, -1, q.shape[1] // k.shape[1], -1) - .reshape(original_shape[0], -1, original_shape[2]) - ) - if v.shape[1] != q.shape[1]: - # MQA expand - if v.shape[1] == 1: - v = v.expand(-1, q.shape[1], -1) - # Grouped attention reshape - else: - original_shape = v.shape - v = ( - v.unsqueeze(2) - .expand(-1, -1, q.shape[1] // v.shape[1], -1) - .reshape(original_shape[0], -1, original_shape[2]) - ) - - out = torch.empty_like(q) - flash_attn_cuda.fwd( - q, - k, - v, - out, - seqlen.cu_seqlen_q, - seqlen.cu_seqlen_q, - seqlen.max_q, - seqlen.max_k, - 0.0, - softmax_scale, - False, - causal, - False, - 0, - None, + if window_size_left != -1: + raise NotImplementedError( + "window_size_left is only available with flash attn v2" ) - return out + if softcap is not None: + raise NotImplementedError("softcap is not available in flash attn v1") -elif ATTENTION == "paged": - if V2: + # Flash attention v1 requires q, k and v to have the same number of heads + if key.shape[1] != query.shape[1]: + # MQA expand + if key.shape[1] == 1: + key = key.expand(-1, query.shape[1], -1) + # Grouped attention reshape + else: + original_shape = key.shape + key = ( + key.unsqueeze(2) + .expand(-1, -1, query.shape[1] // key.shape[1], -1) + .reshape(original_shape[0], -1, original_shape[2]) + ) + if value.shape[1] != query.shape[1]: + # MQA expand + if value.shape[1] == 1: + value = value.expand(-1, query.shape[1], -1) + # Grouped attention reshape + else: + original_shape = value.shape + value = ( + value.unsqueeze(2) + .expand(-1, -1, query.shape[1] // value.shape[1], -1) + .reshape(original_shape[0], -1, original_shape[2]) + ) - def attention( - q, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - seqlen: Seqlen, - block_tables: torch.Tensor, + out = torch.empty_like(query) + flash_attn_cuda.fwd( + query, + key, + value, + out, + seqlen.cu_seqlen_q, + seqlen.cu_seqlen_q, + seqlen.max_q, + seqlen.max_k, + 0.0, softmax_scale, - window_size_left=-1, - causal=True, - softcap=0.0, - ): - out = torch.empty_like(q) - if window_size_left <= 0 and window_size_left != -1: - raise ValueError("`window_size_left` must be > 0 or -1") - return flash_attn_2_cuda.varlen_fwd( - q, - key_cache, - value_cache, - out, - seqlen.cu_seqlen_q, - seqlen.cu_seqlen_k, - None, - None, - None, # block_tables, - None, - seqlen.max_q, - seqlen.max_k, - 0.0, - softmax_scale, - False, - causal, - window_size_left, - 0, - softcap, - False, - None, - )[0] + False, + causal, + False, + 0, + None, + ) + return out - else: - - def attention( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - seqlen: Seqlen, - block_tables: torch.Tensor, - softmax_scale: float, - window_size_left: int = -1, - causal: bool = True, - softcap=None, - ): - if window_size_left != -1: - raise NotImplementedError( - "window_size_left is only available with flash attn v2" - ) - if softcap is not None: - raise NotImplementedError( - "softcap is only available with flash attn v2" - ) - - # Flash attention v1 requires q, k and v to have the same number of heads - if k.shape[1] != q.shape[1]: - # MQA expand - if k.shape[1] == 1: - k = k.expand(-1, q.shape[1], -1) - # Grouped attention reshape - else: - original_shape = k.shape - k = ( - k.unsqueeze(2) - .expand(-1, -1, q.shape[1] // k.shape[1], -1) - .reshape(original_shape[0], -1, original_shape[2]) - ) - if v.shape[1] != q.shape[1]: - # MQA expand - if v.shape[1] == 1: - v = v.expand(-1, q.shape[1], -1) - # Grouped attention reshape - else: - original_shape = v.shape - v = ( - v.unsqueeze(2) - .expand(-1, -1, q.shape[1] // v.shape[1], -1) - .reshape(original_shape[0], -1, original_shape[2]) - ) - - out = torch.empty_like(q) - flash_attn_cuda.fwd( - q, - k, - v, - out, - seqlen.cu_seqlen_q, - seqlen.cu_seqlen_q, - seqlen.max_q, - seqlen.max_k, - 0.0, - softmax_scale, - False, - causal, - False, - 0, - None, - ) - return out - -else: - raise RuntimeError(f"Unknwon attention {ATTENTION}") - - -# Prefill in the cache with every kind of attention, unless we -# have a configuration that requires flash-attention v1, which -# does not support block tables. -PREFILL_IN_KV_CACHE = ATTENTION == "flashinfer" or (ATTENTION == "flashdecoding" and V2) __all__ = [ - "PREFILL_IN_KV_CACHE", "SUPPORTS_WINDOWING", "attention", "paged_attention", diff --git a/server/text_generation_server/layers/attention/ipex.py b/server/text_generation_server/layers/attention/ipex.py index 17f6a7f1..5d159796 100644 --- a/server/text_generation_server/layers/attention/ipex.py +++ b/server/text_generation_server/layers/attention/ipex.py @@ -1,31 +1,36 @@ import intel_extension_for_pytorch as ipex import torch +from text_generation_server.layers.attention.kv_cache import KVCache from text_generation_server.models.flash_causal_lm import BLOCK_SIZE from text_generation_server.layers.attention import Seqlen from typing import Optional SUPPORTS_WINDOWING = False -PREFILL_IN_KV_CACHE = False def attention( - q: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, + *, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: KVCache, seqlen: Seqlen, block_tables: torch.Tensor, - softmax_scale, - window_size_left=-1, - causal=True, + softmax_scale: float, + window_size_left: int = -1, + causal: bool = True, softcap: Optional[float] = None, ): - out = torch.empty_like(q) + if softcap is not None: + raise NotImplementedError("softcap is not available in IPEX") + + out = torch.empty_like(query) # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. ipex.llm.functional.varlen_attention( - q.contiguous() if q.device.type == "xpu" else q, - key_cache.contiguous() if key_cache.device.type == "xpu" else key_cache, - value_cache.contiguous() if value_cache.device.type == "xpu" else value_cache, + query.contiguous() if query.device.type == "xpu" else query, + key.contiguous() if key.device.type == "xpu" else key, + value.contiguous() if value.device.type == "xpu" else value, out, seqlen.cu_seqlen_q, seqlen.cu_seqlen_q, @@ -56,8 +61,7 @@ def reshape_and_cache( def paged_attention( query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, + kv_cache: KVCache, kv_head_mapping: torch.Tensor, softmax_scale: float, block_tables: torch.Tensor, @@ -65,13 +69,16 @@ def paged_attention( max_s: int, softcap: Optional[float] = None, ): + if softcap is not None: + raise NotImplementedError("softcap is not available in IPEX") + out = torch.empty_like(query) input_lengths = seqlen.input_lengths + seqlen.cache_lengths ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( out, query, - key_cache, - value_cache, + kv_cache.key, + kv_cache.value, kv_head_mapping, softmax_scale, block_tables, @@ -84,7 +91,6 @@ def paged_attention( __all__ = [ - "PREFILL_IN_KV_CACHE", "SUPPORTS_WINDOWING", "attention", "paged_attention", diff --git a/server/text_generation_server/layers/attention/kv_cache.py b/server/text_generation_server/layers/attention/kv_cache.py index 7f1dd370..e6091a5f 100644 --- a/server/text_generation_server/layers/attention/kv_cache.py +++ b/server/text_generation_server/layers/attention/kv_cache.py @@ -3,7 +3,6 @@ from typing import Tuple import torch from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE from text_generation_server.utils.import_utils import SYSTEM -from text_generation_server.layers.attention import reshape_and_cache class KVCache: @@ -116,4 +115,6 @@ class KVCache: key_cache.view(-1, shape[-2], shape[-1])[slots] = key value_cache.view(-1, shape[-2], shape[-1])[slots] = value else: + from text_generation_server.layers.attention import reshape_and_cache + reshape_and_cache(key, value, key_cache, value_cache, slots) diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index 27e7638a..986b16e8 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -1,6 +1,7 @@ import os from typing import Optional import torch +from text_generation_server.layers.attention.kv_cache import KVCache from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.models.globals import ATTENTION from text_generation_server.layers.attention import Seqlen @@ -16,8 +17,6 @@ _PARTITION_SIZE_CUSTOM = 256 use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"} ENGINE = "triton" if use_triton else "ck" -PREFILL_IN_KV_CACHE = False - use_rocm_custom_paged_attn = os.getenv("ROCM_USE_CUSTOM_PAGED_ATTN", "1") != "0" try: if use_rocm_custom_paged_attn: @@ -54,8 +53,7 @@ def reshape_and_cache( def paged_attention( query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, + kv_cache: KVCache, kv_head_mapping: torch.Tensor, softmax_scale: float, block_tables: torch.Tensor, @@ -84,10 +82,10 @@ def paged_attention( raise RuntimeError("Paged attention doesn't support softcapping") # value_cache => [num_blocks, num_heads, head_size, block_size] - block_size = value_cache.shape[3] + block_size = kv_cache.value.shape[3] num_seqs, num_heads, head_size = query.shape - num_kv_heads = key_cache.shape[1] + num_kv_heads = kv_cache.key.shape[1] gqa_ratio = num_heads // num_kv_heads use_custom = ( use_rocm_custom_paged_attn @@ -124,8 +122,8 @@ def paged_attention( ops.paged_attention_v1( out, query, - key_cache, - value_cache, + kv_cache.key, + kv_cache.value, kv_head_mapping, softmax_scale, block_tables, @@ -158,8 +156,8 @@ def paged_attention( max_logits, tmp_output, query, - key_cache, - value_cache, + kv_cache.key, + kv_cache.value, kv_head_mapping, softmax_scale, block_tables, @@ -177,8 +175,8 @@ def paged_attention( max_logits, tmp_output, query, - key_cache, - value_cache, + kv_cache.key, + kv_cache.value, num_kv_heads, softmax_scale, block_tables, @@ -227,29 +225,35 @@ if ENGINE != "triton": SUPPORTS_WINDOWING = False -if ENGINE == "ck": - def attention( - q, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - seqlen: Seqlen, - block_tables: torch.Tensor, - softmax_scale: float, - window_size_left: int = -1, - causal: bool = True, - softcap: float = 0.0, - ): + +def attention( + *, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: KVCache, + seqlen: Seqlen, + block_tables: torch.Tensor, + softmax_scale: float, + window_size_left: int = -1, + causal: bool = True, + softcap: Optional[float] = None, +): + if ENGINE == "ck": if window_size_left <= 0 and window_size_left != -1: raise ValueError("`window_size_left` must be > 0 or -1") - out = torch.empty_like(q) + out = torch.empty_like(query) + + if softcap is None: + softcap = 0.0 # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. return flash_attn_2_cuda.varlen_fwd( - q, - key_cache, - value_cache, + query, + key, + value, out, seqlen.cu_seqlen_q, seqlen.cu_seqlen_q, @@ -270,30 +274,19 @@ if ENGINE == "ck": None, )[0] -elif ENGINE == "triton": - from .flash_attn_triton import triton_attention + elif ENGINE == "triton": + from .flash_attn_triton import triton_attention - def attention( - q, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - seqlen: Seqlen, - block_tables: torch.Tensor, - softmax_scale: float, - window_size_left: int = -1, - causal: bool = True, - softcap: Optional[float] = None, - ): if softcap is not None: raise NotImplementedError("softcap is only available with CK flash attn") - out = torch.empty_like(q) + out = torch.empty_like(query) # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. output, _ = triton_attention( - q, - key_cache, - value_cache, + query, + key, + value, out, seqlen.cu_seqlen_q, seqlen.cu_seqlen_q, @@ -304,11 +297,11 @@ elif ENGINE == "triton": ) return output -else: - raise RuntimeError(f"Unknown attention engine {ENGINE}") + else: + raise RuntimeError(f"Unknown attention engine {ENGINE}") + __all__ = [ - "PREFILL_IN_KV_CACHE", "SUPPORTS_WINDOWING", "attention", "paged_attention", diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index d0425fec..4eee5c20 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -38,7 +38,6 @@ from text_generation_server.layers import ( SpeculativeHead, get_linear, ) -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.layernorm import ( FastLayerNorm, ) @@ -296,19 +295,19 @@ class FlashCohereAttention(torch.nn.Module): if cu_seqlen_prefill is not None: # flash attention attn_output = attention( - query, - kv_cache.key if PREFILL_IN_KV_CACHE else key, - kv_cache.value if PREFILL_IN_KV_CACHE else value, - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=key, + value=value, + kv_cache=kv_cache, + seqlen=seqlen, + block_tables=block_tables, + softmax_scale=self.softmax_scale, ) # Decode else: attn_output = paged_attention( query, - kv_cache.key, - kv_cache.value, + kv_cache, self.kv_head_mapping, self.softmax_scale, block_tables, diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index b2b0cecb..4ee67741 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -29,7 +29,6 @@ from text_generation_server.layers.attention import ( paged_attention, attention, Seqlen, - PREFILL_IN_KV_CACHE, ) from text_generation_server.layers import ( FastLinear, @@ -335,19 +334,19 @@ class DbrxAttention(torch.nn.Module): if cu_seqlen_prefill is not None: # flash attention attn_output = attention( - query, - kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0], - kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1], - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=kv[:, 0], + value=kv[:, 1], + kv_cache=kv_cache, + seqlen=seqlen, + block_tables=block_tables, + softmax_scale=self.softmax_scale, ) # Decode else: attn_output = paged_attention( query, - kv_cache.key, - kv_cache.value, + kv_cache, self.kv_head_mapping, self.softmax_scale, block_tables, diff --git a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py index af77af8e..97b3ea96 100644 --- a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -34,7 +34,6 @@ from text_generation_server.layers.attention import ( attention, paged_attention, ) -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale @@ -326,19 +325,19 @@ class DeepseekV2Attention(torch.nn.Module): if cu_seqlen_prefill is not None: # flash attention attn_output = attention( - query, - kv_cache.key if PREFILL_IN_KV_CACHE else key, - kv_cache.value if PREFILL_IN_KV_CACHE else value, - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=key, + value=value, + kv_cache=kv_cache, + seqlen=seqlen, + block_tables=block_tables, + softmax_scale=self.softmax_scale, ) # Decode else: attn_output = paged_attention( query, - kv_cache.key, - kv_cache.value, + kv_cache, self.kv_head_mapping, self.softmax_scale, block_tables, diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index 03b9b2a0..c962a2af 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -39,7 +39,6 @@ from text_generation_server.layers import ( TensorParallelMultiAdapterLinear, TensorParallelAdapterRowLinear, ) -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, @@ -258,13 +257,13 @@ class FlashGemma2Attention(torch.nn.Module): if cu_seqlen_prefill is not None: # flash attention attn_output = attention( - query, - kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0], - kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1], - seqlen, - block_tables, - self.softmax_scale, - causal=self.causal, + query=query, + key=kv[:, 0], + value=kv[:, 1], + kv_cache=kv_cache, + seqlen=seqlen, + block_tables=block_tables, + softmax_scale=self.softmax_scale, window_size_left=self.window_size, softcap=self.softcap, ) @@ -272,8 +271,7 @@ class FlashGemma2Attention(torch.nn.Module): else: attn_output = paged_attention( query, - kv_cache.key, - kv_cache.value, + kv_cache, self.kv_head_mapping, self.softmax_scale, block_tables, diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index f3c46901..b127f284 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -29,7 +29,6 @@ from text_generation_server.layers.attention import ( paged_attention, attention, Seqlen, - PREFILL_IN_KV_CACHE, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -229,20 +228,20 @@ class FlashGemmaAttention(torch.nn.Module): if cu_seqlen_prefill is not None: # flash attention attn_output = attention( - query, - kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0], - kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1], - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=kv[:, 0], + value=kv[:, 1], + kv_cache=kv_cache, + seqlen=seqlen, + block_tables=block_tables, + softmax_scale=self.softmax_scale, causal=self.causal, ) # Decode else: attn_output = paged_attention( query, - kv_cache.key, - kv_cache.value, + kv_cache, self.kv_head_mapping, self.softmax_scale, block_tables, diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index 94a8898d..2d005734 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -24,7 +24,6 @@ import torch.distributed from torch import nn from transformers.activations import ACT2FN from typing import Optional, List, Tuple -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.attention import ( paged_attention, attention, @@ -229,19 +228,19 @@ class FlashGPT2Attention(torch.nn.Module): if cu_seqlen_prefill is not None: # flash attention attn_output = attention( - query, - kv_cache.key if PREFILL_IN_KV_CACHE else key, - kv_cache.value if PREFILL_IN_KV_CACHE else value, - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=key, + value=value, + kv_cache=kv_cache, + seqlen=seqlen, + block_tables=block_tables, + softmax_scale=self.softmax_scale, ) # Decode else: attn_output = paged_attention( query, - kv_cache.key, - kv_cache.value, + kv_cache, self.kv_head_mapping, self.softmax_scale, block_tables, diff --git a/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py index f0a1270e..2eef1ded 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py @@ -37,7 +37,6 @@ from text_generation_server.layers import ( SpeculativeHead, get_linear, ) -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.rotary import ( PositionRotaryEmbedding, ) @@ -191,19 +190,19 @@ class FlashGPTJAttention(torch.nn.Module): if cu_seqlen_prefill is not None: # flash attention attn_output = attention( - query, - kv_cache.key if PREFILL_IN_KV_CACHE else key, - kv_cache.value if PREFILL_IN_KV_CACHE else value, - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=key, + value=value, + kv_cache=kv_cache, + seqlen=seqlen, + block_tables=block_tables, + softmax_scale=self.softmax_scale, ) # Decode else: attn_output = paged_attention( query, - kv_cache.key, - kv_cache.value, + kv_cache, self.kv_head_mapping, self.softmax_scale, block_tables, diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index fbe45d79..5c820bb6 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -27,7 +27,7 @@ import torch.distributed from torch import nn from transformers.activations import ACT2FN -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE, KVCache +from text_generation_server.layers.attention import KVCache from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( @@ -227,19 +227,19 @@ class FlashLlamaAttention(torch.nn.Module): if cu_seqlen_prefill is not None: # flash attention attn_output = attention( - query, - kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0], - kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1], - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=kv[:, 0], + value=kv[:, 1], + kv_cache=kv_cache, + seqlen=seqlen, + block_tables=block_tables, + softmax_scale=self.softmax_scale, ) # Decode else: attn_output = paged_attention( query, - kv_cache.key, - kv_cache.value, + kv_cache, self.kv_head_mapping, self.softmax_scale, block_tables, diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 8974035e..7bad429c 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -40,7 +40,6 @@ from text_generation_server.layers import ( TensorParallelMultiAdapterLinear, TensorParallelAdapterRowLinear, ) -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, @@ -215,20 +214,20 @@ class MistralAttention(torch.nn.Module): if cu_seqlen_prefill is not None: # flash attention attn_output = attention( - query, - kv_cache.key if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], - kv_cache.value if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=kv_to_cache[:, 0], + value=kv_to_cache[:, 1], + kv_cache=kv_cache, + seqlen=seqlen, + block_tables=block_tables, + softmax_scale=self.softmax_scale, window_size_left=self.max_past, ) # Decode else: attn_output = paged_attention( query, - kv_cache.key, - kv_cache.value, + kv_cache, self.kv_head_mapping, self.softmax_scale, block_tables, diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index e7bc8320..712b7bc4 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -38,7 +38,6 @@ from text_generation_server.layers.attention import ( attention, paged_attention, ) -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.rotary import PositionRotaryEmbedding @@ -263,20 +262,20 @@ class MixtralAttention(torch.nn.Module): if cu_seqlen_prefill is not None: # flash attention attn_output = attention( - query, - kv_cache.key if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], - kv_cache.value if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=kv_to_cache[:, 0], + value=kv_to_cache[:, 1], + kv_cache=kv_cache, + seqlen=seqlen, + block_tables=block_tables, + softmax_scale=self.softmax_scale, window_size_left=self.max_past, ) # Decode else: attn_output = paged_attention( query, - kv_cache.key, - kv_cache.value, + kv_cache, self.kv_head_mapping, self.softmax_scale, block_tables, diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index bcbea442..2ce69d8e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -38,7 +38,6 @@ from text_generation_server.layers import ( SpeculativeHead, get_linear, ) -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.layernorm import ( FastLayerNorm, ) @@ -170,19 +169,19 @@ class FlashNeoxAttention(torch.nn.Module): if cu_seqlen_prefill is not None: # flash attention attn_output = attention( - qkv[:, 0], - kv_cache.key if PREFILL_IN_KV_CACHE else qkv[:, 1], - kv_cache.value if PREFILL_IN_KV_CACHE else qkv[:, 2], - seqlen, - block_tables, - self.softmax_scale, + query=qkv[:, 0], + key=qkv[:, 1], + value=qkv[:, 2], + kv_cache=kv_cache, + seqlen=seqlen, + block_tables=block_tables, + softmax_scale=self.softmax_scale, ) # Decode else: attn_output = paged_attention( qkv[:, 0], - kv_cache.key, - kv_cache.value, + kv_cache, self.kv_head_mapping, self.softmax_scale, block_tables, diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index cb7b6ee2..62d524c9 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -18,7 +18,6 @@ from text_generation_server.layers import ( SpeculativeHead, get_linear, ) -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.layernorm import ( FastLayerNorm, ) @@ -192,19 +191,19 @@ class FlashPhiAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: attn_output = attention( - query, - kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0], - kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1], - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=kv[:, 0], + value=kv[:, 1], + kv_cache=kv_cache, + seqlen=seqlen, + block_tables=block_tables, + softmax_scale=self.softmax_scale, ) # Decode else: attn_output = paged_attention( query, - kv_cache.key, - kv_cache.value, + kv_cache, self.kv_head_mapping, self.softmax_scale, block_tables, diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 8185885f..905dd98f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -16,7 +16,6 @@ from text_generation_server.layers import ( TensorParallelEmbedding, SpeculativeHead, ) -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, @@ -133,20 +132,20 @@ class Qwen2Attention(torch.nn.Module): if cu_seqlen_prefill is not None: # flash attention attn_output = attention( - query, - kv_cache.key if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], - kv_cache.value if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=kv_to_cache[:, 0], + value=kv_to_cache[:, 1], + kv_cache=kv_cache, + seqlen=seqlen, + block_tables=block_tables, + softmax_scale=self.softmax_scale, window_size_left=self.max_past, ) # Decode else: attn_output = paged_attention( query, - kv_cache.key, - kv_cache.value, + kv_cache, self.kv_head_mapping, self.softmax_scale, block_tables, diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index dac8ecf9..8085ff89 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -12,7 +12,6 @@ from text_generation_server.layers import ( TensorParallelRowLinear, get_linear, ) -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.layernorm import FastLayerNorm from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.attention import ( @@ -205,19 +204,19 @@ class FlashRWAttention(torch.nn.Module): if cu_seqlen_prefill is not None: # flash attention attn_output = attention( - query, - kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0], - kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1], - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=kv[:, 0], + value=kv[:, 1], + kv_cache=kv_cache, + seqlen=seqlen, + block_tables=block_tables, + softmax_scale=self.softmax_scale, ) # Decode else: attn_output = paged_attention( query, - kv_cache.key, - kv_cache.value, + kv_cache, self.kv_head_mapping, self.softmax_scale, block_tables, @@ -319,19 +318,19 @@ class FlashRWLargeAttention(torch.nn.Module): if cu_seqlen_prefill is not None: # flash attention attn_output = attention( - query, - kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, :, 0].contiguous(), - kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, :, 1].contiguous(), - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=kv[:, :, 0], + value=kv[:, :, 1], + kv_cache=kv_cache, + seqlen=seqlen, + block_tables=block_tables, + softmax_scale=self.softmax_scale, ) # Decode else: attn_output = paged_attention( query, - kv_cache.key, - kv_cache.value, + kv_cache, self.kv_head_mapping, self.softmax_scale, block_tables, diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 5972d436..52119b64 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -17,7 +17,6 @@ from text_generation_server.layers import ( TensorParallelEmbedding, get_linear, ) -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.gptq import GPTQWeightsLoader from text_generation_server.layers.layernorm import ( FastLayerNorm, @@ -289,19 +288,19 @@ class FlashMQAttention(torch.nn.Module): if cu_seqlen_prefill is not None: # flash attention attn_output = attention( - query, - kv_cache.key if PREFILL_IN_KV_CACHE else key_value[:, 0], - kv_cache.value if PREFILL_IN_KV_CACHE else key_value[:, 1], - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=key_value[:, 0], + value=key_value[:, 1], + kv_cache=kv_cache, + seqlen=seqlen, + block_tables=block_tables, + softmax_scale=self.softmax_scale, ) # Decode else: attn_output = paged_attention( query, - kv_cache.key, - kv_cache.value, + kv_cache, self.kv_head_mapping, self.softmax_scale, block_tables, diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index 037238b8..fe339aee 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -38,7 +38,6 @@ from text_generation_server.layers import ( SpeculativeHead, get_linear, ) -from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.layernorm import ( FastLayerNorm, FastRMSNorm, @@ -238,20 +237,20 @@ class Starcoder2Attention(torch.nn.Module): if cu_seqlen_prefill is not None: # flash attention attn_output = attention( - query, - kv_cache.key if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], - kv_cache.value if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], - seqlen, - block_tables, - self.softmax_scale, + query=query, + key=kv_to_cache[:, 0], + value=kv_to_cache[:, 1], + kv_cache=kv_cache, + seqlen=seqlen, + block_tables=block_tables, + softmax_scale=self.softmax_scale, window_size_left=self.max_past, ) # Decode else: attn_output = paged_attention( query, - kv_cache.key, - kv_cache.value, + kv_cache, self.kv_head_mapping, self.softmax_scale, block_tables,