Less intrusive.

This commit is contained in:
Nicolas Patry 2024-05-24 14:15:33 +00:00
parent cacba5f21f
commit 63e72033b7
4 changed files with 65 additions and 139 deletions

View File

@ -28,7 +28,6 @@ from transformers.activations import ACT2FN
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import FLASH_DECODING
from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
@ -155,47 +154,10 @@ class FlashLlamaAttention(torch.nn.Module):
# output tensor # output tensor
attn_output = torch.empty_like(query) attn_output = torch.empty_like(query)
if FLASH_DECODING:
# Prefill
kv_cache[0].view(-1, self.num_key_value_heads, self.head_size)[slots] = kv[
:, 0
]
kv_cache[1].view(-1, self.num_key_value_heads, self.head_size)[slots] = kv[
:, 1
]
if cu_seqlen_prefill is not None:
# flash attention
flash_attn.attention(
query,
# torch.select(kv, dim=1, index=0),
# torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
attn_output,
cu_seqlen_prefill,
block_tables,
max_s,
self.softmax_scale,
)
# Decode
else:
paged_attention.attention(
attn_output,
query,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
max_s,
)
else:
paged_attention.reshape_and_cache( paged_attention.reshape_and_cache(
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
) )
# Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
flash_attn.attention( flash_attn.attention(
@ -204,7 +166,6 @@ class FlashLlamaAttention(torch.nn.Module):
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
attn_output, attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
None,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
) )

View File

@ -27,7 +27,6 @@ from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import FLASH_DECODING
from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
@ -217,43 +216,6 @@ class MistralAttention(torch.nn.Module):
attn_output = torch.empty_like(query) attn_output = torch.empty_like(query)
if FLASH_DECODING:
# Prefill
kv_cache[0].view(-1, self.num_key_value_heads, self.head_size)[slots] = kv[
:, 0
]
kv_cache[1].view(-1, self.num_key_value_heads, self.head_size)[slots] = kv[
:, 1
]
if cu_seqlen_prefill is not None:
# flash attention
flash_attn.attention(
query,
# torch.select(kv, dim=1, index=0),
# torch.select(kv, dim=1, index=1),
kv_cache[0],
kv_cache[1],
attn_output,
cu_seqlen_prefill,
block_tables,
max_s,
self.softmax_scale,
)
# Decode
else:
paged_attention.attention(
attn_output,
query,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
max_s,
)
else:
if prefill_cache_indices is not None: if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices] kv_to_cache = kv[prefill_cache_indices]
else: else:
@ -271,7 +233,6 @@ class MistralAttention(torch.nn.Module):
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
attn_output, attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
None,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
) )

View File

@ -134,7 +134,6 @@ elif HAS_FLASH_ATTN_V2_CUDA:
v, v,
out, out,
cu_seqlens, cu_seqlens,
block_tables,
max_s, max_s,
softmax_scale, softmax_scale,
window_size_left=-1, window_size_left=-1,
@ -150,7 +149,7 @@ elif HAS_FLASH_ATTN_V2_CUDA:
cu_seqlens, cu_seqlens,
cu_seqlens, cu_seqlens,
None, None,
block_tables, None,
None, None,
max_s, max_s,
max_s, max_s,

View File

@ -27,6 +27,11 @@ def reshape_and_cache(
ipex.llm.modules.PagedAttention.reshape_and_cache( ipex.llm.modules.PagedAttention.reshape_and_cache(
key, value, key_cache, value_cache, slots key, value, key_cache, value_cache, slots
) )
else:
if FLASH_DECODING:
shape = key_cache.shape
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else: else:
cache_ops.reshape_and_cache( cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0 key, value, key_cache, value_cache, slots, "auto", 1.0