Simplify the `attention` function (#2609)

* Simplify the `attention` function

- Use one definition rather than multiple.
- Add `key`/`value` arguments, so that we don't need the
  `PREFILL_IN_KVCACHE` constant.
- Make it kwargs-only (to avoid mixing up the various `Tensor` args).

* Fixup flashinfer support
This commit is contained in:
Daniël de Kok 2024-10-17 10:42:52 +02:00 committed by GitHub
parent 5bbe1ce028
commit 59ea38cbca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 313 additions and 463 deletions

View File

@ -8,7 +8,6 @@ if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.") raise ImportError("`USE_FLASH_ATTENTION` is false.")
if SYSTEM == "cuda": if SYSTEM == "cuda":
from .cuda import ( from .cuda import (
PREFILL_IN_KV_CACHE,
SUPPORTS_WINDOWING, SUPPORTS_WINDOWING,
attention, attention,
paged_attention, paged_attention,
@ -16,7 +15,6 @@ if SYSTEM == "cuda":
) )
elif SYSTEM == "rocm": elif SYSTEM == "rocm":
from .rocm import ( from .rocm import (
PREFILL_IN_KV_CACHE,
SUPPORTS_WINDOWING, SUPPORTS_WINDOWING,
attention, attention,
paged_attention, paged_attention,
@ -24,7 +22,6 @@ elif SYSTEM == "rocm":
) )
elif SYSTEM == "ipex": elif SYSTEM == "ipex":
from .ipex import ( from .ipex import (
PREFILL_IN_KV_CACHE,
SUPPORTS_WINDOWING, SUPPORTS_WINDOWING,
attention, attention,
paged_attention, paged_attention,
@ -40,7 +37,6 @@ __all__ = [
"attention", "attention",
"paged_attention", "paged_attention",
"reshape_and_cache", "reshape_and_cache",
"PREFILL_IN_KV_CACHE",
"SUPPORTS_WINDOWING", "SUPPORTS_WINDOWING",
"KVCache", "KVCache",
"Seqlen", "Seqlen",

View File

@ -1,4 +1,5 @@
import torch import torch
from text_generation_server.layers.attention.kv_cache import KVCache
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import ( from text_generation_server.models.globals import (
ATTENTION, ATTENTION,
@ -38,8 +39,7 @@ def reshape_and_cache(
def paged_attention( def paged_attention(
query: torch.Tensor, query: torch.Tensor,
key_cache: torch.Tensor, kv_cache: KVCache,
value_cache: torch.Tensor,
kv_head_mapping: torch.Tensor, kv_head_mapping: torch.Tensor,
softmax_scale: float, softmax_scale: float,
block_tables: torch.Tensor, block_tables: torch.Tensor,
@ -80,7 +80,7 @@ def paged_attention(
return decode_state.get().forward( return decode_state.get().forward(
query.contiguous(), query.contiguous(),
paged_kv_cache=(key_cache, value_cache), paged_kv_cache=(kv_cache.key, kv_cache.value),
logits_soft_cap=softcap, logits_soft_cap=softcap,
sm_scale=softmax_scale, sm_scale=softmax_scale,
) )
@ -98,8 +98,8 @@ def paged_attention(
softcap = 0.0 softcap = 0.0
out = flash_attn_2_cuda.varlen_fwd( out = flash_attn_2_cuda.varlen_fwd(
query, query,
key_cache, kv_cache.key,
value_cache, kv_cache.value,
None, None,
seqlen.cu_seqlen_q, seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k, seqlen.cu_seqlen_k,
@ -135,8 +135,8 @@ def paged_attention(
ops.paged_attention_v1( ops.paged_attention_v1(
out, out,
query, query,
key_cache, kv_cache.key,
value_cache, kv_cache.value,
kv_head_mapping, kv_head_mapping,
softmax_scale, softmax_scale,
block_tables, block_tables,
@ -168,8 +168,8 @@ def paged_attention(
max_logits, max_logits,
tmp_output, tmp_output,
query, query,
key_cache, kv_cache.key,
value_cache, kv_cache.value,
kv_head_mapping, kv_head_mapping,
softmax_scale, softmax_scale,
block_tables, block_tables,
@ -216,61 +216,63 @@ except ImportError:
) from e ) from e
if ATTENTION == "flashdecoding" and not V2:
raise ValueError("Flash decoding requires Flash Attention V2")
SUPPORTS_WINDOWING = V2 SUPPORTS_WINDOWING = V2
if ATTENTION == "flashinfer":
def attention( def attention(
q: torch.Tensor, *,
key_cache: torch.Tensor, query: torch.Tensor,
value_cache: torch.Tensor, key: torch.Tensor,
value: torch.Tensor,
kv_cache: KVCache,
seqlen: Seqlen, seqlen: Seqlen,
block_tables: torch.Tensor, block_tables: torch.Tensor,
softmax_scale, softmax_scale: float,
window_size_left=-1, window_size_left: int = -1,
causal=True, causal: bool = True,
softcap=0.0, softcap: Optional[float] = None,
): ):
if ATTENTION == "flashinfer":
from text_generation_server.layers.attention.flashinfer import ( from text_generation_server.layers.attention.flashinfer import (
prefill_with_paged_kv_state, prefill_with_paged_kv_state,
) )
if softcap is None:
softcap = 0.0
return prefill_with_paged_kv_state.get().forward( return prefill_with_paged_kv_state.get().forward(
q.contiguous(), query.contiguous(),
causal=causal, causal=causal,
paged_kv_cache=(key_cache, value_cache), paged_kv_cache=(kv_cache.key, kv_cache.value),
logits_soft_cap=softcap, logits_soft_cap=softcap,
sm_scale=softmax_scale, sm_scale=softmax_scale,
window_left=window_size_left, window_left=window_size_left,
) )
elif ATTENTION == "flashdecoding": # If we are using flashdecoding or paged, we always use flash-attn for
if V2: # the prefill. We have to branch on whether we use flash-attn v1 or v2.
elif V2:
def attention( out = torch.empty_like(query)
q,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale,
window_size_left=-1,
causal=True,
softcap=0.0,
):
out = torch.empty_like(q)
if window_size_left <= 0 and window_size_left != -1: if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1") raise ValueError("`window_size_left` must be > 0 or -1")
if softcap is None:
softcap = 0.0
return flash_attn_2_cuda.varlen_fwd( return flash_attn_2_cuda.varlen_fwd(
q, query,
key_cache, # flashdecoding: pass the KV caches, paged: pass the KV.
value_cache, kv_cache.key if ATTENTION == "flashdecoding" else key,
kv_cache.value if ATTENTION == "flashdecoding" else value,
out, out,
seqlen.cu_seqlen_q, seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k, seqlen.cu_seqlen_k,
None, None,
None, None,
block_tables, block_tables if ATTENTION == "flashdecoding" else None,
None, None,
seqlen.max_q, seqlen.max_q,
seqlen.max_k, seqlen.max_k,
@ -286,58 +288,44 @@ elif ATTENTION == "flashdecoding":
)[0] )[0]
else: else:
def attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap=None,
):
if window_size_left != -1: if window_size_left != -1:
raise NotImplementedError( raise NotImplementedError(
"window_size_left is only available with flash attn v2" "window_size_left is only available with flash attn v2"
) )
if softcap is not None: if softcap is not None:
raise NotImplementedError( raise NotImplementedError("softcap is not available in flash attn v1")
"softcap is only available with flash attn v2"
)
# Flash attention v1 requires q, k and v to have the same number of heads # Flash attention v1 requires q, k and v to have the same number of heads
if k.shape[1] != q.shape[1]: if key.shape[1] != query.shape[1]:
# MQA expand # MQA expand
if k.shape[1] == 1: if key.shape[1] == 1:
k = k.expand(-1, q.shape[1], -1) key = key.expand(-1, query.shape[1], -1)
# Grouped attention reshape # Grouped attention reshape
else: else:
original_shape = k.shape original_shape = key.shape
k = ( key = (
k.unsqueeze(2) key.unsqueeze(2)
.expand(-1, -1, q.shape[1] // k.shape[1], -1) .expand(-1, -1, query.shape[1] // key.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2]) .reshape(original_shape[0], -1, original_shape[2])
) )
if v.shape[1] != q.shape[1]: if value.shape[1] != query.shape[1]:
# MQA expand # MQA expand
if v.shape[1] == 1: if value.shape[1] == 1:
v = v.expand(-1, q.shape[1], -1) value = value.expand(-1, query.shape[1], -1)
# Grouped attention reshape # Grouped attention reshape
else: else:
original_shape = v.shape original_shape = value.shape
v = ( value = (
v.unsqueeze(2) value.unsqueeze(2)
.expand(-1, -1, q.shape[1] // v.shape[1], -1) .expand(-1, -1, query.shape[1] // value.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2]) .reshape(original_shape[0], -1, original_shape[2])
) )
out = torch.empty_like(q) out = torch.empty_like(query)
flash_attn_cuda.fwd( flash_attn_cuda.fwd(
q, query,
k, key,
v, value,
out, out,
seqlen.cu_seqlen_q, seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q, seqlen.cu_seqlen_q,
@ -353,126 +341,8 @@ elif ATTENTION == "flashdecoding":
) )
return out return out
elif ATTENTION == "paged":
if V2:
def attention(
q,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale,
window_size_left=-1,
causal=True,
softcap=0.0,
):
out = torch.empty_like(q)
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
return flash_attn_2_cuda.varlen_fwd(
q,
key_cache,
value_cache,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
None,
None,
None, # block_tables,
None,
seqlen.max_q,
seqlen.max_k,
0.0,
softmax_scale,
False,
causal,
window_size_left,
0,
softcap,
False,
None,
)[0]
else:
def attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap=None,
):
if window_size_left != -1:
raise NotImplementedError(
"window_size_left is only available with flash attn v2"
)
if softcap is not None:
raise NotImplementedError(
"softcap is only available with flash attn v2"
)
# Flash attention v1 requires q, k and v to have the same number of heads
if k.shape[1] != q.shape[1]:
# MQA expand
if k.shape[1] == 1:
k = k.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = k.shape
k = (
k.unsqueeze(2)
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
if v.shape[1] != q.shape[1]:
# MQA expand
if v.shape[1] == 1:
v = v.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = v.shape
v = (
v.unsqueeze(2)
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
out = torch.empty_like(q)
flash_attn_cuda.fwd(
q,
k,
v,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
seqlen.max_q,
seqlen.max_k,
0.0,
softmax_scale,
False,
causal,
False,
0,
None,
)
return out
else:
raise RuntimeError(f"Unknwon attention {ATTENTION}")
# Prefill in the cache with every kind of attention, unless we
# have a configuration that requires flash-attention v1, which
# does not support block tables.
PREFILL_IN_KV_CACHE = ATTENTION == "flashinfer" or (ATTENTION == "flashdecoding" and V2)
__all__ = [ __all__ = [
"PREFILL_IN_KV_CACHE",
"SUPPORTS_WINDOWING", "SUPPORTS_WINDOWING",
"attention", "attention",
"paged_attention", "paged_attention",

View File

@ -1,31 +1,36 @@
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
import torch import torch
from text_generation_server.layers.attention.kv_cache import KVCache
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
from text_generation_server.layers.attention import Seqlen from text_generation_server.layers.attention import Seqlen
from typing import Optional from typing import Optional
SUPPORTS_WINDOWING = False SUPPORTS_WINDOWING = False
PREFILL_IN_KV_CACHE = False
def attention( def attention(
q: torch.Tensor, *,
key_cache: torch.Tensor, query: torch.Tensor,
value_cache: torch.Tensor, key: torch.Tensor,
value: torch.Tensor,
kv_cache: KVCache,
seqlen: Seqlen, seqlen: Seqlen,
block_tables: torch.Tensor, block_tables: torch.Tensor,
softmax_scale, softmax_scale: float,
window_size_left=-1, window_size_left: int = -1,
causal=True, causal: bool = True,
softcap: Optional[float] = None, softcap: Optional[float] = None,
): ):
out = torch.empty_like(q) if softcap is not None:
raise NotImplementedError("softcap is not available in IPEX")
out = torch.empty_like(query)
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
ipex.llm.functional.varlen_attention( ipex.llm.functional.varlen_attention(
q.contiguous() if q.device.type == "xpu" else q, query.contiguous() if query.device.type == "xpu" else query,
key_cache.contiguous() if key_cache.device.type == "xpu" else key_cache, key.contiguous() if key.device.type == "xpu" else key,
value_cache.contiguous() if value_cache.device.type == "xpu" else value_cache, value.contiguous() if value.device.type == "xpu" else value,
out, out,
seqlen.cu_seqlen_q, seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q, seqlen.cu_seqlen_q,
@ -56,8 +61,7 @@ def reshape_and_cache(
def paged_attention( def paged_attention(
query: torch.Tensor, query: torch.Tensor,
key_cache: torch.Tensor, kv_cache: KVCache,
value_cache: torch.Tensor,
kv_head_mapping: torch.Tensor, kv_head_mapping: torch.Tensor,
softmax_scale: float, softmax_scale: float,
block_tables: torch.Tensor, block_tables: torch.Tensor,
@ -65,13 +69,16 @@ def paged_attention(
max_s: int, max_s: int,
softcap: Optional[float] = None, softcap: Optional[float] = None,
): ):
if softcap is not None:
raise NotImplementedError("softcap is not available in IPEX")
out = torch.empty_like(query) out = torch.empty_like(query)
input_lengths = seqlen.input_lengths + seqlen.cache_lengths input_lengths = seqlen.input_lengths + seqlen.cache_lengths
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
out, out,
query, query,
key_cache, kv_cache.key,
value_cache, kv_cache.value,
kv_head_mapping, kv_head_mapping,
softmax_scale, softmax_scale,
block_tables, block_tables,
@ -84,7 +91,6 @@ def paged_attention(
__all__ = [ __all__ = [
"PREFILL_IN_KV_CACHE",
"SUPPORTS_WINDOWING", "SUPPORTS_WINDOWING",
"attention", "attention",
"paged_attention", "paged_attention",

View File

@ -3,7 +3,6 @@ from typing import Tuple
import torch import torch
from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import reshape_and_cache
class KVCache: class KVCache:
@ -116,4 +115,6 @@ class KVCache:
key_cache.view(-1, shape[-2], shape[-1])[slots] = key key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else: else:
from text_generation_server.layers.attention import reshape_and_cache
reshape_and_cache(key, value, key_cache, value_cache, slots) reshape_and_cache(key, value, key_cache, value_cache, slots)

View File

@ -1,6 +1,7 @@
import os import os
from typing import Optional from typing import Optional
import torch import torch
from text_generation_server.layers.attention.kv_cache import KVCache
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import ATTENTION from text_generation_server.models.globals import ATTENTION
from text_generation_server.layers.attention import Seqlen from text_generation_server.layers.attention import Seqlen
@ -16,8 +17,6 @@ _PARTITION_SIZE_CUSTOM = 256
use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"} use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"}
ENGINE = "triton" if use_triton else "ck" ENGINE = "triton" if use_triton else "ck"
PREFILL_IN_KV_CACHE = False
use_rocm_custom_paged_attn = os.getenv("ROCM_USE_CUSTOM_PAGED_ATTN", "1") != "0" use_rocm_custom_paged_attn = os.getenv("ROCM_USE_CUSTOM_PAGED_ATTN", "1") != "0"
try: try:
if use_rocm_custom_paged_attn: if use_rocm_custom_paged_attn:
@ -54,8 +53,7 @@ def reshape_and_cache(
def paged_attention( def paged_attention(
query: torch.Tensor, query: torch.Tensor,
key_cache: torch.Tensor, kv_cache: KVCache,
value_cache: torch.Tensor,
kv_head_mapping: torch.Tensor, kv_head_mapping: torch.Tensor,
softmax_scale: float, softmax_scale: float,
block_tables: torch.Tensor, block_tables: torch.Tensor,
@ -84,10 +82,10 @@ def paged_attention(
raise RuntimeError("Paged attention doesn't support softcapping") raise RuntimeError("Paged attention doesn't support softcapping")
# value_cache => [num_blocks, num_heads, head_size, block_size] # value_cache => [num_blocks, num_heads, head_size, block_size]
block_size = value_cache.shape[3] block_size = kv_cache.value.shape[3]
num_seqs, num_heads, head_size = query.shape num_seqs, num_heads, head_size = query.shape
num_kv_heads = key_cache.shape[1] num_kv_heads = kv_cache.key.shape[1]
gqa_ratio = num_heads // num_kv_heads gqa_ratio = num_heads // num_kv_heads
use_custom = ( use_custom = (
use_rocm_custom_paged_attn use_rocm_custom_paged_attn
@ -124,8 +122,8 @@ def paged_attention(
ops.paged_attention_v1( ops.paged_attention_v1(
out, out,
query, query,
key_cache, kv_cache.key,
value_cache, kv_cache.value,
kv_head_mapping, kv_head_mapping,
softmax_scale, softmax_scale,
block_tables, block_tables,
@ -158,8 +156,8 @@ def paged_attention(
max_logits, max_logits,
tmp_output, tmp_output,
query, query,
key_cache, kv_cache.key,
value_cache, kv_cache.value,
kv_head_mapping, kv_head_mapping,
softmax_scale, softmax_scale,
block_tables, block_tables,
@ -177,8 +175,8 @@ def paged_attention(
max_logits, max_logits,
tmp_output, tmp_output,
query, query,
key_cache, kv_cache.key,
value_cache, kv_cache.value,
num_kv_heads, num_kv_heads,
softmax_scale, softmax_scale,
block_tables, block_tables,
@ -227,29 +225,35 @@ if ENGINE != "triton":
SUPPORTS_WINDOWING = False SUPPORTS_WINDOWING = False
if ENGINE == "ck":
def attention(
q, def attention(
key_cache: torch.Tensor, *,
value_cache: torch.Tensor, query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: KVCache,
seqlen: Seqlen, seqlen: Seqlen,
block_tables: torch.Tensor, block_tables: torch.Tensor,
softmax_scale: float, softmax_scale: float,
window_size_left: int = -1, window_size_left: int = -1,
causal: bool = True, causal: bool = True,
softcap: float = 0.0, softcap: Optional[float] = None,
): ):
if ENGINE == "ck":
if window_size_left <= 0 and window_size_left != -1: if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1") raise ValueError("`window_size_left` must be > 0 or -1")
out = torch.empty_like(q) out = torch.empty_like(query)
if softcap is None:
softcap = 0.0
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
return flash_attn_2_cuda.varlen_fwd( return flash_attn_2_cuda.varlen_fwd(
q, query,
key_cache, key,
value_cache, value,
out, out,
seqlen.cu_seqlen_q, seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q, seqlen.cu_seqlen_q,
@ -270,30 +274,19 @@ if ENGINE == "ck":
None, None,
)[0] )[0]
elif ENGINE == "triton": elif ENGINE == "triton":
from .flash_attn_triton import triton_attention from .flash_attn_triton import triton_attention
def attention(
q,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap: Optional[float] = None,
):
if softcap is not None: if softcap is not None:
raise NotImplementedError("softcap is only available with CK flash attn") raise NotImplementedError("softcap is only available with CK flash attn")
out = torch.empty_like(q) out = torch.empty_like(query)
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
output, _ = triton_attention( output, _ = triton_attention(
q, query,
key_cache, key,
value_cache, value,
out, out,
seqlen.cu_seqlen_q, seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q, seqlen.cu_seqlen_q,
@ -304,11 +297,11 @@ elif ENGINE == "triton":
) )
return output return output
else: else:
raise RuntimeError(f"Unknown attention engine {ENGINE}") raise RuntimeError(f"Unknown attention engine {ENGINE}")
__all__ = [ __all__ = [
"PREFILL_IN_KV_CACHE",
"SUPPORTS_WINDOWING", "SUPPORTS_WINDOWING",
"attention", "attention",
"paged_attention", "paged_attention",

View File

@ -38,7 +38,6 @@ from text_generation_server.layers import (
SpeculativeHead, SpeculativeHead,
get_linear, get_linear,
) )
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
) )
@ -296,19 +295,19 @@ class FlashCohereAttention(torch.nn.Module):
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query=query,
kv_cache.key if PREFILL_IN_KV_CACHE else key, key=key,
kv_cache.value if PREFILL_IN_KV_CACHE else value, value=value,
seqlen, kv_cache=kv_cache,
block_tables, seqlen=seqlen,
self.softmax_scale, block_tables=block_tables,
softmax_scale=self.softmax_scale,
) )
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache.key, kv_cache,
kv_cache.value,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,

View File

@ -29,7 +29,6 @@ from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
Seqlen, Seqlen,
PREFILL_IN_KV_CACHE,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
FastLinear, FastLinear,
@ -335,19 +334,19 @@ class DbrxAttention(torch.nn.Module):
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query=query,
kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0], key=kv[:, 0],
kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1], value=kv[:, 1],
seqlen, kv_cache=kv_cache,
block_tables, seqlen=seqlen,
self.softmax_scale, block_tables=block_tables,
softmax_scale=self.softmax_scale,
) )
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache.key, kv_cache,
kv_cache.value,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,

View File

@ -34,7 +34,6 @@ from text_generation_server.layers.attention import (
attention, attention,
paged_attention, paged_attention,
) )
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
@ -326,19 +325,19 @@ class DeepseekV2Attention(torch.nn.Module):
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query=query,
kv_cache.key if PREFILL_IN_KV_CACHE else key, key=key,
kv_cache.value if PREFILL_IN_KV_CACHE else value, value=value,
seqlen, kv_cache=kv_cache,
block_tables, seqlen=seqlen,
self.softmax_scale, block_tables=block_tables,
softmax_scale=self.softmax_scale,
) )
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache.key, kv_cache,
kv_cache.value,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,

View File

@ -39,7 +39,6 @@ from text_generation_server.layers import (
TensorParallelMultiAdapterLinear, TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear, TensorParallelAdapterRowLinear,
) )
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastRMSNorm, FastRMSNorm,
@ -258,13 +257,13 @@ class FlashGemma2Attention(torch.nn.Module):
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query=query,
kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0], key=kv[:, 0],
kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1], value=kv[:, 1],
seqlen, kv_cache=kv_cache,
block_tables, seqlen=seqlen,
self.softmax_scale, block_tables=block_tables,
causal=self.causal, softmax_scale=self.softmax_scale,
window_size_left=self.window_size, window_size_left=self.window_size,
softcap=self.softcap, softcap=self.softcap,
) )
@ -272,8 +271,7 @@ class FlashGemma2Attention(torch.nn.Module):
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache.key, kv_cache,
kv_cache.value,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,

View File

@ -29,7 +29,6 @@ from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
Seqlen, Seqlen,
PREFILL_IN_KV_CACHE,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
@ -229,20 +228,20 @@ class FlashGemmaAttention(torch.nn.Module):
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query=query,
kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0], key=kv[:, 0],
kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1], value=kv[:, 1],
seqlen, kv_cache=kv_cache,
block_tables, seqlen=seqlen,
self.softmax_scale, block_tables=block_tables,
softmax_scale=self.softmax_scale,
causal=self.causal, causal=self.causal,
) )
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache.key, kv_cache,
kv_cache.value,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,

View File

@ -24,7 +24,6 @@ import torch.distributed
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
@ -229,19 +228,19 @@ class FlashGPT2Attention(torch.nn.Module):
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query=query,
kv_cache.key if PREFILL_IN_KV_CACHE else key, key=key,
kv_cache.value if PREFILL_IN_KV_CACHE else value, value=value,
seqlen, kv_cache=kv_cache,
block_tables, seqlen=seqlen,
self.softmax_scale, block_tables=block_tables,
softmax_scale=self.softmax_scale,
) )
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache.key, kv_cache,
kv_cache.value,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,

View File

@ -37,7 +37,6 @@ from text_generation_server.layers import (
SpeculativeHead, SpeculativeHead,
get_linear, get_linear,
) )
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.rotary import ( from text_generation_server.layers.rotary import (
PositionRotaryEmbedding, PositionRotaryEmbedding,
) )
@ -191,19 +190,19 @@ class FlashGPTJAttention(torch.nn.Module):
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query=query,
kv_cache.key if PREFILL_IN_KV_CACHE else key, key=key,
kv_cache.value if PREFILL_IN_KV_CACHE else value, value=value,
seqlen, kv_cache=kv_cache,
block_tables, seqlen=seqlen,
self.softmax_scale, block_tables=block_tables,
softmax_scale=self.softmax_scale,
) )
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache.key, kv_cache,
kv_cache.value,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,

View File

@ -27,7 +27,7 @@ import torch.distributed
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE, KVCache from text_generation_server.layers.attention import KVCache
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
@ -227,19 +227,19 @@ class FlashLlamaAttention(torch.nn.Module):
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query=query,
kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0], key=kv[:, 0],
kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1], value=kv[:, 1],
seqlen, kv_cache=kv_cache,
block_tables, seqlen=seqlen,
self.softmax_scale, block_tables=block_tables,
softmax_scale=self.softmax_scale,
) )
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache.key, kv_cache,
kv_cache.value,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,

View File

@ -40,7 +40,6 @@ from text_generation_server.layers import (
TensorParallelMultiAdapterLinear, TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear, TensorParallelAdapterRowLinear,
) )
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastRMSNorm, FastRMSNorm,
@ -215,20 +214,20 @@ class MistralAttention(torch.nn.Module):
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query=query,
kv_cache.key if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], key=kv_to_cache[:, 0],
kv_cache.value if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], value=kv_to_cache[:, 1],
seqlen, kv_cache=kv_cache,
block_tables, seqlen=seqlen,
self.softmax_scale, block_tables=block_tables,
softmax_scale=self.softmax_scale,
window_size_left=self.max_past, window_size_left=self.max_past,
) )
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache.key, kv_cache,
kv_cache.value,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,

View File

@ -38,7 +38,6 @@ from text_generation_server.layers.attention import (
attention, attention,
paged_attention, paged_attention,
) )
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.rotary import PositionRotaryEmbedding
@ -263,20 +262,20 @@ class MixtralAttention(torch.nn.Module):
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query=query,
kv_cache.key if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], key=kv_to_cache[:, 0],
kv_cache.value if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], value=kv_to_cache[:, 1],
seqlen, kv_cache=kv_cache,
block_tables, seqlen=seqlen,
self.softmax_scale, block_tables=block_tables,
softmax_scale=self.softmax_scale,
window_size_left=self.max_past, window_size_left=self.max_past,
) )
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache.key, kv_cache,
kv_cache.value,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,

