hotfix: fix regression of attention api change in intel platform (#2439)
fix regression caused by attention api change. ipex.varlen_attention does not support paged-cache format kv input now. Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
e279b38aca
commit
5cd8025f18
|
@ -171,5 +171,8 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca
|
||||||
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
|
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||||
|
|
||||||
FROM ${PLATFORM} AS final
|
FROM ${PLATFORM} AS final
|
||||||
|
ENV ATTENTION=paged
|
||||||
|
ENV USE_PREFIX_CACHING=0
|
||||||
|
ENV CUDA_GRAPHS=0
|
||||||
ENTRYPOINT ["text-generation-launcher"]
|
ENTRYPOINT ["text-generation-launcher"]
|
||||||
CMD ["--json-output"]
|
CMD ["--json-output"]
|
||||||
|
|
|
@ -8,11 +8,11 @@ SUPPORTS_WINDOWING = False
|
||||||
|
|
||||||
|
|
||||||
def attention(
|
def attention(
|
||||||
q,
|
q: torch.Tensor,
|
||||||
k,
|
key_cache: torch.Tensor,
|
||||||
v,
|
value_cache: torch.Tensor,
|
||||||
cu_seqlens,
|
seqlen: Seqlen,
|
||||||
max_s,
|
block_tables: torch.Tensor,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
window_size_left=-1,
|
window_size_left=-1,
|
||||||
causal=True,
|
causal=True,
|
||||||
|
@ -23,13 +23,13 @@ def attention(
|
||||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||||
ipex.llm.functional.varlen_attention(
|
ipex.llm.functional.varlen_attention(
|
||||||
q,
|
q,
|
||||||
k,
|
key_cache,
|
||||||
v,
|
value_cache,
|
||||||
out,
|
out,
|
||||||
cu_seqlens,
|
seqlen.cu_seqlen_q,
|
||||||
cu_seqlens,
|
seqlen.cu_seqlen_q,
|
||||||
max_s,
|
seqlen.max_q,
|
||||||
max_s,
|
seqlen.max_q,
|
||||||
0.0,
|
0.0,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
False,
|
False,
|
||||||
|
|
|
@ -297,8 +297,8 @@ class FlashCohereAttention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0] if SYSTEM != "ipex" else key,
|
||||||
kv_cache[1],
|
kv_cache[1] if SYSTEM != "ipex" else value,
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -336,8 +336,8 @@ class DbrxAttention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
|
||||||
kv_cache[1],
|
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -363,8 +363,8 @@ class DeepseekV2Attention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0] if SYSTEM != "ipex" else key,
|
||||||
kv_cache[1],
|
kv_cache[1] if SYSTEM != "ipex" else value,
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -25,7 +25,7 @@ from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
|
@ -237,8 +237,8 @@ class FlashGemma2Attention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
|
||||||
kv_cache[1],
|
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -25,7 +25,7 @@ from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
|
@ -231,8 +231,8 @@ class FlashGemmaAttention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
|
||||||
kv_cache[1],
|
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -24,7 +24,7 @@ import torch.distributed
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
|
@ -231,8 +231,8 @@ class FlashGPT2Attention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0] if SYSTEM != "ipex" else key,
|
||||||
kv_cache[1],
|
kv_cache[1] if SYSTEM != "ipex" else value,
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -24,7 +24,7 @@ import torch.distributed
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
|
@ -44,7 +44,6 @@ from text_generation_server.layers.rotary import (
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastLayerNorm,
|
FastLayerNorm,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
|
|
||||||
def load_attention(config, prefix: str, weights):
|
def load_attention(config, prefix: str, weights):
|
||||||
|
@ -193,8 +192,8 @@ class FlashGPTJAttention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0] if SYSTEM != "ipex" else key,
|
||||||
kv_cache[1],
|
kv_cache[1] if SYSTEM != "ipex" else value,
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -220,8 +220,8 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
|
||||||
kv_cache[1],
|
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -218,8 +218,8 @@ class MistralAttention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
|
||||||
kv_cache[1],
|
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -275,8 +275,8 @@ class MixtralAttention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
|
||||||
kv_cache[1],
|
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -26,7 +26,7 @@ from transformers.activations import ACT2FN
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
from transformers.models.gpt_neox import GPTNeoXConfig as TransformersGPTNeoXConfig
|
from transformers.models.gpt_neox import GPTNeoXConfig as TransformersGPTNeoXConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
|
@ -172,8 +172,8 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
qkv[:, 0],
|
qkv[:, 0],
|
||||||
kv_cache[0],
|
kv_cache[0] if SYSTEM != "ipex" else qkv[:, 1],
|
||||||
kv_cache[1],
|
kv_cache[1] if SYSTEM != "ipex" else qkv[:, 2],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -25,6 +25,7 @@ from text_generation_server.layers.layernorm import (
|
||||||
from text_generation_server.layers.rotary import (
|
from text_generation_server.layers.rotary import (
|
||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
|
|
||||||
class PhiConfig(PretrainedConfig):
|
class PhiConfig(PretrainedConfig):
|
||||||
|
@ -193,8 +194,8 @@ class FlashPhiAttention(torch.nn.Module):
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
|
||||||
kv_cache[1],
|
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -21,6 +21,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
|
|
||||||
def load_attention(config, prefix, weights):
|
def load_attention(config, prefix, weights):
|
||||||
|
@ -136,8 +137,8 @@ class Qwen2Attention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
|
||||||
kv_cache[1],
|
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -5,7 +5,7 @@ import torch.distributed
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
|
@ -207,8 +207,8 @@ class FlashRWAttention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
|
||||||
kv_cache[1],
|
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -325,12 +325,10 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=2, index=0),
|
kv_cache[0] if SYSTEM != "ipex" else kv[:, :, 0].contiguous(),
|
||||||
torch.select(kv, dim=2, index=1),
|
kv_cache[1] if SYSTEM != "ipex" else kv[:, :, 1].contiguous(),
|
||||||
kv_cache[0],
|
seqlen,
|
||||||
kv_cache[1],
|
block_tables,
|
||||||
cu_seqlen_prefill,
|
|
||||||
max_s,
|
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
|
|
|
@ -22,6 +22,7 @@ from text_generation_server.layers.gptq import GPTQWeightsLoader
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastLayerNorm,
|
FastLayerNorm,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
|
|
||||||
def load_multi_mqa(
|
def load_multi_mqa(
|
||||||
|
@ -292,8 +293,8 @@ class FlashMQAttention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0] if SYSTEM != "ipex" else key_value[:, 0],
|
||||||
kv_cache[1],
|
kv_cache[1] if SYSTEM != "ipex" else key_value[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -47,6 +47,7 @@ from text_generation_server.layers.rotary import (
|
||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.weights import UnquantizedWeight
|
from text_generation_server.utils.weights import UnquantizedWeight
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
|
|
||||||
class Starcoder2Config(PretrainedConfig):
|
class Starcoder2Config(PretrainedConfig):
|
||||||
|
@ -241,8 +242,8 @@ class Starcoder2Attention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
|
||||||
kv_cache[1],
|
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
Loading…
Reference in New Issue