Simplify the `attention` function (#2609)
* Simplify the `attention` function - Use one definition rather than multiple. - Add `key`/`value` arguments, so that we don't need the `PREFILL_IN_KVCACHE` constant. - Make it kwargs-only (to avoid mixing up the various `Tensor` args). * Fixup flashinfer support
This commit is contained in:
parent
5bbe1ce028
commit
59ea38cbca
|
@ -8,7 +8,6 @@ if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
|||
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
||||
if SYSTEM == "cuda":
|
||||
from .cuda import (
|
||||
PREFILL_IN_KV_CACHE,
|
||||
SUPPORTS_WINDOWING,
|
||||
attention,
|
||||
paged_attention,
|
||||
|
@ -16,7 +15,6 @@ if SYSTEM == "cuda":
|
|||
)
|
||||
elif SYSTEM == "rocm":
|
||||
from .rocm import (
|
||||
PREFILL_IN_KV_CACHE,
|
||||
SUPPORTS_WINDOWING,
|
||||
attention,
|
||||
paged_attention,
|
||||
|
@ -24,7 +22,6 @@ elif SYSTEM == "rocm":
|
|||
)
|
||||
elif SYSTEM == "ipex":
|
||||
from .ipex import (
|
||||
PREFILL_IN_KV_CACHE,
|
||||
SUPPORTS_WINDOWING,
|
||||
attention,
|
||||
paged_attention,
|
||||
|
@ -40,7 +37,6 @@ __all__ = [
|
|||
"attention",
|
||||
"paged_attention",
|
||||
"reshape_and_cache",
|
||||
"PREFILL_IN_KV_CACHE",
|
||||
"SUPPORTS_WINDOWING",
|
||||
"KVCache",
|
||||
"Seqlen",
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import torch
|
||||
from text_generation_server.layers.attention.kv_cache import KVCache
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.models.globals import (
|
||||
ATTENTION,
|
||||
|
@ -38,8 +39,7 @@ def reshape_and_cache(
|
|||
|
||||
def paged_attention(
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
kv_head_mapping: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
block_tables: torch.Tensor,
|
||||
|
@ -80,7 +80,7 @@ def paged_attention(
|
|||
|
||||
return decode_state.get().forward(
|
||||
query.contiguous(),
|
||||
paged_kv_cache=(key_cache, value_cache),
|
||||
paged_kv_cache=(kv_cache.key, kv_cache.value),
|
||||
logits_soft_cap=softcap,
|
||||
sm_scale=softmax_scale,
|
||||
)
|
||||
|
@ -98,8 +98,8 @@ def paged_attention(
|
|||
softcap = 0.0
|
||||
out = flash_attn_2_cuda.varlen_fwd(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
None,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.cu_seqlen_k,
|
||||
|
@ -135,8 +135,8 @@ def paged_attention(
|
|||
ops.paged_attention_v1(
|
||||
out,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
kv_head_mapping,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
|
@ -168,8 +168,8 @@ def paged_attention(
|
|||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
kv_head_mapping,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
|
@ -216,61 +216,63 @@ except ImportError:
|
|||
) from e
|
||||
|
||||
|
||||
if ATTENTION == "flashdecoding" and not V2:
|
||||
raise ValueError("Flash decoding requires Flash Attention V2")
|
||||
|
||||
SUPPORTS_WINDOWING = V2
|
||||
|
||||
if ATTENTION == "flashinfer":
|
||||
|
||||
def attention(
|
||||
q: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
def attention(
|
||||
*,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
seqlen: Seqlen,
|
||||
block_tables: torch.Tensor,
|
||||
softmax_scale,
|
||||
window_size_left=-1,
|
||||
causal=True,
|
||||
softcap=0.0,
|
||||
):
|
||||
softmax_scale: float,
|
||||
window_size_left: int = -1,
|
||||
causal: bool = True,
|
||||
softcap: Optional[float] = None,
|
||||
):
|
||||
if ATTENTION == "flashinfer":
|
||||
from text_generation_server.layers.attention.flashinfer import (
|
||||
prefill_with_paged_kv_state,
|
||||
)
|
||||
|
||||
if softcap is None:
|
||||
softcap = 0.0
|
||||
|
||||
return prefill_with_paged_kv_state.get().forward(
|
||||
q.contiguous(),
|
||||
query.contiguous(),
|
||||
causal=causal,
|
||||
paged_kv_cache=(key_cache, value_cache),
|
||||
paged_kv_cache=(kv_cache.key, kv_cache.value),
|
||||
logits_soft_cap=softcap,
|
||||
sm_scale=softmax_scale,
|
||||
window_left=window_size_left,
|
||||
)
|
||||
|
||||
elif ATTENTION == "flashdecoding":
|
||||
if V2:
|
||||
|
||||
def attention(
|
||||
q,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
block_tables: torch.Tensor,
|
||||
softmax_scale,
|
||||
window_size_left=-1,
|
||||
causal=True,
|
||||
softcap=0.0,
|
||||
):
|
||||
out = torch.empty_like(q)
|
||||
# If we are using flashdecoding or paged, we always use flash-attn for
|
||||
# the prefill. We have to branch on whether we use flash-attn v1 or v2.
|
||||
elif V2:
|
||||
out = torch.empty_like(query)
|
||||
if window_size_left <= 0 and window_size_left != -1:
|
||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||
|
||||
if softcap is None:
|
||||
softcap = 0.0
|
||||
|
||||
return flash_attn_2_cuda.varlen_fwd(
|
||||
q,
|
||||
key_cache,
|
||||
value_cache,
|
||||
query,
|
||||
# flashdecoding: pass the KV caches, paged: pass the KV.
|
||||
kv_cache.key if ATTENTION == "flashdecoding" else key,
|
||||
kv_cache.value if ATTENTION == "flashdecoding" else value,
|
||||
out,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.cu_seqlen_k,
|
||||
None,
|
||||
None,
|
||||
block_tables,
|
||||
block_tables if ATTENTION == "flashdecoding" else None,
|
||||
None,
|
||||
seqlen.max_q,
|
||||
seqlen.max_k,
|
||||
|
@ -286,58 +288,44 @@ elif ATTENTION == "flashdecoding":
|
|||
)[0]
|
||||
|
||||
else:
|
||||
|
||||
def attention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
block_tables: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
window_size_left: int = -1,
|
||||
causal: bool = True,
|
||||
softcap=None,
|
||||
):
|
||||
if window_size_left != -1:
|
||||
raise NotImplementedError(
|
||||
"window_size_left is only available with flash attn v2"
|
||||
)
|
||||
if softcap is not None:
|
||||
raise NotImplementedError(
|
||||
"softcap is only available with flash attn v2"
|
||||
)
|
||||
raise NotImplementedError("softcap is not available in flash attn v1")
|
||||
|
||||
# Flash attention v1 requires q, k and v to have the same number of heads
|
||||
if k.shape[1] != q.shape[1]:
|
||||
if key.shape[1] != query.shape[1]:
|
||||
# MQA expand
|
||||
if k.shape[1] == 1:
|
||||
k = k.expand(-1, q.shape[1], -1)
|
||||
if key.shape[1] == 1:
|
||||
key = key.expand(-1, query.shape[1], -1)
|
||||
# Grouped attention reshape
|
||||
else:
|
||||
original_shape = k.shape
|
||||
k = (
|
||||
k.unsqueeze(2)
|
||||
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
|
||||
original_shape = key.shape
|
||||
key = (
|
||||
key.unsqueeze(2)
|
||||
.expand(-1, -1, query.shape[1] // key.shape[1], -1)
|
||||
.reshape(original_shape[0], -1, original_shape[2])
|
||||
)
|
||||
if v.shape[1] != q.shape[1]:
|
||||
if value.shape[1] != query.shape[1]:
|
||||
# MQA expand
|
||||
if v.shape[1] == 1:
|
||||
v = v.expand(-1, q.shape[1], -1)
|
||||
if value.shape[1] == 1:
|
||||
value = value.expand(-1, query.shape[1], -1)
|
||||
# Grouped attention reshape
|
||||
else:
|
||||
original_shape = v.shape
|
||||
v = (
|
||||
v.unsqueeze(2)
|
||||
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
|
||||
original_shape = value.shape
|
||||
value = (
|
||||
value.unsqueeze(2)
|
||||
.expand(-1, -1, query.shape[1] // value.shape[1], -1)
|
||||
.reshape(original_shape[0], -1, original_shape[2])
|
||||
)
|
||||
|
||||
out = torch.empty_like(q)
|
||||
out = torch.empty_like(query)
|
||||
flash_attn_cuda.fwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.cu_seqlen_q,
|
||||
|
@ -353,126 +341,8 @@ elif ATTENTION == "flashdecoding":
|
|||
)
|
||||
return out
|
||||
|
||||
elif ATTENTION == "paged":
|
||||
if V2:
|
||||
|
||||
def attention(
|
||||
q,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
block_tables: torch.Tensor,
|
||||
softmax_scale,
|
||||
window_size_left=-1,
|
||||
causal=True,
|
||||
softcap=0.0,
|
||||
):
|
||||
out = torch.empty_like(q)
|
||||
if window_size_left <= 0 and window_size_left != -1:
|
||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||
return flash_attn_2_cuda.varlen_fwd(
|
||||
q,
|
||||
key_cache,
|
||||
value_cache,
|
||||
out,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.cu_seqlen_k,
|
||||
None,
|
||||
None,
|
||||
None, # block_tables,
|
||||
None,
|
||||
seqlen.max_q,
|
||||
seqlen.max_k,
|
||||
0.0,
|
||||
softmax_scale,
|
||||
False,
|
||||
causal,
|
||||
window_size_left,
|
||||
0,
|
||||
softcap,
|
||||
False,
|
||||
None,
|
||||
)[0]
|
||||
|
||||
else:
|
||||
|
||||
def attention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
block_tables: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
window_size_left: int = -1,
|
||||
causal: bool = True,
|
||||
softcap=None,
|
||||
):
|
||||
if window_size_left != -1:
|
||||
raise NotImplementedError(
|
||||
"window_size_left is only available with flash attn v2"
|
||||
)
|
||||
if softcap is not None:
|
||||
raise NotImplementedError(
|
||||
"softcap is only available with flash attn v2"
|
||||
)
|
||||
|
||||
# Flash attention v1 requires q, k and v to have the same number of heads
|
||||
if k.shape[1] != q.shape[1]:
|
||||
# MQA expand
|
||||
if k.shape[1] == 1:
|
||||
k = k.expand(-1, q.shape[1], -1)
|
||||
# Grouped attention reshape
|
||||
else:
|
||||
original_shape = k.shape
|
||||
k = (
|
||||
k.unsqueeze(2)
|
||||
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
|
||||
.reshape(original_shape[0], -1, original_shape[2])
|
||||
)
|
||||
if v.shape[1] != q.shape[1]:
|
||||
# MQA expand
|
||||
if v.shape[1] == 1:
|
||||
v = v.expand(-1, q.shape[1], -1)
|
||||
# Grouped attention reshape
|
||||
else:
|
||||
original_shape = v.shape
|
||||
v = (
|
||||
v.unsqueeze(2)
|
||||
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
|
||||
.reshape(original_shape[0], -1, original_shape[2])
|
||||
)
|
||||
|
||||
out = torch.empty_like(q)
|
||||
flash_attn_cuda.fwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.max_q,
|
||||
seqlen.max_k,
|
||||
0.0,
|
||||
softmax_scale,
|
||||
False,
|
||||
causal,
|
||||
False,
|
||||
0,
|
||||
None,
|
||||
)
|
||||
return out
|
||||
|
||||
else:
|
||||
raise RuntimeError(f"Unknwon attention {ATTENTION}")
|
||||
|
||||
|
||||
# Prefill in the cache with every kind of attention, unless we
|
||||
# have a configuration that requires flash-attention v1, which
|
||||
# does not support block tables.
|
||||
PREFILL_IN_KV_CACHE = ATTENTION == "flashinfer" or (ATTENTION == "flashdecoding" and V2)
|
||||
|
||||
__all__ = [
|
||||
"PREFILL_IN_KV_CACHE",
|
||||
"SUPPORTS_WINDOWING",
|
||||
"attention",
|
||||
"paged_attention",
|
||||
|
|
|
@ -1,31 +1,36 @@
|
|||
import intel_extension_for_pytorch as ipex
|
||||
import torch
|
||||
from text_generation_server.layers.attention.kv_cache import KVCache
|
||||
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
|
||||
from text_generation_server.layers.attention import Seqlen
|
||||
from typing import Optional
|
||||
|
||||
SUPPORTS_WINDOWING = False
|
||||
PREFILL_IN_KV_CACHE = False
|
||||
|
||||
|
||||
def attention(
|
||||
q: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
*,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
seqlen: Seqlen,
|
||||
block_tables: torch.Tensor,
|
||||
softmax_scale,
|
||||
window_size_left=-1,
|
||||
causal=True,
|
||||
softmax_scale: float,
|
||||
window_size_left: int = -1,
|
||||
causal: bool = True,
|
||||
softcap: Optional[float] = None,
|
||||
):
|
||||
out = torch.empty_like(q)
|
||||
if softcap is not None:
|
||||
raise NotImplementedError("softcap is not available in IPEX")
|
||||
|
||||
out = torch.empty_like(query)
|
||||
|
||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||
ipex.llm.functional.varlen_attention(
|
||||
q.contiguous() if q.device.type == "xpu" else q,
|
||||
key_cache.contiguous() if key_cache.device.type == "xpu" else key_cache,
|
||||
value_cache.contiguous() if value_cache.device.type == "xpu" else value_cache,
|
||||
query.contiguous() if query.device.type == "xpu" else query,
|
||||
key.contiguous() if key.device.type == "xpu" else key,
|
||||
value.contiguous() if value.device.type == "xpu" else value,
|
||||
out,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.cu_seqlen_q,
|
||||
|
@ -56,8 +61,7 @@ def reshape_and_cache(
|
|||
|
||||
def paged_attention(
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
kv_head_mapping: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
block_tables: torch.Tensor,
|
||||
|
@ -65,13 +69,16 @@ def paged_attention(
|
|||
max_s: int,
|
||||
softcap: Optional[float] = None,
|
||||
):
|
||||
if softcap is not None:
|
||||
raise NotImplementedError("softcap is not available in IPEX")
|
||||
|
||||
out = torch.empty_like(query)
|
||||
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
|
||||
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
||||
out,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
kv_head_mapping,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
|
@ -84,7 +91,6 @@ def paged_attention(
|
|||
|
||||
|
||||
__all__ = [
|
||||
"PREFILL_IN_KV_CACHE",
|
||||
"SUPPORTS_WINDOWING",
|
||||
"attention",
|
||||
"paged_attention",
|
||||
|
|
|
@ -3,7 +3,6 @@ from typing import Tuple
|
|||
import torch
|
||||
from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.layers.attention import reshape_and_cache
|
||||
|
||||
|
||||
class KVCache:
|
||||
|
@ -116,4 +115,6 @@ class KVCache:
|
|||
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
||||
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
||||
else:
|
||||
from text_generation_server.layers.attention import reshape_and_cache
|
||||
|
||||
reshape_and_cache(key, value, key_cache, value_cache, slots)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import os
|
||||
from typing import Optional
|
||||
import torch
|
||||
from text_generation_server.layers.attention.kv_cache import KVCache
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.models.globals import ATTENTION
|
||||
from text_generation_server.layers.attention import Seqlen
|
||||
|
@ -16,8 +17,6 @@ _PARTITION_SIZE_CUSTOM = 256
|
|||
use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"}
|
||||
ENGINE = "triton" if use_triton else "ck"
|
||||
|
||||
PREFILL_IN_KV_CACHE = False
|
||||
|
||||
use_rocm_custom_paged_attn = os.getenv("ROCM_USE_CUSTOM_PAGED_ATTN", "1") != "0"
|
||||
try:
|
||||
if use_rocm_custom_paged_attn:
|
||||
|
@ -54,8 +53,7 @@ def reshape_and_cache(
|
|||
|
||||
def paged_attention(
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
kv_head_mapping: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
block_tables: torch.Tensor,
|
||||
|
@ -84,10 +82,10 @@ def paged_attention(
|
|||
raise RuntimeError("Paged attention doesn't support softcapping")
|
||||
|
||||
# value_cache => [num_blocks, num_heads, head_size, block_size]
|
||||
block_size = value_cache.shape[3]
|
||||
block_size = kv_cache.value.shape[3]
|
||||
num_seqs, num_heads, head_size = query.shape
|
||||
|
||||
num_kv_heads = key_cache.shape[1]
|
||||
num_kv_heads = kv_cache.key.shape[1]
|
||||
gqa_ratio = num_heads // num_kv_heads
|
||||
use_custom = (
|
||||
use_rocm_custom_paged_attn
|
||||
|
@ -124,8 +122,8 @@ def paged_attention(
|
|||
ops.paged_attention_v1(
|
||||
out,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
kv_head_mapping,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
|
@ -158,8 +156,8 @@ def paged_attention(
|
|||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
kv_head_mapping,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
|
@ -177,8 +175,8 @@ def paged_attention(
|
|||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
num_kv_heads,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
|
@ -227,29 +225,35 @@ if ENGINE != "triton":
|
|||
|
||||
|
||||
SUPPORTS_WINDOWING = False
|
||||
if ENGINE == "ck":
|
||||
|
||||
def attention(
|
||||
q,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
|
||||
def attention(
|
||||
*,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
seqlen: Seqlen,
|
||||
block_tables: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
window_size_left: int = -1,
|
||||
causal: bool = True,
|
||||
softcap: float = 0.0,
|
||||
):
|
||||
softcap: Optional[float] = None,
|
||||
):
|
||||
if ENGINE == "ck":
|
||||
if window_size_left <= 0 and window_size_left != -1:
|
||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||
|
||||
out = torch.empty_like(q)
|
||||
out = torch.empty_like(query)
|
||||
|
||||
if softcap is None:
|
||||
softcap = 0.0
|
||||
|
||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||
return flash_attn_2_cuda.varlen_fwd(
|
||||
q,
|
||||
key_cache,
|
||||
value_cache,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.cu_seqlen_q,
|
||||
|
@ -270,30 +274,19 @@ if ENGINE == "ck":
|
|||
None,
|
||||
)[0]
|
||||
|
||||
elif ENGINE == "triton":
|
||||
elif ENGINE == "triton":
|
||||
from .flash_attn_triton import triton_attention
|
||||
|
||||
def attention(
|
||||
q,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
block_tables: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
window_size_left: int = -1,
|
||||
causal: bool = True,
|
||||
softcap: Optional[float] = None,
|
||||
):
|
||||
if softcap is not None:
|
||||
raise NotImplementedError("softcap is only available with CK flash attn")
|
||||
|
||||
out = torch.empty_like(q)
|
||||
out = torch.empty_like(query)
|
||||
|
||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||
output, _ = triton_attention(
|
||||
q,
|
||||
key_cache,
|
||||
value_cache,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.cu_seqlen_q,
|
||||
|
@ -304,11 +297,11 @@ elif ENGINE == "triton":
|
|||
)
|
||||
return output
|
||||
|
||||
else:
|
||||
else:
|
||||
raise RuntimeError(f"Unknown attention engine {ENGINE}")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PREFILL_IN_KV_CACHE",
|
||||
"SUPPORTS_WINDOWING",
|
||||
"attention",
|
||||
"paged_attention",
|
||||
|
|
|
@ -38,7 +38,6 @@ from text_generation_server.layers import (
|
|||
SpeculativeHead,
|
||||
get_linear,
|
||||
)
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
from text_generation_server.layers.layernorm import (
|
||||
FastLayerNorm,
|
||||
)
|
||||
|
@ -296,19 +295,19 @@ class FlashCohereAttention(torch.nn.Module):
|
|||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache.key if PREFILL_IN_KV_CACHE else key,
|
||||
kv_cache.value if PREFILL_IN_KV_CACHE else value,
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
kv_cache=kv_cache,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
|
|
|
@ -29,7 +29,6 @@ from text_generation_server.layers.attention import (
|
|||
paged_attention,
|
||||
attention,
|
||||
Seqlen,
|
||||
PREFILL_IN_KV_CACHE,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
FastLinear,
|
||||
|
@ -335,19 +334,19 @@ class DbrxAttention(torch.nn.Module):
|
|||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||
kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
query=query,
|
||||
key=kv[:, 0],
|
||||
value=kv[:, 1],
|
||||
kv_cache=kv_cache,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
|
|
|
@ -34,7 +34,6 @@ from text_generation_server.layers.attention import (
|
|||
attention,
|
||||
paged_attention,
|
||||
)
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
from text_generation_server.layers.layernorm import FastRMSNorm
|
||||
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
|
||||
|
@ -326,19 +325,19 @@ class DeepseekV2Attention(torch.nn.Module):
|
|||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache.key if PREFILL_IN_KV_CACHE else key,
|
||||
kv_cache.value if PREFILL_IN_KV_CACHE else value,
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
kv_cache=kv_cache,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
|
|
|
@ -39,7 +39,6 @@ from text_generation_server.layers import (
|
|||
TensorParallelMultiAdapterLinear,
|
||||
TensorParallelAdapterRowLinear,
|
||||
)
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||
from text_generation_server.layers.layernorm import (
|
||||
FastRMSNorm,
|
||||
|
@ -258,13 +257,13 @@ class FlashGemma2Attention(torch.nn.Module):
|
|||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||
kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
causal=self.causal,
|
||||
query=query,
|
||||
key=kv[:, 0],
|
||||
value=kv[:, 1],
|
||||
kv_cache=kv_cache,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
window_size_left=self.window_size,
|
||||
softcap=self.softcap,
|
||||
)
|
||||
|
@ -272,8 +271,7 @@ class FlashGemma2Attention(torch.nn.Module):
|
|||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
|
|
|
@ -29,7 +29,6 @@ from text_generation_server.layers.attention import (
|
|||
paged_attention,
|
||||
attention,
|
||||
Seqlen,
|
||||
PREFILL_IN_KV_CACHE,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
|
@ -229,20 +228,20 @@ class FlashGemmaAttention(torch.nn.Module):
|
|||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||
kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
query=query,
|
||||
key=kv[:, 0],
|
||||
value=kv[:, 1],
|
||||
kv_cache=kv_cache,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
causal=self.causal,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
|
|
|
@ -24,7 +24,6 @@ import torch.distributed
|
|||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
from typing import Optional, List, Tuple
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
|
@ -229,19 +228,19 @@ class FlashGPT2Attention(torch.nn.Module):
|
|||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache.key if PREFILL_IN_KV_CACHE else key,
|
||||
kv_cache.value if PREFILL_IN_KV_CACHE else value,
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
kv_cache=kv_cache,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
|
|
|
@ -37,7 +37,6 @@ from text_generation_server.layers import (
|
|||
SpeculativeHead,
|
||||
get_linear,
|
||||
)
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
from text_generation_server.layers.rotary import (
|
||||
PositionRotaryEmbedding,
|
||||
)
|
||||
|
@ -191,19 +190,19 @@ class FlashGPTJAttention(torch.nn.Module):
|
|||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache.key if PREFILL_IN_KV_CACHE else key,
|
||||
kv_cache.value if PREFILL_IN_KV_CACHE else value,
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
kv_cache=kv_cache,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
|
|
|
@ -27,7 +27,7 @@ import torch.distributed
|
|||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE, KVCache
|
||||
from text_generation_server.layers.attention import KVCache
|
||||
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.layers.attention import (
|
||||
|
@ -227,19 +227,19 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||
kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
query=query,
|
||||
key=kv[:, 0],
|
||||
value=kv[:, 1],
|
||||
kv_cache=kv_cache,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
|
|
|
@ -40,7 +40,6 @@ from text_generation_server.layers import (
|
|||
TensorParallelMultiAdapterLinear,
|
||||
TensorParallelAdapterRowLinear,
|
||||
)
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||
from text_generation_server.layers.layernorm import (
|
||||
FastRMSNorm,
|
||||
|
@ -215,20 +214,20 @@ class MistralAttention(torch.nn.Module):
|
|||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache.key if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
|
||||
kv_cache.value if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
query=query,
|
||||
key=kv_to_cache[:, 0],
|
||||
value=kv_to_cache[:, 1],
|
||||
kv_cache=kv_cache,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
window_size_left=self.max_past,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
|
|
|
@ -38,7 +38,6 @@ from text_generation_server.layers.attention import (
|
|||
attention,
|
||||
paged_attention,
|
||||
)
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
from text_generation_server.layers.layernorm import FastRMSNorm
|
||||
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||
|
@ -263,20 +262,20 @@ class MixtralAttention(torch.nn.Module):
|
|||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache.key if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
|
||||
kv_cache.value if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
query=query,
|
||||
key=kv_to_cache[:, 0],
|
||||
value=kv_to_cache[:, 1],
|
||||
kv_cache=kv_cache,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
window_size_left=self.max_past,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
|
|
|
@ -38,7 +38,6 @@ from text_generation_server.layers import (
|
|||
SpeculativeHead,
|
||||
get_linear,
|
||||
)
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
from text_generation_server.layers.layernorm import (
|
||||
FastLayerNorm,
|
||||
)
|
||||
|
@ -170,19 +169,19 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
qkv[:, 0],
|
||||
kv_cache.key if PREFILL_IN_KV_CACHE else qkv[:, 1],
|
||||
kv_cache.value if PREFILL_IN_KV_CACHE else qkv[:, 2],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
query=qkv[:, 0],
|
||||
key=qkv[:, 1],
|
||||
value=qkv[:, 2],
|
||||
kv_cache=kv_cache,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
qkv[:, 0],
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
|
|
|
@ -18,7 +18,6 @@ from text_generation_server.layers import (
|
|||
SpeculativeHead,
|
||||
get_linear,
|
||||
)
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
from text_generation_server.layers.layernorm import (
|
||||
FastLayerNorm,
|
||||
)
|
||||
|
@ -192,19 +191,19 @@ class FlashPhiAttention(torch.nn.Module):
|
|||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||
kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
query=query,
|
||||
key=kv[:, 0],
|
||||
value=kv[:, 1],
|
||||
kv_cache=kv_cache,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
|
|
|
@ -16,7 +16,6 @@ from text_generation_server.layers import (
|
|||
TensorParallelEmbedding,
|
||||
SpeculativeHead,
|
||||
)
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||
from text_generation_server.layers.layernorm import (
|
||||
FastRMSNorm,
|
||||
|
@ -133,20 +132,20 @@ class Qwen2Attention(torch.nn.Module):
|
|||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache.key if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
|
||||
kv_cache.value if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
query=query,
|
||||
key=kv_to_cache[:, 0],
|
||||
value=kv_to_cache[:, 1],
|
||||
kv_cache=kv_cache,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
window_size_left=self.max_past,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
|
|
|
@ -12,7 +12,6 @@ from text_generation_server.layers import (
|
|||
TensorParallelRowLinear,
|
||||
get_linear,
|
||||
)
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
from text_generation_server.layers.layernorm import FastLayerNorm
|
||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||
from text_generation_server.layers.attention import (
|
||||
|
@ -205,19 +204,19 @@ class FlashRWAttention(torch.nn.Module):
|
|||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||
kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
query=query,
|
||||
key=kv[:, 0],
|
||||
value=kv[:, 1],
|
||||
kv_cache=kv_cache,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
|
@ -319,19 +318,19 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, :, 0].contiguous(),
|
||||
kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, :, 1].contiguous(),
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
query=query,
|
||||
key=kv[:, :, 0],
|
||||
value=kv[:, :, 1],
|
||||
kv_cache=kv_cache,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
|
|
|
@ -17,7 +17,6 @@ from text_generation_server.layers import (
|
|||
TensorParallelEmbedding,
|
||||
get_linear,
|
||||
)
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
from text_generation_server.layers.gptq import GPTQWeightsLoader
|
||||
from text_generation_server.layers.layernorm import (
|
||||
FastLayerNorm,
|
||||
|
@ -289,19 +288,19 @@ class FlashMQAttention(torch.nn.Module):
|
|||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache.key if PREFILL_IN_KV_CACHE else key_value[:, 0],
|
||||
kv_cache.value if PREFILL_IN_KV_CACHE else key_value[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
query=query,
|
||||
key=key_value[:, 0],
|
||||
value=key_value[:, 1],
|
||||
kv_cache=kv_cache,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
|
|
|
@ -38,7 +38,6 @@ from text_generation_server.layers import (
|
|||
SpeculativeHead,
|
||||
get_linear,
|
||||
)
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
from text_generation_server.layers.layernorm import (
|
||||
FastLayerNorm,
|
||||
FastRMSNorm,
|
||||
|
@ -238,20 +237,20 @@ class Starcoder2Attention(torch.nn.Module):
|
|||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache.key if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
|
||||
kv_cache.value if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
query=query,
|
||||
key=kv_to_cache[:, 0],
|
||||
value=kv_to_cache[:, 1],
|
||||
kv_cache=kv_cache,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
window_size_left=self.max_past,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache.key,
|
||||
kv_cache.value,
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
|
|
Loading…
Reference in New Issue