hotfix: fix regression of attention api change in intel platform (#2439)
fix regression caused by attention api change. ipex.varlen_attention does not support paged-cache format kv input now. Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
e279b38aca
commit
5cd8025f18
|
@ -171,5 +171,8 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca
|
|||
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||
|
||||
FROM ${PLATFORM} AS final
|
||||
ENV ATTENTION=paged
|
||||
ENV USE_PREFIX_CACHING=0
|
||||
ENV CUDA_GRAPHS=0
|
||||
ENTRYPOINT ["text-generation-launcher"]
|
||||
CMD ["--json-output"]
|
||||
|
|
|
@ -8,11 +8,11 @@ SUPPORTS_WINDOWING = False
|
|||
|
||||
|
||||
def attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
q: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
block_tables: torch.Tensor,
|
||||
softmax_scale,
|
||||
window_size_left=-1,
|
||||
causal=True,
|
||||
|
@ -23,13 +23,13 @@ def attention(
|
|||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||
ipex.llm.functional.varlen_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
key_cache,
|
||||
value_cache,
|
||||
out,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
max_s,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.max_q,
|
||||
seqlen.max_q,
|
||||
0.0,
|
||||
softmax_scale,
|
||||
False,
|
||||
|
|
|
@ -297,8 +297,8 @@ class FlashCohereAttention(torch.nn.Module):
|
|||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache[0] if SYSTEM != "ipex" else key,
|
||||
kv_cache[1] if SYSTEM != "ipex" else value,
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
|
|
@ -336,8 +336,8 @@ class DbrxAttention(torch.nn.Module):
|
|||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
|
||||
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
|
|
@ -363,8 +363,8 @@ class DeepseekV2Attention(torch.nn.Module):
|
|||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache[0] if SYSTEM != "ipex" else key,
|
||||
kv_cache[1] if SYSTEM != "ipex" else value,
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
|
|
@ -25,7 +25,7 @@ from torch import nn
|
|||
from transformers.activations import ACT2FN
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
|
@ -237,8 +237,8 @@ class FlashGemma2Attention(torch.nn.Module):
|
|||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
|
||||
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
|
|
@ -25,7 +25,7 @@ from torch import nn
|
|||
from transformers.activations import ACT2FN
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
|
@ -231,8 +231,8 @@ class FlashGemmaAttention(torch.nn.Module):
|
|||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
|
||||
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
|
|
@ -24,7 +24,7 @@ import torch.distributed
|
|||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
|
@ -231,8 +231,8 @@ class FlashGPT2Attention(torch.nn.Module):
|
|||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache[0] if SYSTEM != "ipex" else key,
|
||||
kv_cache[1] if SYSTEM != "ipex" else value,
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
|
|
@ -24,7 +24,7 @@ import torch.distributed
|
|||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
|
@ -44,7 +44,6 @@ from text_generation_server.layers.rotary import (
|
|||
from text_generation_server.layers.layernorm import (
|
||||
FastLayerNorm,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
|
||||
def load_attention(config, prefix: str, weights):
|
||||
|
@ -193,8 +192,8 @@ class FlashGPTJAttention(torch.nn.Module):
|
|||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache[0] if SYSTEM != "ipex" else key,
|
||||
kv_cache[1] if SYSTEM != "ipex" else value,
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
|
|
@ -220,8 +220,8 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
|
||||
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
|
|
@ -218,8 +218,8 @@ class MistralAttention(torch.nn.Module):
|
|||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
|
||||
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
|
|
@ -275,8 +275,8 @@ class MixtralAttention(torch.nn.Module):
|
|||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
|
||||
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
|
|
@ -26,7 +26,7 @@ from transformers.activations import ACT2FN
|
|||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.models.gpt_neox import GPTNeoXConfig as TransformersGPTNeoXConfig
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
|
@ -172,8 +172,8 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||
# flash attention
|
||||
attn_output = attention(
|
||||
qkv[:, 0],
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache[0] if SYSTEM != "ipex" else qkv[:, 1],
|
||||
kv_cache[1] if SYSTEM != "ipex" else qkv[:, 2],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
|
|
@ -25,6 +25,7 @@ from text_generation_server.layers.layernorm import (
|
|||
from text_generation_server.layers.rotary import (
|
||||
PositionRotaryEmbedding,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
|
||||
class PhiConfig(PretrainedConfig):
|
||||
|
@ -193,8 +194,8 @@ class FlashPhiAttention(torch.nn.Module):
|
|||
if cu_seqlen_prefill is not None:
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
|
||||
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
|
|
@ -21,6 +21,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
|||
from text_generation_server.layers.layernorm import (
|
||||
FastRMSNorm,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
|
||||
def load_attention(config, prefix, weights):
|
||||
|
@ -136,8 +137,8 @@ class Qwen2Attention(torch.nn.Module):
|
|||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
|
||||
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch.distributed
|
|||
from torch import nn
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.layers import (
|
||||
SpeculativeHead,
|
||||
TensorParallelColumnLinear,
|
||||
|
@ -207,8 +207,8 @@ class FlashRWAttention(torch.nn.Module):
|
|||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
|
||||
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
@ -325,12 +325,10 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
torch.select(kv, dim=2, index=0),
|
||||
torch.select(kv, dim=2, index=1),
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
kv_cache[0] if SYSTEM != "ipex" else kv[:, :, 0].contiguous(),
|
||||
kv_cache[1] if SYSTEM != "ipex" else kv[:, :, 1].contiguous(),
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
|
|
|
@ -22,6 +22,7 @@ from text_generation_server.layers.gptq import GPTQWeightsLoader
|
|||
from text_generation_server.layers.layernorm import (
|
||||
FastLayerNorm,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
|
||||
def load_multi_mqa(
|
||||
|
@ -292,8 +293,8 @@ class FlashMQAttention(torch.nn.Module):
|
|||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache[0] if SYSTEM != "ipex" else key_value[:, 0],
|
||||
kv_cache[1] if SYSTEM != "ipex" else key_value[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
|
|
@ -47,6 +47,7 @@ from text_generation_server.layers.rotary import (
|
|||
PositionRotaryEmbedding,
|
||||
)
|
||||
from text_generation_server.utils.weights import UnquantizedWeight
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
|
||||
class Starcoder2Config(PretrainedConfig):
|
||||
|
@ -241,8 +242,8 @@ class Starcoder2Attention(torch.nn.Module):
|
|||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
|
||||
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
|
|
Loading…
Reference in New Issue