Less intrusive.

This commit is contained in:
Nicolas Patry 2024-05-24 14:15:33 +00:00
parent cacba5f21f
commit 63e72033b7
4 changed files with 65 additions and 139 deletions

View File

@ -28,7 +28,6 @@ from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import FLASH_DECODING
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -155,72 +154,34 @@ class FlashLlamaAttention(torch.nn.Module):
# output tensor
attn_output = torch.empty_like(query)
if FLASH_DECODING:
# Prefill
kv_cache[0].view(-1, self.num_key_value_heads, self.head_size)[slots] = kv[
:, 0
]
kv_cache[1].view(-1, self.num_key_value_heads, self.head_size)[slots] = kv[
:, 1
]
paged_attention.reshape_and_cache(
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
)
if cu_seqlen_prefill is not None:
# flash attention
flash_attn.attention(
query,
# torch.select(kv, dim=1, index=0),
# torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
attn_output,
cu_seqlen_prefill,
block_tables,
max_s,
self.softmax_scale,
)
# Decode
else:
paged_attention.attention(
attn_output,
query,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
max_s,
)
else:
paged_attention.reshape_and_cache(
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
if cu_seqlen_prefill is not None:
# flash attention
flash_attn.attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill,
max_s,
self.softmax_scale,
)
# Decode
else:
paged_attention.attention(
attn_output,
query,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
max_s,
)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
flash_attn.attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill,
None,
max_s,
self.softmax_scale,
)
# Decode
else:
paged_attention.attention(
attn_output,
query,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
max_s,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -27,7 +27,6 @@ from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import FLASH_DECODING
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -217,77 +216,39 @@ class MistralAttention(torch.nn.Module):
attn_output = torch.empty_like(query)
if FLASH_DECODING:
# Prefill
kv_cache[0].view(-1, self.num_key_value_heads, self.head_size)[slots] = kv[
:, 0
]
kv_cache[1].view(-1, self.num_key_value_heads, self.head_size)[slots] = kv[
:, 1
]
if cu_seqlen_prefill is not None:
# flash attention
flash_attn.attention(
query,
# torch.select(kv, dim=1, index=0),
# torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
attn_output,
cu_seqlen_prefill,
block_tables,
max_s,
self.softmax_scale,
)
# Decode
else:
paged_attention.attention(
attn_output,
query,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
max_s,
)
if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices]
else:
if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices]
else:
kv_to_cache = kv
kv_to_cache = kv
paged_attention.reshape_and_cache(
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
paged_attention.reshape_and_cache(
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
flash_attn.attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill,
max_s,
self.softmax_scale,
)
# Decode
else:
paged_attention.attention(
attn_output,
query,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
max_s,
)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
flash_attn.attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill,
None,
max_s,
self.softmax_scale,
)
# Decode
else:
paged_attention.attention(
attn_output,
query,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
max_s,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -134,7 +134,6 @@ elif HAS_FLASH_ATTN_V2_CUDA:
v,
out,
cu_seqlens,
block_tables,
max_s,
softmax_scale,
window_size_left=-1,
@ -150,7 +149,7 @@ elif HAS_FLASH_ATTN_V2_CUDA:
cu_seqlens,
cu_seqlens,
None,
block_tables,
None,
None,
max_s,
max_s,

View File

@ -28,9 +28,14 @@ def reshape_and_cache(
key, value, key_cache, value_cache, slots
)
else:
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0
)
if FLASH_DECODING:
shape = key_cache.shape
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else:
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0
)
def attention(