hotfix: fix regression of attention api change in intel platform (#2439)

fix regression caused by attention api change. ipex.varlen_attention does not support paged-cache
format kv input now.

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi 2024-09-05 23:41:39 +08:00 committed by GitHub
parent e279b38aca
commit 5cd8025f18
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 60 additions and 56 deletions

View File

@ -171,5 +171,8 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
FROM ${PLATFORM} AS final
ENV ATTENTION=paged
ENV USE_PREFIX_CACHING=0
ENV CUDA_GRAPHS=0
ENTRYPOINT ["text-generation-launcher"]
CMD ["--json-output"]

View File

@ -8,11 +8,11 @@ SUPPORTS_WINDOWING = False
def attention(
q,
k,
v,
cu_seqlens,
max_s,
q: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale,
window_size_left=-1,
causal=True,
@ -23,13 +23,13 @@ def attention(
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
ipex.llm.functional.varlen_attention(
q,
k,
v,
key_cache,
value_cache,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
seqlen.max_q,
seqlen.max_q,
0.0,
softmax_scale,
False,

View File

@ -297,8 +297,8 @@ class FlashCohereAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0],
kv_cache[1],
kv_cache[0] if SYSTEM != "ipex" else key,
kv_cache[1] if SYSTEM != "ipex" else value,
seqlen,
block_tables,
self.softmax_scale,

View File

@ -336,8 +336,8 @@ class DbrxAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0],
kv_cache[1],
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
seqlen,
block_tables,
self.softmax_scale,

View File

@ -363,8 +363,8 @@ class DeepseekV2Attention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0],
kv_cache[1],
kv_cache[0] if SYSTEM != "ipex" else key,
kv_cache[1] if SYSTEM != "ipex" else value,
seqlen,
block_tables,
self.softmax_scale,

View File

@ -25,7 +25,7 @@ from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import (
paged_attention,
attention,
@ -237,8 +237,8 @@ class FlashGemma2Attention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0],
kv_cache[1],
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
seqlen,
block_tables,
self.softmax_scale,

View File

@ -25,7 +25,7 @@ from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import (
paged_attention,
attention,
@ -231,8 +231,8 @@ class FlashGemmaAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0],
kv_cache[1],
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
seqlen,
block_tables,
self.softmax_scale,

View File

@ -24,7 +24,7 @@ import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import (
paged_attention,
attention,
@ -231,8 +231,8 @@ class FlashGPT2Attention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0],
kv_cache[1],
kv_cache[0] if SYSTEM != "ipex" else key,
kv_cache[1] if SYSTEM != "ipex" else value,
seqlen,
block_tables,
self.softmax_scale,

View File

@ -24,7 +24,7 @@ import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import (
paged_attention,
attention,
@ -44,7 +44,6 @@ from text_generation_server.layers.rotary import (
from text_generation_server.layers.layernorm import (
FastLayerNorm,
)
from text_generation_server.utils.import_utils import SYSTEM
def load_attention(config, prefix: str, weights):
@ -193,8 +192,8 @@ class FlashGPTJAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0],
kv_cache[1],
kv_cache[0] if SYSTEM != "ipex" else key,
kv_cache[1] if SYSTEM != "ipex" else value,
seqlen,
block_tables,
self.softmax_scale,

View File

@ -220,8 +220,8 @@ class FlashLlamaAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0],
kv_cache[1],
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
seqlen,
block_tables,
self.softmax_scale,

View File

@ -218,8 +218,8 @@ class MistralAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0],
kv_cache[1],
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
seqlen,
block_tables,
self.softmax_scale,

View File

@ -275,8 +275,8 @@ class MixtralAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0],
kv_cache[1],
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
seqlen,
block_tables,
self.softmax_scale,

View File

@ -26,7 +26,7 @@ from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel
from transformers.models.gpt_neox import GPTNeoXConfig as TransformersGPTNeoXConfig
from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import (
paged_attention,
attention,
@ -172,8 +172,8 @@ class FlashNeoxAttention(torch.nn.Module):
# flash attention
attn_output = attention(
qkv[:, 0],
kv_cache[0],
kv_cache[1],
kv_cache[0] if SYSTEM != "ipex" else qkv[:, 1],
kv_cache[1] if SYSTEM != "ipex" else qkv[:, 2],
seqlen,
block_tables,
self.softmax_scale,

View File

@ -25,6 +25,7 @@ from text_generation_server.layers.layernorm import (
from text_generation_server.layers.rotary import (
PositionRotaryEmbedding,
)
from text_generation_server.utils.import_utils import SYSTEM
class PhiConfig(PretrainedConfig):
@ -193,8 +194,8 @@ class FlashPhiAttention(torch.nn.Module):
if cu_seqlen_prefill is not None:
attn_output = attention(
query,
kv_cache[0],
kv_cache[1],
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
seqlen,
block_tables,
self.softmax_scale,

View File

@ -21,6 +21,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
FastRMSNorm,
)
from text_generation_server.utils.import_utils import SYSTEM
def load_attention(config, prefix, weights):
@ -136,8 +137,8 @@ class Qwen2Attention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0],
kv_cache[1],
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
seqlen,
block_tables,
self.softmax_scale,

View File

@ -5,7 +5,7 @@ import torch.distributed
from torch import nn
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_utils import PreTrainedModel
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers import (
SpeculativeHead,
TensorParallelColumnLinear,
@ -207,8 +207,8 @@ class FlashRWAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0],
kv_cache[1],
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
seqlen,
block_tables,
self.softmax_scale,
@ -325,12 +325,10 @@ class FlashRWLargeAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
torch.select(kv, dim=2, index=0),
torch.select(kv, dim=2, index=1),
kv_cache[0],
kv_cache[1],
cu_seqlen_prefill,
max_s,
kv_cache[0] if SYSTEM != "ipex" else kv[:, :, 0].contiguous(),
kv_cache[1] if SYSTEM != "ipex" else kv[:, :, 1].contiguous(),
seqlen,
block_tables,
self.softmax_scale,
)
# Decode

View File

@ -22,6 +22,7 @@ from text_generation_server.layers.gptq import GPTQWeightsLoader
from text_generation_server.layers.layernorm import (
FastLayerNorm,
)
from text_generation_server.utils.import_utils import SYSTEM
def load_multi_mqa(
@ -292,8 +293,8 @@ class FlashMQAttention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0],
kv_cache[1],
kv_cache[0] if SYSTEM != "ipex" else key_value[:, 0],
kv_cache[1] if SYSTEM != "ipex" else key_value[:, 1],
seqlen,
block_tables,
self.softmax_scale,

View File

@ -47,6 +47,7 @@ from text_generation_server.layers.rotary import (
PositionRotaryEmbedding,
)
from text_generation_server.utils.weights import UnquantizedWeight
from text_generation_server.utils.import_utils import SYSTEM
class Starcoder2Config(PretrainedConfig):
@ -241,8 +242,8 @@ class Starcoder2Attention(torch.nn.Module):
# flash attention
attn_output = attention(
query,
kv_cache[0],
kv_cache[1],
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
seqlen,
block_tables,
self.softmax_scale,