Less intrusive.
This commit is contained in:
parent
cacba5f21f
commit
63e72033b7
|
@ -28,7 +28,6 @@ from transformers.activations import ACT2FN
|
|||
from typing import Optional, List, Tuple
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.models.globals import FLASH_DECODING
|
||||
from text_generation_server.utils import paged_attention, flash_attn
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
|
@ -155,72 +154,34 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
# output tensor
|
||||
attn_output = torch.empty_like(query)
|
||||
|
||||
if FLASH_DECODING:
|
||||
# Prefill
|
||||
kv_cache[0].view(-1, self.num_key_value_heads, self.head_size)[slots] = kv[
|
||||
:, 0
|
||||
]
|
||||
kv_cache[1].view(-1, self.num_key_value_heads, self.head_size)[slots] = kv[
|
||||
:, 1
|
||||
]
|
||||
paged_attention.reshape_and_cache(
|
||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
)
|
||||
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
flash_attn.attention(
|
||||
query,
|
||||
# torch.select(kv, dim=1, index=0),
|
||||
# torch.select(kv, dim=1, index=1),
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
attn_output,
|
||||
cu_seqlen_prefill,
|
||||
block_tables,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention.attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
else:
|
||||
paged_attention.reshape_and_cache(
|
||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
flash_attn.attention(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
attn_output,
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention.attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
flash_attn.attention(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
attn_output,
|
||||
cu_seqlen_prefill,
|
||||
None,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention.attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
|
||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
||||
|
|
|
@ -27,7 +27,6 @@ from transformers.configuration_utils import PretrainedConfig
|
|||
from typing import Optional, List, Tuple
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.models.globals import FLASH_DECODING
|
||||
from text_generation_server.utils import paged_attention, flash_attn
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
|
@ -217,77 +216,39 @@ class MistralAttention(torch.nn.Module):
|
|||
|
||||
attn_output = torch.empty_like(query)
|
||||
|
||||
if FLASH_DECODING:
|
||||
# Prefill
|
||||
kv_cache[0].view(-1, self.num_key_value_heads, self.head_size)[slots] = kv[
|
||||
:, 0
|
||||
]
|
||||
kv_cache[1].view(-1, self.num_key_value_heads, self.head_size)[slots] = kv[
|
||||
:, 1
|
||||
]
|
||||
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
flash_attn.attention(
|
||||
query,
|
||||
# torch.select(kv, dim=1, index=0),
|
||||
# torch.select(kv, dim=1, index=1),
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
attn_output,
|
||||
cu_seqlen_prefill,
|
||||
block_tables,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention.attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
if prefill_cache_indices is not None:
|
||||
kv_to_cache = kv[prefill_cache_indices]
|
||||
else:
|
||||
if prefill_cache_indices is not None:
|
||||
kv_to_cache = kv[prefill_cache_indices]
|
||||
else:
|
||||
kv_to_cache = kv
|
||||
kv_to_cache = kv
|
||||
|
||||
paged_attention.reshape_and_cache(
|
||||
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
paged_attention.reshape_and_cache(
|
||||
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
)
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
flash_attn.attention(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
attn_output,
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention.attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
flash_attn.attention(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
attn_output,
|
||||
cu_seqlen_prefill,
|
||||
None,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention.attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
|
||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
||||
|
|
|
@ -134,7 +134,6 @@ elif HAS_FLASH_ATTN_V2_CUDA:
|
|||
v,
|
||||
out,
|
||||
cu_seqlens,
|
||||
block_tables,
|
||||
max_s,
|
||||
softmax_scale,
|
||||
window_size_left=-1,
|
||||
|
@ -150,7 +149,7 @@ elif HAS_FLASH_ATTN_V2_CUDA:
|
|||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
None,
|
||||
block_tables,
|
||||
None,
|
||||
None,
|
||||
max_s,
|
||||
max_s,
|
||||
|
|
|
@ -28,9 +28,14 @@ def reshape_and_cache(
|
|||
key, value, key_cache, value_cache, slots
|
||||
)
|
||||
else:
|
||||
cache_ops.reshape_and_cache(
|
||||
key, value, key_cache, value_cache, slots, "auto", 1.0
|
||||
)
|
||||
if FLASH_DECODING:
|
||||
shape = key_cache.shape
|
||||
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
||||
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
||||
else:
|
||||
cache_ops.reshape_and_cache(
|
||||
key, value, key_cache, value_cache, slots, "auto", 1.0
|
||||
)
|
||||
|
||||
|
||||
def attention(
|
||||
|
|
Loading…
Reference in New Issue