Simplify the `attention` function (#2609)

* Simplify the `attention` function

- Use one definition rather than multiple.
- Add `key`/`value` arguments, so that we don't need the
  `PREFILL_IN_KVCACHE` constant.
- Make it kwargs-only (to avoid mixing up the various `Tensor` args).

* Fixup flashinfer support
This commit is contained in:
Daniël de Kok 2024-10-17 10:42:52 +02:00 committed by GitHub
parent 5bbe1ce028
commit 59ea38cbca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 313 additions and 463 deletions

View File

@ -8,7 +8,6 @@ if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.")
if SYSTEM == "cuda":
from .cuda import (
PREFILL_IN_KV_CACHE,
SUPPORTS_WINDOWING,
attention,
paged_attention,
@ -16,7 +15,6 @@ if SYSTEM == "cuda":
)
elif SYSTEM == "rocm":
from .rocm import (
PREFILL_IN_KV_CACHE,
SUPPORTS_WINDOWING,
attention,
paged_attention,
@ -24,7 +22,6 @@ elif SYSTEM == "rocm":
)
elif SYSTEM == "ipex":
from .ipex import (
PREFILL_IN_KV_CACHE,
SUPPORTS_WINDOWING,
attention,
paged_attention,
@ -40,7 +37,6 @@ __all__ = [
"attention",
"paged_attention",
"reshape_and_cache",
"PREFILL_IN_KV_CACHE",
"SUPPORTS_WINDOWING",
"KVCache",
"Seqlen",

View File

@ -1,4 +1,5 @@
import torch
from text_generation_server.layers.attention.kv_cache import KVCache
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import (
ATTENTION,
@ -38,8 +39,7 @@ def reshape_and_cache(
def paged_attention(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
kv_cache: KVCache,
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
@ -80,7 +80,7 @@ def paged_attention(
return decode_state.get().forward(
query.contiguous(),
paged_kv_cache=(key_cache, value_cache),
paged_kv_cache=(kv_cache.key, kv_cache.value),
logits_soft_cap=softcap,
sm_scale=softmax_scale,
)
@ -98,8 +98,8 @@ def paged_attention(
softcap = 0.0
out = flash_attn_2_cuda.varlen_fwd(
query,
key_cache,
value_cache,
kv_cache.key,
kv_cache.value,
None,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
@ -135,8 +135,8 @@ def paged_attention(
ops.paged_attention_v1(
out,
query,
key_cache,
value_cache,
kv_cache.key,
kv_cache.value,
kv_head_mapping,
softmax_scale,
block_tables,
@ -168,8 +168,8 @@ def paged_attention(
max_logits,
tmp_output,
query,
key_cache,
value_cache,
kv_cache.key,
kv_cache.value,
kv_head_mapping,
softmax_scale,
block_tables,
@ -216,61 +216,63 @@ except ImportError:
) from e
if ATTENTION == "flashdecoding" and not V2:
raise ValueError("Flash decoding requires Flash Attention V2")
SUPPORTS_WINDOWING = V2
if ATTENTION == "flashinfer":
def attention(
q: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
def attention(
*,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: KVCache,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale,
window_size_left=-1,
causal=True,
softcap=0.0,
):
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap: Optional[float] = None,
):
if ATTENTION == "flashinfer":
from text_generation_server.layers.attention.flashinfer import (
prefill_with_paged_kv_state,
)
if softcap is None:
softcap = 0.0
return prefill_with_paged_kv_state.get().forward(
q.contiguous(),
query.contiguous(),
causal=causal,
paged_kv_cache=(key_cache, value_cache),
paged_kv_cache=(kv_cache.key, kv_cache.value),
logits_soft_cap=softcap,
sm_scale=softmax_scale,
window_left=window_size_left,
)
elif ATTENTION == "flashdecoding":
if V2:
def attention(
q,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale,
window_size_left=-1,
causal=True,
softcap=0.0,
):
out = torch.empty_like(q)
# If we are using flashdecoding or paged, we always use flash-attn for
# the prefill. We have to branch on whether we use flash-attn v1 or v2.
elif V2:
out = torch.empty_like(query)
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
if softcap is None:
softcap = 0.0
return flash_attn_2_cuda.varlen_fwd(
q,
key_cache,
value_cache,
query,
# flashdecoding: pass the KV caches, paged: pass the KV.
kv_cache.key if ATTENTION == "flashdecoding" else key,
kv_cache.value if ATTENTION == "flashdecoding" else value,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
None,
None,
block_tables,
block_tables if ATTENTION == "flashdecoding" else None,
None,
seqlen.max_q,
seqlen.max_k,
@ -286,58 +288,44 @@ elif ATTENTION == "flashdecoding":
)[0]
else:
def attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap=None,
):
if window_size_left != -1:
raise NotImplementedError(
"window_size_left is only available with flash attn v2"
)
if softcap is not None:
raise NotImplementedError(
"softcap is only available with flash attn v2"
)
raise NotImplementedError("softcap is not available in flash attn v1")
# Flash attention v1 requires q, k and v to have the same number of heads
if k.shape[1] != q.shape[1]:
if key.shape[1] != query.shape[1]:
# MQA expand
if k.shape[1] == 1:
k = k.expand(-1, q.shape[1], -1)
if key.shape[1] == 1:
key = key.expand(-1, query.shape[1], -1)
# Grouped attention reshape
else:
original_shape = k.shape
k = (
k.unsqueeze(2)
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
original_shape = key.shape
key = (
key.unsqueeze(2)
.expand(-1, -1, query.shape[1] // key.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
if v.shape[1] != q.shape[1]:
if value.shape[1] != query.shape[1]:
# MQA expand
if v.shape[1] == 1:
v = v.expand(-1, q.shape[1], -1)
if value.shape[1] == 1:
value = value.expand(-1, query.shape[1], -1)
# Grouped attention reshape
else:
original_shape = v.shape
v = (
v.unsqueeze(2)
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
original_shape = value.shape
value = (
value.unsqueeze(2)
.expand(-1, -1, query.shape[1] // value.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
out = torch.empty_like(q)
out = torch.empty_like(query)
flash_attn_cuda.fwd(
q,
k,
v,
query,
key,
value,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
@ -353,126 +341,8 @@ elif ATTENTION == "flashdecoding":
)
return out
elif ATTENTION == "paged":
if V2:
def attention(
q,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale,
window_size_left=-1,
causal=True,
softcap=0.0,
):
out = torch.empty_like(q)
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
return flash_attn_2_cuda.varlen_fwd(
q,
key_cache,
value_cache,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
None,
None,
None, # block_tables,
None,
seqlen.max_q,
seqlen.max_k,
0.0,
softmax_scale,
False,
causal,
window_size_left,
0,
softcap,
False,
None,
)[0]
else:
def attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap=None,
):
if window_size_left != -1:
raise NotImplementedError(
"window_size_left is only available with flash attn v2"
)
if softcap is not None:
raise NotImplementedError(
"softcap is only available with flash attn v2"
)
# Flash attention v1 requires q, k and v to have the same number of heads
if k.shape[1] != q.shape[1]:
# MQA expand
if k.shape[1] == 1:
k = k.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = k.shape
k = (
k.unsqueeze(2)
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
if v.shape[1] != q.shape[1]:
# MQA expand
if v.shape[1] == 1:
v = v.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = v.shape
v = (
v.unsqueeze(2)
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
out = torch.empty_like(q)
flash_attn_cuda.fwd(
q,
k,
v,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
seqlen.max_q,
seqlen.max_k,
0.0,
softmax_scale,
False,
causal,
False,
0,
None,
)
return out
else:
raise RuntimeError(f"Unknwon attention {ATTENTION}")
# Prefill in the cache with every kind of attention, unless we
# have a configuration that requires flash-attention v1, which
# does not support block tables.
PREFILL_IN_KV_CACHE = ATTENTION == "flashinfer" or (ATTENTION == "flashdecoding" and V2)
__all__ = [
"PREFILL_IN_KV_CACHE",
"SUPPORTS_WINDOWING",
"attention",
"paged_attention",

View File

@ -1,31 +1,36 @@
import intel_extension_for_pytorch as ipex
import torch
from text_generation_server.layers.attention.kv_cache import KVCache
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
from text_generation_server.layers.attention import Seqlen
from typing import Optional
SUPPORTS_WINDOWING = False
PREFILL_IN_KV_CACHE = False
def attention(
q: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
*,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: KVCache,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale,
window_size_left=-1,
causal=True,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap: Optional[float] = None,
):
out = torch.empty_like(q)
if softcap is not None:
raise NotImplementedError("softcap is not available in IPEX")
out = torch.empty_like(query)
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
ipex.llm.functional.varlen_attention(
q.contiguous() if q.device.type == "xpu" else q,
key_cache.contiguous() if key_cache.device.type == "xpu" else key_cache,
value_cache.contiguous() if value_cache.device.type == "xpu" else value_cache,
query.contiguous() if query.device.type == "xpu" else query,
key.contiguous() if key.device.type == "xpu" else key,
value.contiguous() if value.device.type == "xpu" else value,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
@ -56,8 +61,7 @@ def reshape_and_cache(
def paged_attention(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
kv_cache: KVCache,
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
@ -65,13 +69,16 @@ def paged_attention(
max_s: int,
softcap: Optional[float] = None,
):
if softcap is not None:
raise NotImplementedError("softcap is not available in IPEX")
out = torch.empty_like(query)
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
out,
query,
key_cache,
value_cache,
kv_cache.key,
kv_cache.value,
kv_head_mapping,
softmax_scale,
block_tables,
@ -84,7 +91,6 @@ def paged_attention(
__all__ = [
"PREFILL_IN_KV_CACHE",
"SUPPORTS_WINDOWING",
"attention",
"paged_attention",

View File

@ -3,7 +3,6 @@ from typing import Tuple
import torch
from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import reshape_and_cache
class KVCache:
@ -116,4 +115,6 @@ class KVCache:
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else:
from text_generation_server.layers.attention import reshape_and_cache
reshape_and_cache(key, value, key_cache, value_cache, slots)

View File

@ -1,6 +1,7 @@
import os
from typing import Optional
import torch
from text_generation_server.layers.attention.kv_cache import KVCache
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import ATTENTION
from text_generation_server.layers.attention import Seqlen
@ -16,8 +17,6 @@ _PARTITION_SIZE_CUSTOM = 256
use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"}
ENGINE = "triton" if use_triton else "ck"
PREFILL_IN_KV_CACHE = False
use_rocm_custom_paged_attn = os.getenv("ROCM_USE_CUSTOM_PAGED_ATTN", "1") != "0"
try:
if use_rocm_custom_paged_attn:
@ -54,8 +53,7 @@ def reshape_and_cache(
def paged_attention(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
kv_cache: KVCache,
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
@ -84,10 +82,10 @@ def paged_attention(
raise RuntimeError("Paged attention doesn't support softcapping")
# value_cache => [num_blocks, num_heads, head_size, block_size]
block_size = value_cache.shape[3]
block_size = kv_cache.value.shape[3]
num_seqs, num_heads, head_size = query.shape
num_kv_heads = key_cache.shape[1]
num_kv_heads = kv_cache.key.shape[1]
gqa_ratio = num_heads // num_kv_heads
use_custom = (
use_rocm_custom_paged_attn
@ -124,8 +122,8 @@ def paged_attention(
ops.paged_attention_v1(
out,
query,
key_cache,
value_cache,
kv_cache.key,
kv_cache.value,
kv_head_mapping,
softmax_scale,
block_tables,
@ -158,8 +156,8 @@ def paged_attention(
max_logits,
tmp_output,
query,
key_cache,
value_cache,
kv_cache.key,
kv_cache.value,
kv_head_mapping,
softmax_scale,
block_tables,
@ -177,8 +175,8 @@ def paged_attention(
max_logits,
tmp_output,
query,
key_cache,
value_cache,
kv_cache.key,
kv_cache.value,
num_kv_heads,
softmax_scale,
block_tables,
@ -227,29 +225,35 @@ if ENGINE != "triton":
SUPPORTS_WINDOWING = False
if ENGINE == "ck":
def attention(
q,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
def attention(
*,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: KVCache,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap: float = 0.0,
):
softcap: Optional[float] = None,
):
if ENGINE == "ck":
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
out = torch.empty_like(q)
out = torch.empty_like(query)
if softcap is None:
softcap = 0.0
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
return flash_attn_2_cuda.varlen_fwd(
q,
key_cache,
value_cache,
query,
key,
value,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
@ -270,30 +274,19 @@ if ENGINE == "ck":
None,
)[0]
elif ENGINE == "triton":
elif ENGINE == "triton":
from .flash_attn_triton import triton_attention
def attention(
q,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap: Optional[float] = None,
):
if softcap is not None:
raise NotImplementedError("softcap is only available with CK flash attn")
out = torch.empty_like(q)
out = torch.empty_like(query)
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
output, _ = triton_attention(
q,
key_cache,
value_cache,
query,
key,
value,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
@ -304,11 +297,11 @@ elif ENGINE == "triton":
)
return output
else:
else:
raise RuntimeError(f"Unknown attention engine {ENGINE}")
__all__ = [
"PREFILL_IN_KV_CACHE",
"SUPPORTS_WINDOWING",
"attention",
"paged_attention",

View File

@ -38,7 +38,6 @@ from text_generation_server.layers import (
SpeculativeHead,
get_linear,
)
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import (
FastLayerNorm,
)
@ -296,19 +295,19 @@ class FlashCohereAttention(torch.nn.Module):
if cu_seqlen_prefill is not None:
# flash attention
attn_output = attention(
query,
kv_cache.key if PREFILL_IN_KV_CACHE else key,
kv_cache.value if PREFILL_IN_KV_CACHE else value,
seqlen,
block_tables,
self.softmax_scale,
query=query,
key=key,
value=value,
kv_cache=kv_cache,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache.key,
kv_cache.value,
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
block_tables,

View File

@ -29,7 +29,6 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
Seqlen,
PREFILL_IN_KV_CACHE,
)
from text_generation_server.layers import (
FastLinear,
@ -335,19 +334,19 @@ class DbrxAttention(torch.nn.Module):
if cu_seqlen_prefill is not None:
# flash attention
attn_output = attention(
query,
kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen,
block_tables,
self.softmax_scale,
query=query,
key=kv[:, 0],
value=kv[:, 1],
kv_cache=kv_cache,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache.key,
kv_cache.value,
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
block_tables,

View File

@ -34,7 +34,6 @@ from text_generation_server.layers.attention import (
attention,
paged_attention,
)
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
@ -326,19 +325,19 @@ class DeepseekV2Attention(torch.nn.Module):
if cu_seqlen_prefill is not None:
# flash attention
attn_output = attention(
query,
kv_cache.key if PREFILL_IN_KV_CACHE else key,
kv_cache.value if PREFILL_IN_KV_CACHE else value,
seqlen,
block_tables,
self.softmax_scale,
query=query,
key=key,
value=value,
kv_cache=kv_cache,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache.key,
kv_cache.value,
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
block_tables,

View File

@ -39,7 +39,6 @@ from text_generation_server.layers import (
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
)
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
FastRMSNorm,
@ -258,13 +257,13 @@ class FlashGemma2Attention(torch.nn.Module):
if cu_seqlen_prefill is not None:
# flash attention
attn_output = attention(
query,
kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen,
block_tables,
self.softmax_scale,
causal=self.causal,
query=query,
key=kv[:, 0],
value=kv[:, 1],
kv_cache=kv_cache,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
window_size_left=self.window_size,
softcap=self.softcap,
)
@ -272,8 +271,7 @@ class FlashGemma2Attention(torch.nn.Module):
else:
attn_output = paged_attention(
query,
kv_cache.key,
kv_cache.value,
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
block_tables,

View File

@ -29,7 +29,6 @@ from text_generation_server.layers.attention import (
paged_attention,
attention,
Seqlen,
PREFILL_IN_KV_CACHE,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -229,20 +228,20 @@ class FlashGemmaAttention(torch.nn.Module):
if cu_seqlen_prefill is not None:
# flash attention
attn_output = attention(
query,
kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen,
block_tables,
self.softmax_scale,
query=query,
key=kv[:, 0],
value=kv[:, 1],
kv_cache=kv_cache,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
causal=self.causal,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache.key,
kv_cache.value,
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
block_tables,

View File

@ -24,7 +24,6 @@ import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.attention import (
paged_attention,
attention,
@ -229,19 +228,19 @@ class FlashGPT2Attention(torch.nn.Module):
if cu_seqlen_prefill is not None:
# flash attention
attn_output = attention(
query,
kv_cache.key if PREFILL_IN_KV_CACHE else key,
kv_cache.value if PREFILL_IN_KV_CACHE else value,
seqlen,
block_tables,
self.softmax_scale,
query=query,
key=key,
value=value,
kv_cache=kv_cache,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache.key,
kv_cache.value,
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
block_tables,

View File

@ -37,7 +37,6 @@ from text_generation_server.layers import (
SpeculativeHead,
get_linear,
)
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.rotary import (
PositionRotaryEmbedding,
)
@ -191,19 +190,19 @@ class FlashGPTJAttention(torch.nn.Module):
if cu_seqlen_prefill is not None:
# flash attention
attn_output = attention(
query,
kv_cache.key if PREFILL_IN_KV_CACHE else key,
kv_cache.value if PREFILL_IN_KV_CACHE else value,
seqlen,
block_tables,
self.softmax_scale,
query=query,
key=key,
value=value,
kv_cache=kv_cache,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache.key,
kv_cache.value,
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
block_tables,

View File

@ -27,7 +27,7 @@ import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE, KVCache
from text_generation_server.layers.attention import KVCache
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import (
@ -227,19 +227,19 @@ class FlashLlamaAttention(torch.nn.Module):
if cu_seqlen_prefill is not None:
# flash attention
attn_output = attention(
query,
kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen,
block_tables,
self.softmax_scale,
query=query,
key=kv[:, 0],
value=kv[:, 1],
kv_cache=kv_cache,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache.key,
kv_cache.value,
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
block_tables,

View File

@ -40,7 +40,6 @@ from text_generation_server.layers import (
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
)
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
FastRMSNorm,
@ -215,20 +214,20 @@ class MistralAttention(torch.nn.Module):
if cu_seqlen_prefill is not None:
# flash attention
attn_output = attention(
query,
kv_cache.key if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
kv_cache.value if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
seqlen,
block_tables,
self.softmax_scale,
query=query,
key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1],
kv_cache=kv_cache,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
window_size_left=self.max_past,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache.key,
kv_cache.value,
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
block_tables,

View File

@ -38,7 +38,6 @@ from text_generation_server.layers.attention import (
attention,
paged_attention,
)
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers.rotary import PositionRotaryEmbedding
@ -263,20 +262,20 @@ class MixtralAttention(torch.nn.Module):
if cu_seqlen_prefill is not None:
# flash attention
attn_output = attention(
query,
kv_cache.key if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
kv_cache.value if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
seqlen,
block_tables,
self.softmax_scale,
query=query,
key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1],
kv_cache=kv_cache,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
window_size_left=self.max_past,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache.key,
kv_cache.value,
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
block_tables,

View File

@ -38,7 +38,6 @@ from text_generation_server.layers import (
SpeculativeHead,
get_linear,
)
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import (
FastLayerNorm,
)
@ -170,19 +169,19 @@ class FlashNeoxAttention(torch.nn.Module):
if cu_seqlen_prefill is not None:
# flash attention
attn_output = attention(
qkv[:, 0],
kv_cache.key if PREFILL_IN_KV_CACHE else qkv[:, 1],
kv_cache.value if PREFILL_IN_KV_CACHE else qkv[:, 2],
seqlen,
block_tables,
self.softmax_scale,
query=qkv[:, 0],
key=qkv[:, 1],
value=qkv[:, 2],
kv_cache=kv_cache,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
)
# Decode
else:
attn_output = paged_attention(
qkv[:, 0],
kv_cache.key,
kv_cache.value,
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
block_tables,

View File

@ -18,7 +18,6 @@ from text_generation_server.layers import (
SpeculativeHead,
get_linear,
)
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import (
FastLayerNorm,
)
@ -192,19 +191,19 @@ class FlashPhiAttention(torch.nn.Module):
# Prefill
if cu_seqlen_prefill is not None:
attn_output = attention(
query,
kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen,
block_tables,
self.softmax_scale,
query=query,
key=kv[:, 0],
value=kv[:, 1],
kv_cache=kv_cache,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache.key,
kv_cache.value,
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
block_tables,

View File

@ -16,7 +16,6 @@ from text_generation_server.layers import (
TensorParallelEmbedding,
SpeculativeHead,
)
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
FastRMSNorm,
@ -133,20 +132,20 @@ class Qwen2Attention(torch.nn.Module):
if cu_seqlen_prefill is not None:
# flash attention
attn_output = attention(
query,
kv_cache.key if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
kv_cache.value if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
seqlen,
block_tables,
self.softmax_scale,
query=query,
key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1],
kv_cache=kv_cache,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
window_size_left=self.max_past,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache.key,
kv_cache.value,
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
block_tables,

View File

@ -12,7 +12,6 @@ from text_generation_server.layers import (
TensorParallelRowLinear,
get_linear,
)
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import FastLayerNorm
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.attention import (
@ -205,19 +204,19 @@ class FlashRWAttention(torch.nn.Module):
if cu_seqlen_prefill is not None:
# flash attention
attn_output = attention(
query,
kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0],
kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1],
seqlen,
block_tables,
self.softmax_scale,
query=query,
key=kv[:, 0],
value=kv[:, 1],
kv_cache=kv_cache,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache.key,
kv_cache.value,
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
@ -319,19 +318,19 @@ class FlashRWLargeAttention(torch.nn.Module):
if cu_seqlen_prefill is not None:
# flash attention
attn_output = attention(
query,
kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, :, 0].contiguous(),
kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, :, 1].contiguous(),
seqlen,
block_tables,
self.softmax_scale,
query=query,
key=kv[:, :, 0],
value=kv[:, :, 1],
kv_cache=kv_cache,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache.key,
kv_cache.value,
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
block_tables,

View File

@ -17,7 +17,6 @@ from text_generation_server.layers import (
TensorParallelEmbedding,
get_linear,
)
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.gptq import GPTQWeightsLoader
from text_generation_server.layers.layernorm import (
FastLayerNorm,
@ -289,19 +288,19 @@ class FlashMQAttention(torch.nn.Module):
if cu_seqlen_prefill is not None:
# flash attention
attn_output = attention(
query,
kv_cache.key if PREFILL_IN_KV_CACHE else key_value[:, 0],
kv_cache.value if PREFILL_IN_KV_CACHE else key_value[:, 1],
seqlen,
block_tables,
self.softmax_scale,
query=query,
key=key_value[:, 0],
value=key_value[:, 1],
kv_cache=kv_cache,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache.key,
kv_cache.value,
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
block_tables,

View File

@ -38,7 +38,6 @@ from text_generation_server.layers import (
SpeculativeHead,
get_linear,
)
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import (
FastLayerNorm,
FastRMSNorm,
@ -238,20 +237,20 @@ class Starcoder2Attention(torch.nn.Module):
if cu_seqlen_prefill is not None:
# flash attention
attn_output = attention(
query,
kv_cache.key if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
kv_cache.value if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
seqlen,
block_tables,
self.softmax_scale,
query=query,
key=kv_to_cache[:, 0],
value=kv_to_cache[:, 1],
kv_cache=kv_cache,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
window_size_left=self.max_past,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache.key,
kv_cache.value,
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
block_tables,