View File

@ -38,7 +38,6 @@ from text_generation_server.layers import (
SpeculativeHead, SpeculativeHead,
get_linear, get_linear,
) )
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
) )
@ -170,19 +169,19 @@ class FlashNeoxAttention(torch.nn.Module):
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attn_output = attention( attn_output = attention(
qkv[:, 0], query=qkv[:, 0],
kv_cache.key if PREFILL_IN_KV_CACHE else qkv[:, 1], key=qkv[:, 1],
kv_cache.value if PREFILL_IN_KV_CACHE else qkv[:, 2], value=qkv[:, 2],
seqlen, kv_cache=kv_cache,
block_tables, seqlen=seqlen,
self.softmax_scale, block_tables=block_tables,
softmax_scale=self.softmax_scale,
) )
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
qkv[:, 0], qkv[:, 0],
kv_cache.key, kv_cache,
kv_cache.value,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,

View File

@ -18,7 +18,6 @@ from text_generation_server.layers import (
SpeculativeHead, SpeculativeHead,
get_linear, get_linear,
) )
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
) )
@ -192,19 +191,19 @@ class FlashPhiAttention(torch.nn.Module):
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
attn_output = attention( attn_output = attention(
query, query=query,
kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0], key=kv[:, 0],
kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1], value=kv[:, 1],
seqlen, kv_cache=kv_cache,
block_tables, seqlen=seqlen,
self.softmax_scale, block_tables=block_tables,
softmax_scale=self.softmax_scale,
) )
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache.key, kv_cache,
kv_cache.value,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,

