Less intrusive.

This commit is contained in:
Nicolas Patry 2024-05-24 14:15:33 +00:00
parent cacba5f21f
commit 63e72033b7
4 changed files with 65 additions and 139 deletions

View File

@ -28,7 +28,6 @@ from transformers.activations import ACT2FN
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import FLASH_DECODING
from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
@ -155,72 +154,34 @@ class FlashLlamaAttention(torch.nn.Module):
# output tensor # output tensor
attn_output = torch.empty_like(query) attn_output = torch.empty_like(query)
if FLASH_DECODING: paged_attention.reshape_and_cache(
# Prefill kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
kv_cache[0].view(-1, self.num_key_value_heads, self.head_size)[slots] = kv[ )
:, 0
]
kv_cache[1].view(-1, self.num_key_value_heads, self.head_size)[slots] = kv[
:, 1
]
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
flash_attn.attention( flash_attn.attention(
query, query,
# torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
# torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
kv_cache[0], attn_output,
kv_cache[1], cu_seqlen_prefill,
attn_output, max_s,
cu_seqlen_prefill, self.softmax_scale,
block_tables, )
max_s, # Decode
self.softmax_scale, else:
) paged_attention.attention(
# Decode attn_output,
else: query,
paged_attention.attention( kv_cache[0],
attn_output, kv_cache[1],
query, self.kv_head_mapping,
kv_cache[0], self.softmax_scale,
kv_cache[1], block_tables,
self.kv_head_mapping, input_lengths,
self.softmax_scale, max_s,
block_tables,
input_lengths,
max_s,
)
else:
paged_attention.reshape_and_cache(
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
) )
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
flash_attn.attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill,
None,
max_s,
self.softmax_scale,
)
# Decode
else:
paged_attention.attention(
attn_output,
query,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
max_s,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -27,7 +27,6 @@ from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import FLASH_DECODING
from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
@ -217,77 +216,39 @@ class MistralAttention(torch.nn.Module):
attn_output = torch.empty_like(query) attn_output = torch.empty_like(query)
if FLASH_DECODING: if prefill_cache_indices is not None:
# Prefill kv_to_cache = kv[prefill_cache_indices]
kv_cache[0].view(-1, self.num_key_value_heads, self.head_size)[slots] = kv[
:, 0
]
kv_cache[1].view(-1, self.num_key_value_heads, self.head_size)[slots] = kv[
:, 1
]
if cu_seqlen_prefill is not None:
# flash attention
flash_attn.attention(
query,
# torch.select(kv, dim=1, index=0),
# torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
attn_output,
cu_seqlen_prefill,
block_tables,
max_s,
self.softmax_scale,
)
# Decode
else:
paged_attention.attention(
attn_output,
query,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
max_s,
)
else: else:
if prefill_cache_indices is not None: kv_to_cache = kv
kv_to_cache = kv[prefill_cache_indices]
else:
kv_to_cache = kv
paged_attention.reshape_and_cache( paged_attention.reshape_and_cache(
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
flash_attn.attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill,
max_s,
self.softmax_scale,
)
# Decode
else:
paged_attention.attention(
attn_output,
query,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
max_s,
) )
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
flash_attn.attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill,
None,
max_s,
self.softmax_scale,
)
# Decode
else:
paged_attention.attention(
attn_output,
query,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
max_s,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -134,7 +134,6 @@ elif HAS_FLASH_ATTN_V2_CUDA:
v, v,
out, out,
cu_seqlens, cu_seqlens,
block_tables,
max_s, max_s,
softmax_scale, softmax_scale,
window_size_left=-1, window_size_left=-1,
@ -150,7 +149,7 @@ elif HAS_FLASH_ATTN_V2_CUDA:
cu_seqlens, cu_seqlens,
cu_seqlens, cu_seqlens,
None, None,
block_tables, None,
None, None,
max_s, max_s,
max_s, max_s,

View File

@ -28,9 +28,14 @@ def reshape_and_cache(
key, value, key_cache, value_cache, slots key, value, key_cache, value_cache, slots
) )
else: else:
cache_ops.reshape_and_cache( if FLASH_DECODING:
key, value, key_cache, value_cache, slots, "auto", 1.0 shape = key_cache.shape
) key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else:
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0
)
def attention( def attention(