Simplify the `attention` function (#2609)
* Simplify the `attention` function - Use one definition rather than multiple. - Add `key`/`value` arguments, so that we don't need the `PREFILL_IN_KVCACHE` constant. - Make it kwargs-only (to avoid mixing up the various `Tensor` args). * Fixup flashinfer support
This commit is contained in:
parent
5bbe1ce028
commit
59ea38cbca
|
@ -8,7 +8,6 @@ if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
||||||
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
||||||
if SYSTEM == "cuda":
|
if SYSTEM == "cuda":
|
||||||
from .cuda import (
|
from .cuda import (
|
||||||
PREFILL_IN_KV_CACHE,
|
|
||||||
SUPPORTS_WINDOWING,
|
SUPPORTS_WINDOWING,
|
||||||
attention,
|
attention,
|
||||||
paged_attention,
|
paged_attention,
|
||||||
|
@ -16,7 +15,6 @@ if SYSTEM == "cuda":
|
||||||
)
|
)
|
||||||
elif SYSTEM == "rocm":
|
elif SYSTEM == "rocm":
|
||||||
from .rocm import (
|
from .rocm import (
|
||||||
PREFILL_IN_KV_CACHE,
|
|
||||||
SUPPORTS_WINDOWING,
|
SUPPORTS_WINDOWING,
|
||||||
attention,
|
attention,
|
||||||
paged_attention,
|
paged_attention,
|
||||||
|
@ -24,7 +22,6 @@ elif SYSTEM == "rocm":
|
||||||
)
|
)
|
||||||
elif SYSTEM == "ipex":
|
elif SYSTEM == "ipex":
|
||||||
from .ipex import (
|
from .ipex import (
|
||||||
PREFILL_IN_KV_CACHE,
|
|
||||||
SUPPORTS_WINDOWING,
|
SUPPORTS_WINDOWING,
|
||||||
attention,
|
attention,
|
||||||
paged_attention,
|
paged_attention,
|
||||||
|
@ -40,7 +37,6 @@ __all__ = [
|
||||||
"attention",
|
"attention",
|
||||||
"paged_attention",
|
"paged_attention",
|
||||||
"reshape_and_cache",
|
"reshape_and_cache",
|
||||||
"PREFILL_IN_KV_CACHE",
|
|
||||||
"SUPPORTS_WINDOWING",
|
"SUPPORTS_WINDOWING",
|
||||||
"KVCache",
|
"KVCache",
|
||||||
"Seqlen",
|
"Seqlen",
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import torch
|
import torch
|
||||||
|
from text_generation_server.layers.attention.kv_cache import KVCache
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.models.globals import (
|
from text_generation_server.models.globals import (
|
||||||
ATTENTION,
|
ATTENTION,
|
||||||
|
@ -38,8 +39,7 @@ def reshape_and_cache(
|
||||||
|
|
||||||
def paged_attention(
|
def paged_attention(
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key_cache: torch.Tensor,
|
kv_cache: KVCache,
|
||||||
value_cache: torch.Tensor,
|
|
||||||
kv_head_mapping: torch.Tensor,
|
kv_head_mapping: torch.Tensor,
|
||||||
softmax_scale: float,
|
softmax_scale: float,
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
|
@ -80,7 +80,7 @@ def paged_attention(
|
||||||
|
|
||||||
return decode_state.get().forward(
|
return decode_state.get().forward(
|
||||||
query.contiguous(),
|
query.contiguous(),
|
||||||
paged_kv_cache=(key_cache, value_cache),
|
paged_kv_cache=(kv_cache.key, kv_cache.value),
|
||||||
logits_soft_cap=softcap,
|
logits_soft_cap=softcap,
|
||||||
sm_scale=softmax_scale,
|
sm_scale=softmax_scale,
|
||||||
)
|
)
|
||||||
|
@ -98,8 +98,8 @@ def paged_attention(
|
||||||
softcap = 0.0
|
softcap = 0.0
|
||||||
out = flash_attn_2_cuda.varlen_fwd(
|
out = flash_attn_2_cuda.varlen_fwd(
|
||||||
query,
|
query,
|
||||||
key_cache,
|
kv_cache.key,
|
||||||
value_cache,
|
kv_cache.value,
|
||||||
None,
|
None,
|
||||||
seqlen.cu_seqlen_q,
|
seqlen.cu_seqlen_q,
|
||||||
seqlen.cu_seqlen_k,
|
seqlen.cu_seqlen_k,
|
||||||
|
@ -135,8 +135,8 @@ def paged_attention(
|
||||||
ops.paged_attention_v1(
|
ops.paged_attention_v1(
|
||||||
out,
|
out,
|
||||||
query,
|
query,
|
||||||
key_cache,
|
kv_cache.key,
|
||||||
value_cache,
|
kv_cache.value,
|
||||||
kv_head_mapping,
|
kv_head_mapping,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
@ -168,8 +168,8 @@ def paged_attention(
|
||||||
max_logits,
|
max_logits,
|
||||||
tmp_output,
|
tmp_output,
|
||||||
query,
|
query,
|
||||||
key_cache,
|
kv_cache.key,
|
||||||
value_cache,
|
kv_cache.value,
|
||||||
kv_head_mapping,
|
kv_head_mapping,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
@ -216,61 +216,63 @@ except ImportError:
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
if ATTENTION == "flashdecoding" and not V2:
|
||||||
|
raise ValueError("Flash decoding requires Flash Attention V2")
|
||||||
|
|
||||||
SUPPORTS_WINDOWING = V2
|
SUPPORTS_WINDOWING = V2
|
||||||
|
|
||||||
if ATTENTION == "flashinfer":
|
|
||||||
|
|
||||||
def attention(
|
def attention(
|
||||||
q: torch.Tensor,
|
*,
|
||||||
key_cache: torch.Tensor,
|
query: torch.Tensor,
|
||||||
value_cache: torch.Tensor,
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
kv_cache: KVCache,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
softmax_scale,
|
softmax_scale: float,
|
||||||
window_size_left=-1,
|
window_size_left: int = -1,
|
||||||
causal=True,
|
causal: bool = True,
|
||||||
softcap=0.0,
|
softcap: Optional[float] = None,
|
||||||
):
|
):
|
||||||
|
if ATTENTION == "flashinfer":
|
||||||
from text_generation_server.layers.attention.flashinfer import (
|
from text_generation_server.layers.attention.flashinfer import (
|
||||||
prefill_with_paged_kv_state,
|
prefill_with_paged_kv_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if softcap is None:
|
||||||
|
softcap = 0.0
|
||||||
|
|
||||||
return prefill_with_paged_kv_state.get().forward(
|
return prefill_with_paged_kv_state.get().forward(
|
||||||
q.contiguous(),
|
query.contiguous(),
|
||||||
causal=causal,
|
causal=causal,
|
||||||
paged_kv_cache=(key_cache, value_cache),
|
paged_kv_cache=(kv_cache.key, kv_cache.value),
|
||||||
logits_soft_cap=softcap,
|
logits_soft_cap=softcap,
|
||||||
sm_scale=softmax_scale,
|
sm_scale=softmax_scale,
|
||||||
window_left=window_size_left,
|
window_left=window_size_left,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif ATTENTION == "flashdecoding":
|
# If we are using flashdecoding or paged, we always use flash-attn for
|
||||||
if V2:
|
# the prefill. We have to branch on whether we use flash-attn v1 or v2.
|
||||||
|
elif V2:
|
||||||
def attention(
|
out = torch.empty_like(query)
|
||||||
q,
|
|
||||||
key_cache: torch.Tensor,
|
|
||||||
value_cache: torch.Tensor,
|
|
||||||
seqlen: Seqlen,
|
|
||||||
block_tables: torch.Tensor,
|
|
||||||
softmax_scale,
|
|
||||||
window_size_left=-1,
|
|
||||||
causal=True,
|
|
||||||
softcap=0.0,
|
|
||||||
):
|
|
||||||
out = torch.empty_like(q)
|
|
||||||
if window_size_left <= 0 and window_size_left != -1:
|
if window_size_left <= 0 and window_size_left != -1:
|
||||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||||
|
|
||||||
|
if softcap is None:
|
||||||
|
softcap = 0.0
|
||||||
|
|
||||||
return flash_attn_2_cuda.varlen_fwd(
|
return flash_attn_2_cuda.varlen_fwd(
|
||||||
q,
|
query,
|
||||||
key_cache,
|
# flashdecoding: pass the KV caches, paged: pass the KV.
|
||||||
value_cache,
|
kv_cache.key if ATTENTION == "flashdecoding" else key,
|
||||||
|
kv_cache.value if ATTENTION == "flashdecoding" else value,
|
||||||
out,
|
out,
|
||||||
seqlen.cu_seqlen_q,
|
seqlen.cu_seqlen_q,
|
||||||
seqlen.cu_seqlen_k,
|
seqlen.cu_seqlen_k,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
block_tables,
|
block_tables if ATTENTION == "flashdecoding" else None,
|
||||||
None,
|
None,
|
||||||
seqlen.max_q,
|
seqlen.max_q,
|
||||||
seqlen.max_k,
|
seqlen.max_k,
|
||||||
|
@ -286,58 +288,44 @@ elif ATTENTION == "flashdecoding":
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
def attention(
|
|
||||||
q: torch.Tensor,
|
|
||||||
k: torch.Tensor,
|
|
||||||
v: torch.Tensor,
|
|
||||||
seqlen: Seqlen,
|
|
||||||
block_tables: torch.Tensor,
|
|
||||||
softmax_scale: float,
|
|
||||||
window_size_left: int = -1,
|
|
||||||
causal: bool = True,
|
|
||||||
softcap=None,
|
|
||||||
):
|
|
||||||
if window_size_left != -1:
|
if window_size_left != -1:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"window_size_left is only available with flash attn v2"
|
"window_size_left is only available with flash attn v2"
|
||||||
)
|
)
|
||||||
if softcap is not None:
|
if softcap is not None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError("softcap is not available in flash attn v1")
|
||||||
"softcap is only available with flash attn v2"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Flash attention v1 requires q, k and v to have the same number of heads
|
# Flash attention v1 requires q, k and v to have the same number of heads
|
||||||
if k.shape[1] != q.shape[1]:
|
if key.shape[1] != query.shape[1]:
|
||||||
# MQA expand
|
# MQA expand
|
||||||
if k.shape[1] == 1:
|
if key.shape[1] == 1:
|
||||||
k = k.expand(-1, q.shape[1], -1)
|
key = key.expand(-1, query.shape[1], -1)
|
||||||
# Grouped attention reshape
|
# Grouped attention reshape
|
||||||
else:
|
else:
|
||||||
original_shape = k.shape
|
original_shape = key.shape
|
||||||
k = (
|
key = (
|
||||||
k.unsqueeze(2)
|
key.unsqueeze(2)
|
||||||
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
|
.expand(-1, -1, query.shape[1] // key.shape[1], -1)
|
||||||
.reshape(original_shape[0], -1, original_shape[2])
|
.reshape(original_shape[0], -1, original_shape[2])
|
||||||
)
|
)
|
||||||
if v.shape[1] != q.shape[1]:
|
if value.shape[1] != query.shape[1]:
|
||||||
# MQA expand
|
# MQA expand
|
||||||
if v.shape[1] == 1:
|
if value.shape[1] == 1:
|
||||||
v = v.expand(-1, q.shape[1], -1)
|
value = value.expand(-1, query.shape[1], -1)
|
||||||
# Grouped attention reshape
|
# Grouped attention reshape
|
||||||
else:
|
else:
|
||||||
original_shape = v.shape
|
original_shape = value.shape
|
||||||
v = (
|
value = (
|
||||||
v.unsqueeze(2)
|
value.unsqueeze(2)
|
||||||
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
|
.expand(-1, -1, query.shape[1] // value.shape[1], -1)
|
||||||
.reshape(original_shape[0], -1, original_shape[2])
|
.reshape(original_shape[0], -1, original_shape[2])
|
||||||
)
|
)
|
||||||
|
|
||||||
out = torch.empty_like(q)
|
out = torch.empty_like(query)
|
||||||
flash_attn_cuda.fwd(
|
flash_attn_cuda.fwd(
|
||||||
q,
|
query,
|
||||||
k,
|
key,
|
||||||
v,
|
value,
|
||||||
out,
|
out,
|
||||||
seqlen.cu_seqlen_q,
|
seqlen.cu_seqlen_q,
|
||||||
seqlen.cu_seqlen_q,
|
seqlen.cu_seqlen_q,
|
||||||
|
@ -353,126 +341,8 @@ elif ATTENTION == "flashdecoding":
|
||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
elif ATTENTION == "paged":
|
|
||||||
if V2:
|
|
||||||
|
|
||||||
def attention(
|
|
||||||
q,
|
|
||||||
key_cache: torch.Tensor,
|
|
||||||
value_cache: torch.Tensor,
|
|
||||||
seqlen: Seqlen,
|
|
||||||
block_tables: torch.Tensor,
|
|
||||||
softmax_scale,
|
|
||||||
window_size_left=-1,
|
|
||||||
causal=True,
|
|
||||||
softcap=0.0,
|
|
||||||
):
|
|
||||||
out = torch.empty_like(q)
|
|
||||||
if window_size_left <= 0 and window_size_left != -1:
|
|
||||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
|
||||||
return flash_attn_2_cuda.varlen_fwd(
|
|
||||||
q,
|
|
||||||
key_cache,
|
|
||||||
value_cache,
|
|
||||||
out,
|
|
||||||
seqlen.cu_seqlen_q,
|
|
||||||
seqlen.cu_seqlen_k,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None, # block_tables,
|
|
||||||
None,
|
|
||||||
seqlen.max_q,
|
|
||||||
seqlen.max_k,
|
|
||||||
0.0,
|
|
||||||
softmax_scale,
|
|
||||||
False,
|
|
||||||
causal,
|
|
||||||
window_size_left,
|
|
||||||
0,
|
|
||||||
softcap,
|
|
||||||
False,
|
|
||||||
None,
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
else:
|
|
||||||
|
|
||||||
def attention(
|
|
||||||
q: torch.Tensor,
|
|
||||||
k: torch.Tensor,
|
|
||||||
v: torch.Tensor,
|
|
||||||
seqlen: Seqlen,
|
|
||||||
block_tables: torch.Tensor,
|
|
||||||
softmax_scale: float,
|
|
||||||
window_size_left: int = -1,
|
|
||||||
causal: bool = True,
|
|
||||||
softcap=None,
|
|
||||||
):
|
|
||||||
if window_size_left != -1:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"window_size_left is only available with flash attn v2"
|
|
||||||
)
|
|
||||||
if softcap is not None:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"softcap is only available with flash attn v2"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Flash attention v1 requires q, k and v to have the same number of heads
|
|
||||||
if k.shape[1] != q.shape[1]:
|
|
||||||
# MQA expand
|
|
||||||
if k.shape[1] == 1:
|
|
||||||
k = k.expand(-1, q.shape[1], -1)
|
|
||||||
# Grouped attention reshape
|
|
||||||
else:
|
|
||||||
original_shape = k.shape
|
|
||||||
k = (
|
|
||||||
k.unsqueeze(2)
|
|
||||||
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
|
|
||||||
.reshape(original_shape[0], -1, original_shape[2])
|
|
||||||
)
|
|
||||||
if v.shape[1] != q.shape[1]:
|
|
||||||
# MQA expand
|
|
||||||
if v.shape[1] == 1:
|
|
||||||
v = v.expand(-1, q.shape[1], -1)
|
|
||||||
# Grouped attention reshape
|
|
||||||
else:
|
|
||||||
original_shape = v.shape
|
|
||||||
v = (
|
|
||||||
v.unsqueeze(2)
|
|
||||||
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
|
|
||||||
.reshape(original_shape[0], -1, original_shape[2])
|
|
||||||
)
|
|
||||||
|
|
||||||
out = torch.empty_like(q)
|
|
||||||
flash_attn_cuda.fwd(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
out,
|
|
||||||
seqlen.cu_seqlen_q,
|
|
||||||
seqlen.cu_seqlen_q,
|
|
||||||
seqlen.max_q,
|
|
||||||
seqlen.max_k,
|
|
||||||
0.0,
|
|
||||||
softmax_scale,
|
|
||||||
False,
|
|
||||||
causal,
|
|
||||||
False,
|
|
||||||
0,
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
return out
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"Unknwon attention {ATTENTION}")
|
|
||||||
|
|
||||||
|
|
||||||
# Prefill in the cache with every kind of attention, unless we
|
|
||||||
# have a configuration that requires flash-attention v1, which
|
|
||||||
# does not support block tables.
|
|
||||||
PREFILL_IN_KV_CACHE = ATTENTION == "flashinfer" or (ATTENTION == "flashdecoding" and V2)
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"PREFILL_IN_KV_CACHE",
|
|
||||||
"SUPPORTS_WINDOWING",
|
"SUPPORTS_WINDOWING",
|
||||||
"attention",
|
"attention",
|
||||||
"paged_attention",
|
"paged_attention",
|
||||||
|
|
|
@ -1,31 +1,36 @@
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
import torch
|
import torch
|
||||||
|
from text_generation_server.layers.attention.kv_cache import KVCache
|
||||||
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
|
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
|
||||||
from text_generation_server.layers.attention import Seqlen
|
from text_generation_server.layers.attention import Seqlen
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
SUPPORTS_WINDOWING = False
|
SUPPORTS_WINDOWING = False
|
||||||
PREFILL_IN_KV_CACHE = False
|
|
||||||
|
|
||||||
|
|
||||||
def attention(
|
def attention(
|
||||||
q: torch.Tensor,
|
*,
|
||||||
key_cache: torch.Tensor,
|
query: torch.Tensor,
|
||||||
value_cache: torch.Tensor,
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
kv_cache: KVCache,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
softmax_scale,
|
softmax_scale: float,
|
||||||
window_size_left=-1,
|
window_size_left: int = -1,
|
||||||
causal=True,
|
causal: bool = True,
|
||||||
softcap: Optional[float] = None,
|
softcap: Optional[float] = None,
|
||||||
):
|
):
|
||||||
out = torch.empty_like(q)
|
if softcap is not None:
|
||||||
|
raise NotImplementedError("softcap is not available in IPEX")
|
||||||
|
|
||||||
|
out = torch.empty_like(query)
|
||||||
|
|
||||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||||
ipex.llm.functional.varlen_attention(
|
ipex.llm.functional.varlen_attention(
|
||||||
q.contiguous() if q.device.type == "xpu" else q,
|
query.contiguous() if query.device.type == "xpu" else query,
|
||||||
key_cache.contiguous() if key_cache.device.type == "xpu" else key_cache,
|
key.contiguous() if key.device.type == "xpu" else key,
|
||||||
value_cache.contiguous() if value_cache.device.type == "xpu" else value_cache,
|
value.contiguous() if value.device.type == "xpu" else value,
|
||||||
out,
|
out,
|
||||||
seqlen.cu_seqlen_q,
|
seqlen.cu_seqlen_q,
|
||||||
seqlen.cu_seqlen_q,
|
seqlen.cu_seqlen_q,
|
||||||
|
@ -56,8 +61,7 @@ def reshape_and_cache(
|
||||||
|
|
||||||
def paged_attention(
|
def paged_attention(
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key_cache: torch.Tensor,
|
kv_cache: KVCache,
|
||||||
value_cache: torch.Tensor,
|
|
||||||
kv_head_mapping: torch.Tensor,
|
kv_head_mapping: torch.Tensor,
|
||||||
softmax_scale: float,
|
softmax_scale: float,
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
|
@ -65,13 +69,16 @@ def paged_attention(
|
||||||
max_s: int,
|
max_s: int,
|
||||||
softcap: Optional[float] = None,
|
softcap: Optional[float] = None,
|
||||||
):
|
):
|
||||||
|
if softcap is not None:
|
||||||
|
raise NotImplementedError("softcap is not available in IPEX")
|
||||||
|
|
||||||
out = torch.empty_like(query)
|
out = torch.empty_like(query)
|
||||||
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
|
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
|
||||||
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
||||||
out,
|
out,
|
||||||
query,
|
query,
|
||||||
key_cache,
|
kv_cache.key,
|
||||||
value_cache,
|
kv_cache.value,
|
||||||
kv_head_mapping,
|
kv_head_mapping,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
@ -84,7 +91,6 @@ def paged_attention(
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"PREFILL_IN_KV_CACHE",
|
|
||||||
"SUPPORTS_WINDOWING",
|
"SUPPORTS_WINDOWING",
|
||||||
"attention",
|
"attention",
|
||||||
"paged_attention",
|
"paged_attention",
|
||||||
|
|
|
@ -3,7 +3,6 @@ from typing import Tuple
|
||||||
import torch
|
import torch
|
||||||
from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE
|
from text_generation_server.models.globals import ATTENTION, BLOCK_SIZE
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.layers.attention import reshape_and_cache
|
|
||||||
|
|
||||||
|
|
||||||
class KVCache:
|
class KVCache:
|
||||||
|
@ -116,4 +115,6 @@ class KVCache:
|
||||||
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
||||||
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
||||||
else:
|
else:
|
||||||
|
from text_generation_server.layers.attention import reshape_and_cache
|
||||||
|
|
||||||
reshape_and_cache(key, value, key_cache, value_cache, slots)
|
reshape_and_cache(key, value, key_cache, value_cache, slots)
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import os
|
import os
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import torch
|
import torch
|
||||||
|
from text_generation_server.layers.attention.kv_cache import KVCache
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.models.globals import ATTENTION
|
from text_generation_server.models.globals import ATTENTION
|
||||||
from text_generation_server.layers.attention import Seqlen
|
from text_generation_server.layers.attention import Seqlen
|
||||||
|
@ -16,8 +17,6 @@ _PARTITION_SIZE_CUSTOM = 256
|
||||||
use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"}
|
use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"}
|
||||||
ENGINE = "triton" if use_triton else "ck"
|
ENGINE = "triton" if use_triton else "ck"
|
||||||
|
|
||||||
PREFILL_IN_KV_CACHE = False
|
|
||||||
|
|
||||||
use_rocm_custom_paged_attn = os.getenv("ROCM_USE_CUSTOM_PAGED_ATTN", "1") != "0"
|
use_rocm_custom_paged_attn = os.getenv("ROCM_USE_CUSTOM_PAGED_ATTN", "1") != "0"
|
||||||
try:
|
try:
|
||||||
if use_rocm_custom_paged_attn:
|
if use_rocm_custom_paged_attn:
|
||||||
|
@ -54,8 +53,7 @@ def reshape_and_cache(
|
||||||
|
|
||||||
def paged_attention(
|
def paged_attention(
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key_cache: torch.Tensor,
|
kv_cache: KVCache,
|
||||||
value_cache: torch.Tensor,
|
|
||||||
kv_head_mapping: torch.Tensor,
|
kv_head_mapping: torch.Tensor,
|
||||||
softmax_scale: float,
|
softmax_scale: float,
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
|
@ -84,10 +82,10 @@ def paged_attention(
|
||||||
raise RuntimeError("Paged attention doesn't support softcapping")
|
raise RuntimeError("Paged attention doesn't support softcapping")
|
||||||
|
|
||||||
# value_cache => [num_blocks, num_heads, head_size, block_size]
|
# value_cache => [num_blocks, num_heads, head_size, block_size]
|
||||||
block_size = value_cache.shape[3]
|
block_size = kv_cache.value.shape[3]
|
||||||
num_seqs, num_heads, head_size = query.shape
|
num_seqs, num_heads, head_size = query.shape
|
||||||
|
|
||||||
num_kv_heads = key_cache.shape[1]
|
num_kv_heads = kv_cache.key.shape[1]
|
||||||
gqa_ratio = num_heads // num_kv_heads
|
gqa_ratio = num_heads // num_kv_heads
|
||||||
use_custom = (
|
use_custom = (
|
||||||
use_rocm_custom_paged_attn
|
use_rocm_custom_paged_attn
|
||||||
|
@ -124,8 +122,8 @@ def paged_attention(
|
||||||
ops.paged_attention_v1(
|
ops.paged_attention_v1(
|
||||||
out,
|
out,
|
||||||
query,
|
query,
|
||||||
key_cache,
|
kv_cache.key,
|
||||||
value_cache,
|
kv_cache.value,
|
||||||
kv_head_mapping,
|
kv_head_mapping,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
@ -158,8 +156,8 @@ def paged_attention(
|
||||||
max_logits,
|
max_logits,
|
||||||
tmp_output,
|
tmp_output,
|
||||||
query,
|
query,
|
||||||
key_cache,
|
kv_cache.key,
|
||||||
value_cache,
|
kv_cache.value,
|
||||||
kv_head_mapping,
|
kv_head_mapping,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
@ -177,8 +175,8 @@ def paged_attention(
|
||||||
max_logits,
|
max_logits,
|
||||||
tmp_output,
|
tmp_output,
|
||||||
query,
|
query,
|
||||||
key_cache,
|
kv_cache.key,
|
||||||
value_cache,
|
kv_cache.value,
|
||||||
num_kv_heads,
|
num_kv_heads,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
@ -227,29 +225,35 @@ if ENGINE != "triton":
|
||||||
|
|
||||||
|
|
||||||
SUPPORTS_WINDOWING = False
|
SUPPORTS_WINDOWING = False
|
||||||
if ENGINE == "ck":
|
|
||||||
|
|
||||||
def attention(
|
def attention(
|
||||||
q,
|
*,
|
||||||
key_cache: torch.Tensor,
|
query: torch.Tensor,
|
||||||
value_cache: torch.Tensor,
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
kv_cache: KVCache,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
softmax_scale: float,
|
softmax_scale: float,
|
||||||
window_size_left: int = -1,
|
window_size_left: int = -1,
|
||||||
causal: bool = True,
|
causal: bool = True,
|
||||||
softcap: float = 0.0,
|
softcap: Optional[float] = None,
|
||||||
):
|
):
|
||||||
|
if ENGINE == "ck":
|
||||||
if window_size_left <= 0 and window_size_left != -1:
|
if window_size_left <= 0 and window_size_left != -1:
|
||||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||||
|
|
||||||
out = torch.empty_like(q)
|
out = torch.empty_like(query)
|
||||||
|
|
||||||
|
if softcap is None:
|
||||||
|
softcap = 0.0
|
||||||
|
|
||||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||||
return flash_attn_2_cuda.varlen_fwd(
|
return flash_attn_2_cuda.varlen_fwd(
|
||||||
q,
|
query,
|
||||||
key_cache,
|
key,
|
||||||
value_cache,
|
value,
|
||||||
out,
|
out,
|
||||||
seqlen.cu_seqlen_q,
|
seqlen.cu_seqlen_q,
|
||||||
seqlen.cu_seqlen_q,
|
seqlen.cu_seqlen_q,
|
||||||
|
@ -273,27 +277,16 @@ if ENGINE == "ck":
|
||||||
elif ENGINE == "triton":
|
elif ENGINE == "triton":
|
||||||
from .flash_attn_triton import triton_attention
|
from .flash_attn_triton import triton_attention
|
||||||
|
|
||||||
def attention(
|
|
||||||
q,
|
|
||||||
key_cache: torch.Tensor,
|
|
||||||
value_cache: torch.Tensor,
|
|
||||||
seqlen: Seqlen,
|
|
||||||
block_tables: torch.Tensor,
|
|
||||||
softmax_scale: float,
|
|
||||||
window_size_left: int = -1,
|
|
||||||
causal: bool = True,
|
|
||||||
softcap: Optional[float] = None,
|
|
||||||
):
|
|
||||||
if softcap is not None:
|
if softcap is not None:
|
||||||
raise NotImplementedError("softcap is only available with CK flash attn")
|
raise NotImplementedError("softcap is only available with CK flash attn")
|
||||||
|
|
||||||
out = torch.empty_like(q)
|
out = torch.empty_like(query)
|
||||||
|
|
||||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||||
output, _ = triton_attention(
|
output, _ = triton_attention(
|
||||||
q,
|
query,
|
||||||
key_cache,
|
key,
|
||||||
value_cache,
|
value,
|
||||||
out,
|
out,
|
||||||
seqlen.cu_seqlen_q,
|
seqlen.cu_seqlen_q,
|
||||||
seqlen.cu_seqlen_q,
|
seqlen.cu_seqlen_q,
|
||||||
|
@ -307,8 +300,8 @@ elif ENGINE == "triton":
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Unknown attention engine {ENGINE}")
|
raise RuntimeError(f"Unknown attention engine {ENGINE}")
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"PREFILL_IN_KV_CACHE",
|
|
||||||
"SUPPORTS_WINDOWING",
|
"SUPPORTS_WINDOWING",
|
||||||
"attention",
|
"attention",
|
||||||
"paged_attention",
|
"paged_attention",
|
||||||
|
|
|
@ -38,7 +38,6 @@ from text_generation_server.layers import (
|
||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastLayerNorm,
|
FastLayerNorm,
|
||||||
)
|
)
|
||||||
|
@ -296,19 +295,19 @@ class FlashCohereAttention(torch.nn.Module):
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query=query,
|
||||||
kv_cache.key if PREFILL_IN_KV_CACHE else key,
|
key=key,
|
||||||
kv_cache.value if PREFILL_IN_KV_CACHE else value,
|
value=value,
|
||||||
seqlen,
|
kv_cache=kv_cache,
|
||||||
block_tables,
|
seqlen=seqlen,
|
||||||
self.softmax_scale,
|
block_tables=block_tables,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
query,
|
query,
|
||||||
kv_cache.key,
|
kv_cache,
|
||||||
kv_cache.value,
|
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
|
|
@ -29,7 +29,6 @@ from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
Seqlen,
|
Seqlen,
|
||||||
PREFILL_IN_KV_CACHE,
|
|
||||||
)
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
FastLinear,
|
FastLinear,
|
||||||
|
@ -335,19 +334,19 @@ class DbrxAttention(torch.nn.Module):
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query=query,
|
||||||
kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0],
|
key=kv[:, 0],
|
||||||
kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1],
|
value=kv[:, 1],
|
||||||
seqlen,
|
kv_cache=kv_cache,
|
||||||
block_tables,
|
seqlen=seqlen,
|
||||||
self.softmax_scale,
|
block_tables=block_tables,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
query,
|
query,
|
||||||
kv_cache.key,
|
kv_cache,
|
||||||
kv_cache.value,
|
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
|
|
@ -34,7 +34,6 @@ from text_generation_server.layers.attention import (
|
||||||
attention,
|
attention,
|
||||||
paged_attention,
|
paged_attention,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
|
||||||
from text_generation_server.layers.layernorm import FastRMSNorm
|
from text_generation_server.layers.layernorm import FastRMSNorm
|
||||||
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
||||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
|
||||||
|
@ -326,19 +325,19 @@ class DeepseekV2Attention(torch.nn.Module):
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query=query,
|
||||||
kv_cache.key if PREFILL_IN_KV_CACHE else key,
|
key=key,
|
||||||
kv_cache.value if PREFILL_IN_KV_CACHE else value,
|
value=value,
|
||||||
seqlen,
|
kv_cache=kv_cache,
|
||||||
block_tables,
|
seqlen=seqlen,
|
||||||
self.softmax_scale,
|
block_tables=block_tables,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
query,
|
query,
|
||||||
kv_cache.key,
|
kv_cache,
|
||||||
kv_cache.value,
|
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
|
|
@ -39,7 +39,6 @@ from text_generation_server.layers import (
|
||||||
TensorParallelMultiAdapterLinear,
|
TensorParallelMultiAdapterLinear,
|
||||||
TensorParallelAdapterRowLinear,
|
TensorParallelAdapterRowLinear,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
|
||||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
|
@ -258,13 +257,13 @@ class FlashGemma2Attention(torch.nn.Module):
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query=query,
|
||||||
kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0],
|
key=kv[:, 0],
|
||||||
kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1],
|
value=kv[:, 1],
|
||||||
seqlen,
|
kv_cache=kv_cache,
|
||||||
block_tables,
|
seqlen=seqlen,
|
||||||
self.softmax_scale,
|
block_tables=block_tables,
|
||||||
causal=self.causal,
|
softmax_scale=self.softmax_scale,
|
||||||
window_size_left=self.window_size,
|
window_size_left=self.window_size,
|
||||||
softcap=self.softcap,
|
softcap=self.softcap,
|
||||||
)
|
)
|
||||||
|
@ -272,8 +271,7 @@ class FlashGemma2Attention(torch.nn.Module):
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
query,
|
query,
|
||||||
kv_cache.key,
|
kv_cache,
|
||||||
kv_cache.value,
|
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
|
|
@ -29,7 +29,6 @@ from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
Seqlen,
|
Seqlen,
|
||||||
PREFILL_IN_KV_CACHE,
|
|
||||||
)
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
|
@ -229,20 +228,20 @@ class FlashGemmaAttention(torch.nn.Module):
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query=query,
|
||||||
kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0],
|
key=kv[:, 0],
|
||||||
kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1],
|
value=kv[:, 1],
|
||||||
seqlen,
|
kv_cache=kv_cache,
|
||||||
block_tables,
|
seqlen=seqlen,
|
||||||
self.softmax_scale,
|
block_tables=block_tables,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
causal=self.causal,
|
causal=self.causal,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
query,
|
query,
|
||||||
kv_cache.key,
|
kv_cache,
|
||||||
kv_cache.value,
|
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
|
|
@ -24,7 +24,6 @@ import torch.distributed
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
|
@ -229,19 +228,19 @@ class FlashGPT2Attention(torch.nn.Module):
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query=query,
|
||||||
kv_cache.key if PREFILL_IN_KV_CACHE else key,
|
key=key,
|
||||||
kv_cache.value if PREFILL_IN_KV_CACHE else value,
|
value=value,
|
||||||
seqlen,
|
kv_cache=kv_cache,
|
||||||
block_tables,
|
seqlen=seqlen,
|
||||||
self.softmax_scale,
|
block_tables=block_tables,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
query,
|
query,
|
||||||
kv_cache.key,
|
kv_cache,
|
||||||
kv_cache.value,
|
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
|
|
@ -37,7 +37,6 @@ from text_generation_server.layers import (
|
||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
|
||||||
from text_generation_server.layers.rotary import (
|
from text_generation_server.layers.rotary import (
|
||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
)
|
)
|
||||||
|
@ -191,19 +190,19 @@ class FlashGPTJAttention(torch.nn.Module):
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query=query,
|
||||||
kv_cache.key if PREFILL_IN_KV_CACHE else key,
|
key=key,
|
||||||
kv_cache.value if PREFILL_IN_KV_CACHE else value,
|
value=value,
|
||||||
seqlen,
|
kv_cache=kv_cache,
|
||||||
block_tables,
|
seqlen=seqlen,
|
||||||
self.softmax_scale,
|
block_tables=block_tables,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
query,
|
query,
|
||||||
kv_cache.key,
|
kv_cache,
|
||||||
kv_cache.value,
|
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
|
|
@ -27,7 +27,7 @@ import torch.distributed
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
|
|
||||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE, KVCache
|
from text_generation_server.layers.attention import KVCache
|
||||||
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
|
@ -227,19 +227,19 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query=query,
|
||||||
kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0],
|
key=kv[:, 0],
|
||||||
kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1],
|
value=kv[:, 1],
|
||||||
seqlen,
|
kv_cache=kv_cache,
|
||||||
block_tables,
|
seqlen=seqlen,
|
||||||
self.softmax_scale,
|
block_tables=block_tables,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
query,
|
query,
|
||||||
kv_cache.key,
|
kv_cache,
|
||||||
kv_cache.value,
|
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
|
|
@ -40,7 +40,6 @@ from text_generation_server.layers import (
|
||||||
TensorParallelMultiAdapterLinear,
|
TensorParallelMultiAdapterLinear,
|
||||||
TensorParallelAdapterRowLinear,
|
TensorParallelAdapterRowLinear,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
|
||||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
|
@ -215,20 +214,20 @@ class MistralAttention(torch.nn.Module):
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query=query,
|
||||||
kv_cache.key if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
|
key=kv_to_cache[:, 0],
|
||||||
kv_cache.value if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
|
value=kv_to_cache[:, 1],
|
||||||
seqlen,
|
kv_cache=kv_cache,
|
||||||
block_tables,
|
seqlen=seqlen,
|
||||||
self.softmax_scale,
|
block_tables=block_tables,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
window_size_left=self.max_past,
|
window_size_left=self.max_past,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
query,
|
query,
|
||||||
kv_cache.key,
|
kv_cache,
|
||||||
kv_cache.value,
|
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
|
|
@ -38,7 +38,6 @@ from text_generation_server.layers.attention import (
|
||||||
attention,
|
attention,
|
||||||
paged_attention,
|
paged_attention,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
|
||||||
from text_generation_server.layers.layernorm import FastRMSNorm
|
from text_generation_server.layers.layernorm import FastRMSNorm
|
||||||
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
||||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
|
@ -263,20 +262,20 @@ class MixtralAttention(torch.nn.Module):
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query=query,
|
||||||
kv_cache.key if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
|
key=kv_to_cache[:, 0],
|
||||||
kv_cache.value if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
|
value=kv_to_cache[:, 1],
|
||||||
seqlen,
|
kv_cache=kv_cache,
|
||||||
block_tables,
|
seqlen=seqlen,
|
||||||
self.softmax_scale,
|
block_tables=block_tables,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
window_size_left=self.max_past,
|
window_size_left=self.max_past,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
query,
|
query,
|
||||||
kv_cache.key,
|
kv_cache,
|
||||||
kv_cache.value,
|
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
|
|
@ -38,7 +38,6 @@ from text_generation_server.layers import (
|
||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastLayerNorm,
|
FastLayerNorm,
|
||||||
)
|
)
|
||||||
|
@ -170,19 +169,19 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
qkv[:, 0],
|
query=qkv[:, 0],
|
||||||
kv_cache.key if PREFILL_IN_KV_CACHE else qkv[:, 1],
|
key=qkv[:, 1],
|
||||||
kv_cache.value if PREFILL_IN_KV_CACHE else qkv[:, 2],
|
value=qkv[:, 2],
|
||||||
seqlen,
|
kv_cache=kv_cache,
|
||||||
block_tables,
|
seqlen=seqlen,
|
||||||
self.softmax_scale,
|
block_tables=block_tables,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
qkv[:, 0],
|
qkv[:, 0],
|
||||||
kv_cache.key,
|
kv_cache,
|
||||||
kv_cache.value,
|
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
|
|
@ -18,7 +18,6 @@ from text_generation_server.layers import (
|
||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastLayerNorm,
|
FastLayerNorm,
|
||||||
)
|
)
|
||||||
|
@ -192,19 +191,19 @@ class FlashPhiAttention(torch.nn.Module):
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query=query,
|
||||||
kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0],
|
key=kv[:, 0],
|
||||||
kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1],
|
value=kv[:, 1],
|
||||||
seqlen,
|
kv_cache=kv_cache,
|
||||||
block_tables,
|
seqlen=seqlen,
|
||||||
self.softmax_scale,
|
block_tables=block_tables,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
query,
|
query,
|
||||||
kv_cache.key,
|
kv_cache,
|
||||||
kv_cache.value,
|
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
|
|
@ -16,7 +16,6 @@ from text_generation_server.layers import (
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
|
||||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
|
@ -133,20 +132,20 @@ class Qwen2Attention(torch.nn.Module):
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query=query,
|
||||||
kv_cache.key if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
|
key=kv_to_cache[:, 0],
|
||||||
kv_cache.value if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
|
value=kv_to_cache[:, 1],
|
||||||
seqlen,
|
kv_cache=kv_cache,
|
||||||
block_tables,
|
seqlen=seqlen,
|
||||||
self.softmax_scale,
|
block_tables=block_tables,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
window_size_left=self.max_past,
|
window_size_left=self.max_past,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
query,
|
query,
|
||||||
kv_cache.key,
|
kv_cache,
|
||||||
kv_cache.value,
|
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
|
|
@ -12,7 +12,6 @@ from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
|
||||||
from text_generation_server.layers.layernorm import FastLayerNorm
|
from text_generation_server.layers.layernorm import FastLayerNorm
|
||||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
|
@ -205,19 +204,19 @@ class FlashRWAttention(torch.nn.Module):
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query=query,
|
||||||
kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, 0],
|
key=kv[:, 0],
|
||||||
kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, 1],
|
value=kv[:, 1],
|
||||||
seqlen,
|
kv_cache=kv_cache,
|
||||||
block_tables,
|
seqlen=seqlen,
|
||||||
self.softmax_scale,
|
block_tables=block_tables,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
query,
|
query,
|
||||||
kv_cache.key,
|
kv_cache,
|
||||||
kv_cache.value,
|
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
@ -319,19 +318,19 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query=query,
|
||||||
kv_cache.key if PREFILL_IN_KV_CACHE else kv[:, :, 0].contiguous(),
|
key=kv[:, :, 0],
|
||||||
kv_cache.value if PREFILL_IN_KV_CACHE else kv[:, :, 1].contiguous(),
|
value=kv[:, :, 1],
|
||||||
seqlen,
|
kv_cache=kv_cache,
|
||||||
block_tables,
|
seqlen=seqlen,
|
||||||
self.softmax_scale,
|
block_tables=block_tables,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
query,
|
query,
|
||||||
kv_cache.key,
|
kv_cache,
|
||||||
kv_cache.value,
|
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
|
|
@ -17,7 +17,6 @@ from text_generation_server.layers import (
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
|
||||||
from text_generation_server.layers.gptq import GPTQWeightsLoader
|
from text_generation_server.layers.gptq import GPTQWeightsLoader
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastLayerNorm,
|
FastLayerNorm,
|
||||||
|
@ -289,19 +288,19 @@ class FlashMQAttention(torch.nn.Module):
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query=query,
|
||||||
kv_cache.key if PREFILL_IN_KV_CACHE else key_value[:, 0],
|
key=key_value[:, 0],
|
||||||
kv_cache.value if PREFILL_IN_KV_CACHE else key_value[:, 1],
|
value=key_value[:, 1],
|
||||||
seqlen,
|
kv_cache=kv_cache,
|
||||||
block_tables,
|
seqlen=seqlen,
|
||||||
self.softmax_scale,
|
block_tables=block_tables,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
query,
|
query,
|
||||||
kv_cache.key,
|
kv_cache,
|
||||||
kv_cache.value,
|
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
|
|
@ -38,7 +38,6 @@ from text_generation_server.layers import (
|
||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastLayerNorm,
|
FastLayerNorm,
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
|
@ -238,20 +237,20 @@ class Starcoder2Attention(torch.nn.Module):
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query=query,
|
||||||
kv_cache.key if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
|
key=kv_to_cache[:, 0],
|
||||||
kv_cache.value if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
|
value=kv_to_cache[:, 1],
|
||||||
seqlen,
|
kv_cache=kv_cache,
|
||||||
block_tables,
|
seqlen=seqlen,
|
||||||
self.softmax_scale,
|
block_tables=block_tables,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
window_size_left=self.max_past,
|
window_size_left=self.max_past,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
attn_output = paged_attention(
|
attn_output = paged_attention(
|
||||||
query,
|
query,
|
||||||
kv_cache.key,
|
kv_cache,
|
||||||
kv_cache.value,
|
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
|
|
Loading…
Reference in New Issue