View File

@ -16,7 +16,6 @@ from text_generation_server.layers import (
TensorParallelEmbedding, TensorParallelEmbedding,
SpeculativeHead, SpeculativeHead,
) )
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastRMSNorm, FastRMSNorm,
@ -133,20 +132,20 @@ class Qwen2Attention(torch.nn.Module):
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query=query,
kv_cache.key if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], key=kv_to_cache[:, 0],
kv_cache.value if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], value=kv_to_cache[:, 1],
seqlen, kv_cache=kv_cache,
block_tables, seqlen=seqlen,
self.softmax_scale, block_tables=block_tables,
softmax_scale=self.softmax_scale,
window_size_left=self.max_past, window_size_left=self.max_past,
) )
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache.key, kv_cache,
kv_cache.value,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,

View File

@ -12,7 +12,6 @@ from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
get_linear, get_linear,
) )
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import FastLayerNorm from text_generation_server.layers.layernorm import FastLayerNorm
from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
@ -205,19 +204,19 @@ class FlashRWAttention(torch.nn.Module):
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query=query,
kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0], key=kv[:, 0],
kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1], value=kv[:, 1],
seqlen, kv_cache=kv_cache,
block_tables, seqlen=seqlen,
self.softmax_scale, block_tables=block_tables,
softmax_scale=self.softmax_scale,
) )
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache.key, kv_cache,
kv_cache.value,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,
@ -319,19 +318,19 @@ class FlashRWLargeAttention(torch.nn.Module):
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query=query,
kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, :, 0].contiguous(), key=kv[:, :, 0],
kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, :, 1].contiguous(), value=kv[:, :, 1],
seqlen, kv_cache=kv_cache,
block_tables, seqlen=seqlen,
self.softmax_scale, block_tables=block_tables,
softmax_scale=self.softmax_scale,
) )
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache.key, kv_cache,
kv_cache.value,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,

