Unify attention output handling (#2343)

- Always return the hidden states.
- Create the output tensor inside the `attention` and `paged_attention`
  functions.

This removes the difference between how the output is handled between
attention (output parameter) and paged attention (return value). This
also removes the assumption that the attention implementation can
write to an output tensor (in preparation of FlashInfer).
This commit is contained in:
Daniël de Kok 2024-08-01 17:03:28 +02:00 committed by GitHub
parent 22fb1be588
commit 47447ef017
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 36 additions and 109 deletions

View File

@ -34,7 +34,6 @@ def reshape_and_cache(
def paged_attention( def paged_attention(
out: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
key_cache: torch.Tensor, key_cache: torch.Tensor,
value_cache: torch.Tensor, value_cache: torch.Tensor,
@ -85,7 +84,7 @@ def paged_attention(
# This fails becuase we're using causal, therefore window_right is set to 0 and the split logic is never applied. # This fails becuase we're using causal, therefore window_right is set to 0 and the split logic is never applied.
if softcap is None: if softcap is None:
softcap = 0.0 softcap = 0.0
out2 = flash_attn_2_cuda.varlen_fwd( out = flash_attn_2_cuda.varlen_fwd(
query, query,
key_cache, key_cache,
value_cache, value_cache,
@ -108,13 +107,15 @@ def paged_attention(
False, # return softmax False, # return softmax
None, # generator None, # generator
) )
return out2[0] return out[0]
else: else:
if softcap is not None: if softcap is not None:
raise RuntimeError("Paged attention doesn't support softcapping") raise RuntimeError("Paged attention doesn't support softcapping")
input_lengths = seqlen.input_lengths input_lengths = seqlen.input_lengths
from vllm._C import ops from vllm._C import ops
out = torch.empty_like(query)
use_v1 = max_s <= 8192 and ( use_v1 = max_s <= 8192 and (
max_num_partitions == 1 or num_seqs * num_heads > 512 max_num_partitions == 1 or num_seqs * num_heads > 512
) )
@ -200,13 +201,13 @@ except ImportError:
SUPPORTS_WINDOWING = V2 SUPPORTS_WINDOWING = V2
if V2: if V2:
def attention( def attention(
q, q,
k, k,
v, v,
out,
cu_seqlens, cu_seqlens,
max_s, max_s,
softmax_scale, softmax_scale,
@ -214,6 +215,7 @@ if V2:
causal=True, causal=True,
softcap=0.0, softcap=0.0,
): ):
out = torch.empty_like(q)
if window_size_left <= 0 and window_size_left != -1: if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1") raise ValueError("`window_size_left` must be > 0 or -1")
return flash_attn_2_cuda.varlen_fwd( return flash_attn_2_cuda.varlen_fwd(
@ -238,7 +240,7 @@ if V2:
softcap, softcap,
False, False,
None, None,
) )[0]
else: else:
@ -246,7 +248,6 @@ else:
q, q,
k, k,
v, v,
out,
cu_seqlens, cu_seqlens,
max_s, max_s,
softmax_scale, softmax_scale,
@ -286,6 +287,8 @@ else:
.reshape(original_shape[0], -1, original_shape[2]) .reshape(original_shape[0], -1, original_shape[2])
) )
out = torch.empty_like(q)
return flash_attn_cuda.fwd( return flash_attn_cuda.fwd(
q, q,
k, k,
@ -302,4 +305,4 @@ else:
False, False,
0, 0,
None, None,
) )[0]

View File

@ -10,13 +10,14 @@ def attention(
q, q,
k, k,
v, v,
out,
cu_seqlens, cu_seqlens,
max_s, max_s,
softmax_scale, softmax_scale,
window_size_left=-1, window_size_left=-1,
causal=True, causal=True,
): ):
out = torch.empty_like(q)
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
return ipex.llm.functional.varlen_attention( return ipex.llm.functional.varlen_attention(
q, q,
@ -49,7 +50,6 @@ def reshape_and_cache(
def paged_attention( def paged_attention(
out: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
key_cache: torch.Tensor, key_cache: torch.Tensor,
value_cache: torch.Tensor, value_cache: torch.Tensor,
@ -59,6 +59,7 @@ def paged_attention(
seqlen: Seqlen, seqlen: Seqlen,
max_s: int, max_s: int,
): ):
out = torch.empty_like(query)
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
out, out,
query, query,

View File

@ -39,7 +39,6 @@ def reshape_and_cache(
def paged_attention( def paged_attention(
out: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
key_cache: torch.Tensor, key_cache: torch.Tensor,
value_cache: torch.Tensor, value_cache: torch.Tensor,
@ -72,6 +71,8 @@ def paged_attention(
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
input_lengths = input_lengths.input_lengths input_lengths = input_lengths.input_lengths
out = torch.empty_like(query)
# NOTE(woosuk): We use a simple heuristic to decide whether to use # NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use # PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of # V1 to avoid the overhead of reduction. Also, if the number of
@ -174,7 +175,6 @@ if ENGINE == "ck":
q, q,
k, k,
v, v,
out,
cu_seqlens, cu_seqlens,
max_s, max_s,
softmax_scale, softmax_scale,
@ -184,6 +184,8 @@ if ENGINE == "ck":
if window_size_left <= 0 and window_size_left != -1: if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1") raise ValueError("`window_size_left` must be > 0 or -1")
out = torch.empty_like(q)
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
return flash_attn_2_cuda.varlen_fwd( return flash_attn_2_cuda.varlen_fwd(
q, q,
@ -209,13 +211,14 @@ elif ENGINE == "triton":
q, q,
k, k,
v, v,
out,
cu_seqlens, cu_seqlens,
max_s, max_s,
softmax_scale, softmax_scale,
window_size_left=-1, window_size_left=-1,
causal=True, causal=True,
): ):
out = torch.empty_like(q)
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
output, _ = triton_attention( output, _ = triton_attention(
q, q,

View File

@ -291,17 +291,13 @@ class FlashCohereAttention(torch.nn.Module):
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
# output tensor
attn_output = torch.empty_like(query)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attention( attn_output = attention(
query, query,
key, key,
value, value,
attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
@ -309,7 +305,6 @@ class FlashCohereAttention(torch.nn.Module):
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
attn_output,
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],

View File

@ -330,17 +330,13 @@ class DbrxAttention(torch.nn.Module):
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
# output tensor
attn_output = torch.empty_like(query)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
@ -348,7 +344,6 @@ class DbrxAttention(torch.nn.Module):
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
attn_output,
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],

View File

@ -358,25 +358,20 @@ class DeepseekV2Attention(torch.nn.Module):
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
# Output tensor
attn_output = torch.empty_like(query)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attention( attn_output = attention(
query, query,
key, key,
value, value,
attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
) )
# Decode # Decode
else: else:
paged_attention( attn_output = paged_attention(
attn_output,
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],

View File

@ -231,17 +231,13 @@ class FlashGemma2Attention(torch.nn.Module):
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
# output tensor
attn_output = torch.empty_like(query)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
@ -252,7 +248,6 @@ class FlashGemma2Attention(torch.nn.Module):
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
attn_output,
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],

View File

@ -225,17 +225,13 @@ class FlashGemmaAttention(torch.nn.Module):
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
# output tensor
attn_output = torch.empty_like(query)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
@ -244,7 +240,6 @@ class FlashGemmaAttention(torch.nn.Module):
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
attn_output,
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],

View File

@ -225,17 +225,13 @@ class FlashGPT2Attention(torch.nn.Module):
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
# output tensor
attn_output = torch.empty_like(query)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attention( attn_output = attention(
query, query,
key, key,
value, value,
attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
@ -243,7 +239,6 @@ class FlashGPT2Attention(torch.nn.Module):
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
attn_output,
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],

View File

@ -213,17 +213,13 @@ class FlashLlamaAttention(torch.nn.Module):
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
# output tensor
attn_output = torch.empty_like(query)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
@ -231,7 +227,6 @@ class FlashLlamaAttention(torch.nn.Module):
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
attn_output,
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],

View File

@ -212,17 +212,13 @@ class MistralAttention(torch.nn.Module):
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
) )
# output tensor
attn_output = torch.empty_like(query)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
@ -231,7 +227,6 @@ class MistralAttention(torch.nn.Module):
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
attn_output,
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],

View File

@ -269,17 +269,13 @@ class MixtralAttention(torch.nn.Module):
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
) )
# output tensor
attn_output = torch.empty_like(query)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
@ -288,7 +284,6 @@ class MixtralAttention(torch.nn.Module):
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
attn_output,
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],

View File

@ -158,17 +158,13 @@ class FlashNeoxAttention(torch.nn.Module):
reshape_and_cache(qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots) reshape_and_cache(qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots)
# output tensor
attn_output = torch.empty_like(qkv[:, 0])
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attention( attn_output = attention(
qkv[:, 0], qkv[:, 0],
qkv[:, 1], qkv[:, 1],
qkv[:, 2], qkv[:, 2],
attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
@ -176,7 +172,6 @@ class FlashNeoxAttention(torch.nn.Module):
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
attn_output,
qkv[:, 0], qkv[:, 0],
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],

View File

@ -188,16 +188,12 @@ class FlashPhiAttention(torch.nn.Module):
# Reshape key and value and cache # Reshape key and value and cache
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
# output tensor
attn_output = torch.empty_like(query)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
@ -205,7 +201,6 @@ class FlashPhiAttention(torch.nn.Module):
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
attn_output,
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],

View File

@ -130,17 +130,13 @@ class Qwen2Attention(torch.nn.Module):
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
) )
# output tensor
attn_output = torch.empty_like(query)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
@ -149,7 +145,6 @@ class Qwen2Attention(torch.nn.Module):
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
attn_output,
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],

View File

@ -201,17 +201,13 @@ class FlashRWAttention(torch.nn.Module):
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
# output
attn_output = torch.empty_like(query)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
@ -219,7 +215,6 @@ class FlashRWAttention(torch.nn.Module):
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
attn_output,
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
@ -324,17 +319,13 @@ class FlashRWLargeAttention(torch.nn.Module):
slots, slots,
) )
# output
attn_output = torch.empty_like(query)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attention( attn_output = attention(
query, query,
torch.select(kv, dim=2, index=0), torch.select(kv, dim=2, index=0),
torch.select(kv, dim=2, index=1), torch.select(kv, dim=2, index=1),
attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
@ -342,7 +333,6 @@ class FlashRWLargeAttention(torch.nn.Module):
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
attn_output,
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],

View File

@ -286,17 +286,13 @@ class FlashMQAttention(torch.nn.Module):
key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots
) )
# output
attn_output = torch.empty_like(query)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attention( attn_output = attention(
query, query,
torch.select(key_value, dim=1, index=0), torch.select(key_value, dim=1, index=0),
torch.select(key_value, dim=1, index=1), torch.select(key_value, dim=1, index=1),
attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
@ -304,7 +300,6 @@ class FlashMQAttention(torch.nn.Module):
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
attn_output,
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],

View File

@ -235,17 +235,13 @@ class Starcoder2Attention(torch.nn.Module):
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
) )
# output tensor
attn_output = torch.empty_like(query)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attention( attn_output = attention(
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
@ -254,7 +250,6 @@ class Starcoder2Attention(torch.nn.Module):
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
attn_output,
query, query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],