Unify attention output handling (#2343)
- Always return the hidden states. - Create the output tensor inside the `attention` and `paged_attention` functions. This removes the difference between how the output is handled between attention (output parameter) and paged attention (return value). This also removes the assumption that the attention implementation can write to an output tensor (in preparation of FlashInfer).
This commit is contained in:
parent
22fb1be588
commit
47447ef017
|
@ -34,7 +34,6 @@ def reshape_and_cache(
|
|||
|
||||
|
||||
def paged_attention(
|
||||
out: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
|
@ -85,7 +84,7 @@ def paged_attention(
|
|||
# This fails becuase we're using causal, therefore window_right is set to 0 and the split logic is never applied.
|
||||
if softcap is None:
|
||||
softcap = 0.0
|
||||
out2 = flash_attn_2_cuda.varlen_fwd(
|
||||
out = flash_attn_2_cuda.varlen_fwd(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
|
@ -108,13 +107,15 @@ def paged_attention(
|
|||
False, # return softmax
|
||||
None, # generator
|
||||
)
|
||||
return out2[0]
|
||||
return out[0]
|
||||
else:
|
||||
if softcap is not None:
|
||||
raise RuntimeError("Paged attention doesn't support softcapping")
|
||||
input_lengths = seqlen.input_lengths
|
||||
from vllm._C import ops
|
||||
|
||||
out = torch.empty_like(query)
|
||||
|
||||
use_v1 = max_s <= 8192 and (
|
||||
max_num_partitions == 1 or num_seqs * num_heads > 512
|
||||
)
|
||||
|
@ -200,13 +201,13 @@ except ImportError:
|
|||
|
||||
|
||||
SUPPORTS_WINDOWING = V2
|
||||
|
||||
if V2:
|
||||
|
||||
def attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
softmax_scale,
|
||||
|
@ -214,6 +215,7 @@ if V2:
|
|||
causal=True,
|
||||
softcap=0.0,
|
||||
):
|
||||
out = torch.empty_like(q)
|
||||
if window_size_left <= 0 and window_size_left != -1:
|
||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||
return flash_attn_2_cuda.varlen_fwd(
|
||||
|
@ -238,7 +240,7 @@ if V2:
|
|||
softcap,
|
||||
False,
|
||||
None,
|
||||
)
|
||||
)[0]
|
||||
|
||||
else:
|
||||
|
||||
|
@ -246,7 +248,6 @@ else:
|
|||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
softmax_scale,
|
||||
|
@ -286,6 +287,8 @@ else:
|
|||
.reshape(original_shape[0], -1, original_shape[2])
|
||||
)
|
||||
|
||||
out = torch.empty_like(q)
|
||||
|
||||
return flash_attn_cuda.fwd(
|
||||
q,
|
||||
k,
|
||||
|
@ -302,4 +305,4 @@ else:
|
|||
False,
|
||||
0,
|
||||
None,
|
||||
)
|
||||
)[0]
|
||||
|
|
|
@ -10,13 +10,14 @@ def attention(
|
|||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
softmax_scale,
|
||||
window_size_left=-1,
|
||||
causal=True,
|
||||
):
|
||||
out = torch.empty_like(q)
|
||||
|
||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||
return ipex.llm.functional.varlen_attention(
|
||||
q,
|
||||
|
@ -49,7 +50,6 @@ def reshape_and_cache(
|
|||
|
||||
|
||||
def paged_attention(
|
||||
out: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
|
@ -59,6 +59,7 @@ def paged_attention(
|
|||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
):
|
||||
out = torch.empty_like(query)
|
||||
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
||||
out,
|
||||
query,
|
||||
|
|
|
@ -39,7 +39,6 @@ def reshape_and_cache(
|
|||
|
||||
|
||||
def paged_attention(
|
||||
out: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
|
@ -72,6 +71,8 @@ def paged_attention(
|
|||
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
||||
input_lengths = input_lengths.input_lengths
|
||||
|
||||
out = torch.empty_like(query)
|
||||
|
||||
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
||||
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
||||
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||
|
@ -174,7 +175,6 @@ if ENGINE == "ck":
|
|||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
softmax_scale,
|
||||
|
@ -184,6 +184,8 @@ if ENGINE == "ck":
|
|||
if window_size_left <= 0 and window_size_left != -1:
|
||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||
|
||||
out = torch.empty_like(q)
|
||||
|
||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||
return flash_attn_2_cuda.varlen_fwd(
|
||||
q,
|
||||
|
@ -209,13 +211,14 @@ elif ENGINE == "triton":
|
|||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
softmax_scale,
|
||||
window_size_left=-1,
|
||||
causal=True,
|
||||
):
|
||||
out = torch.empty_like(q)
|
||||
|
||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||
output, _ = triton_attention(
|
||||
q,
|
||||
|
|
|
@ -291,17 +291,13 @@ class FlashCohereAttention(torch.nn.Module):
|
|||
|
||||
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
|
||||
|
||||
# output tensor
|
||||
attn_output = torch.empty_like(query)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attention(
|
||||
attn_output = attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_output,
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
|
@ -309,7 +305,6 @@ class FlashCohereAttention(torch.nn.Module):
|
|||
# Decode
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
|
|
|
@ -330,17 +330,13 @@ class DbrxAttention(torch.nn.Module):
|
|||
|
||||
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
||||
|
||||
# output tensor
|
||||
attn_output = torch.empty_like(query)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attention(
|
||||
attn_output = attention(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
attn_output,
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
|
@ -348,7 +344,6 @@ class DbrxAttention(torch.nn.Module):
|
|||
# Decode
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
|
|
|
@ -358,25 +358,20 @@ class DeepseekV2Attention(torch.nn.Module):
|
|||
|
||||
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
|
||||
|
||||
# Output tensor
|
||||
attn_output = torch.empty_like(query)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attention(
|
||||
attn_output = attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_output,
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention(
|
||||
attn_output,
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
|
|
|
@ -231,17 +231,13 @@ class FlashGemma2Attention(torch.nn.Module):
|
|||
|
||||
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
||||
|
||||
# output tensor
|
||||
attn_output = torch.empty_like(query)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attention(
|
||||
attn_output = attention(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
attn_output,
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
|
@ -252,7 +248,6 @@ class FlashGemma2Attention(torch.nn.Module):
|
|||
# Decode
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
|
|
|
@ -225,17 +225,13 @@ class FlashGemmaAttention(torch.nn.Module):
|
|||
|
||||
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
||||
|
||||
# output tensor
|
||||
attn_output = torch.empty_like(query)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attention(
|
||||
attn_output = attention(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
attn_output,
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
|
@ -244,7 +240,6 @@ class FlashGemmaAttention(torch.nn.Module):
|
|||
# Decode
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
|
|
|
@ -225,17 +225,13 @@ class FlashGPT2Attention(torch.nn.Module):
|
|||
|
||||
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
|
||||
|
||||
# output tensor
|
||||
attn_output = torch.empty_like(query)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attention(
|
||||
attn_output = attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_output,
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
|
@ -243,7 +239,6 @@ class FlashGPT2Attention(torch.nn.Module):
|
|||
# Decode
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
|
|
|
@ -213,17 +213,13 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
|
||||
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
||||
|
||||
# output tensor
|
||||
attn_output = torch.empty_like(query)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attention(
|
||||
attn_output = attention(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
attn_output,
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
|
@ -231,7 +227,6 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
# Decode
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
|
|
|
@ -212,17 +212,13 @@ class MistralAttention(torch.nn.Module):
|
|||
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
)
|
||||
|
||||
# output tensor
|
||||
attn_output = torch.empty_like(query)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attention(
|
||||
attn_output = attention(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
attn_output,
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
|
@ -231,7 +227,6 @@ class MistralAttention(torch.nn.Module):
|
|||
# Decode
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
|
|
|
@ -269,17 +269,13 @@ class MixtralAttention(torch.nn.Module):
|
|||
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
)
|
||||
|
||||
# output tensor
|
||||
attn_output = torch.empty_like(query)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attention(
|
||||
attn_output = attention(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
attn_output,
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
|
@ -288,7 +284,6 @@ class MixtralAttention(torch.nn.Module):
|
|||
# Decode
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
|
|
|
@ -158,17 +158,13 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||
|
||||
reshape_and_cache(qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots)
|
||||
|
||||
# output tensor
|
||||
attn_output = torch.empty_like(qkv[:, 0])
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attention(
|
||||
attn_output = attention(
|
||||
qkv[:, 0],
|
||||
qkv[:, 1],
|
||||
qkv[:, 2],
|
||||
attn_output,
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
|
@ -176,7 +172,6 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||
# Decode
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
qkv[:, 0],
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
|
|
|
@ -188,16 +188,12 @@ class FlashPhiAttention(torch.nn.Module):
|
|||
# Reshape key and value and cache
|
||||
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
||||
|
||||
# output tensor
|
||||
attn_output = torch.empty_like(query)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
attention(
|
||||
attn_output = attention(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
attn_output,
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
|
@ -205,7 +201,6 @@ class FlashPhiAttention(torch.nn.Module):
|
|||
# Decode
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
|
|
|
@ -130,17 +130,13 @@ class Qwen2Attention(torch.nn.Module):
|
|||
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
)
|
||||
|
||||
# output tensor
|
||||
attn_output = torch.empty_like(query)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attention(
|
||||
attn_output = attention(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
attn_output,
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
|
@ -149,7 +145,6 @@ class Qwen2Attention(torch.nn.Module):
|
|||
# Decode
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
|
|
|
@ -201,17 +201,13 @@ class FlashRWAttention(torch.nn.Module):
|
|||
|
||||
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
||||
|
||||
# output
|
||||
attn_output = torch.empty_like(query)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attention(
|
||||
attn_output = attention(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
attn_output,
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
|
@ -219,7 +215,6 @@ class FlashRWAttention(torch.nn.Module):
|
|||
# Decode
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
|
@ -324,17 +319,13 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||
slots,
|
||||
)
|
||||
|
||||
# output
|
||||
attn_output = torch.empty_like(query)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attention(
|
||||
attn_output = attention(
|
||||
query,
|
||||
torch.select(kv, dim=2, index=0),
|
||||
torch.select(kv, dim=2, index=1),
|
||||
attn_output,
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
|
@ -342,7 +333,6 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||
# Decode
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
|
|
|
@ -286,17 +286,13 @@ class FlashMQAttention(torch.nn.Module):
|
|||
key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
)
|
||||
|
||||
# output
|
||||
attn_output = torch.empty_like(query)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attention(
|
||||
attn_output = attention(
|
||||
query,
|
||||
torch.select(key_value, dim=1, index=0),
|
||||
torch.select(key_value, dim=1, index=1),
|
||||
attn_output,
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
|
@ -304,7 +300,6 @@ class FlashMQAttention(torch.nn.Module):
|
|||
# Decode
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
|
|
|
@ -235,17 +235,13 @@ class Starcoder2Attention(torch.nn.Module):
|
|||
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
)
|
||||
|
||||
# output tensor
|
||||
attn_output = torch.empty_like(query)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attention(
|
||||
attn_output = attention(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
attn_output,
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
|
@ -254,7 +250,6 @@ class Starcoder2Attention(torch.nn.Module):
|
|||
# Decode
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
|
|
Loading…
Reference in New Issue