View File

@ -17,7 +17,6 @@ from text_generation_server.layers import (
TensorParallelEmbedding, TensorParallelEmbedding,
get_linear, get_linear,
) )
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.gptq import GPTQWeightsLoader from text_generation_server.layers.gptq import GPTQWeightsLoader
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
@ -289,19 +288,19 @@ class FlashMQAttention(torch.nn.Module):
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query=query,
kv_cache.key if PREFILL_IN_KV_CACHE else key_value[:, 0], key=key_value[:, 0],
kv_cache.value if PREFILL_IN_KV_CACHE else key_value[:, 1], value=key_value[:, 1],
seqlen, kv_cache=kv_cache,
block_tables, seqlen=seqlen,
self.softmax_scale, block_tables=block_tables,
softmax_scale=self.softmax_scale,
) )
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache.key, kv_cache,
kv_cache.value,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,

View File

@ -38,7 +38,6 @@ from text_generation_server.layers import (
SpeculativeHead, SpeculativeHead,
get_linear, get_linear,
) )
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
FastRMSNorm, FastRMSNorm,
@ -238,20 +237,20 @@ class Starcoder2Attention(torch.nn.Module):
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query=query,
kv_cache.key if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], key=kv_to_cache[:, 0],
kv_cache.value if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], value=kv_to_cache[:, 1],
seqlen, kv_cache=kv_cache,
block_tables, seqlen=seqlen,
self.softmax_scale, block_tables=block_tables,
softmax_scale=self.softmax_scale,
window_size_left=self.max_past, window_size_left=self.max_past,
) )
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache.key, kv_cache,
kv_cache.value,
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables, block_tables,