hotfix: fix regression of attention api change in intel platform (#2439)

fix regression caused by attention api change. ipex.varlen_attention does not support paged-cache
format kv input now.

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi 2024-09-05 23:41:39 +08:00 committed by GitHub
parent e279b38aca
commit 5cd8025f18
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 60 additions and 56 deletions

View File

@ -171,5 +171,8 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
FROM ${PLATFORM} AS final FROM ${PLATFORM} AS final
ENV ATTENTION=paged
ENV USE_PREFIX_CACHING=0
ENV CUDA_GRAPHS=0
ENTRYPOINT ["text-generation-launcher"] ENTRYPOINT ["text-generation-launcher"]
CMD ["--json-output"] CMD ["--json-output"]

View File

@ -8,11 +8,11 @@ SUPPORTS_WINDOWING = False
def attention( def attention(
q, q: torch.Tensor,
k, key_cache: torch.Tensor,
v, value_cache: torch.Tensor,
cu_seqlens, seqlen: Seqlen,
max_s, block_tables: torch.Tensor,
softmax_scale, softmax_scale,
window_size_left=-1, window_size_left=-1,
causal=True, causal=True,
@ -23,13 +23,13 @@ def attention(
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
ipex.llm.functional.varlen_attention( ipex.llm.functional.varlen_attention(
q, q,
k, key_cache,
v, value_cache,
out, out,
cu_seqlens, seqlen.cu_seqlen_q,
cu_seqlens, seqlen.cu_seqlen_q,
max_s, seqlen.max_q,
max_s, seqlen.max_q,
0.0, 0.0,
softmax_scale, softmax_scale,
False, False,

View File

@ -297,8 +297,8 @@ class FlashCohereAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0], kv_cache[0] if SYSTEM != "ipex" else key,
kv_cache[1], kv_cache[1] if SYSTEM != "ipex" else value,
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -336,8 +336,8 @@ class DbrxAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0], kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
kv_cache[1], kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -363,8 +363,8 @@ class DeepseekV2Attention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0], kv_cache[0] if SYSTEM != "ipex" else key,
kv_cache[1], kv_cache[1] if SYSTEM != "ipex" else value,
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -25,7 +25,7 @@ from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
@ -237,8 +237,8 @@ class FlashGemma2Attention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0], kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
kv_cache[1], kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -25,7 +25,7 @@ from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
@ -231,8 +231,8 @@ class FlashGemmaAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0], kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
kv_cache[1], kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -24,7 +24,7 @@ import torch.distributed
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
@ -231,8 +231,8 @@ class FlashGPT2Attention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0], kv_cache[0] if SYSTEM != "ipex" else key,
kv_cache[1], kv_cache[1] if SYSTEM != "ipex" else value,
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -24,7 +24,7 @@ import torch.distributed
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
@ -44,7 +44,6 @@ from text_generation_server.layers.rotary import (
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
) )
from text_generation_server.utils.import_utils import SYSTEM
def load_attention(config, prefix: str, weights): def load_attention(config, prefix: str, weights):
@ -193,8 +192,8 @@ class FlashGPTJAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0], kv_cache[0] if SYSTEM != "ipex" else key,
kv_cache[1], kv_cache[1] if SYSTEM != "ipex" else value,
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -220,8 +220,8 @@ class FlashLlamaAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0], kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
kv_cache[1], kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -218,8 +218,8 @@ class MistralAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0], kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
kv_cache[1], kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -275,8 +275,8 @@ class MixtralAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0], kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
kv_cache[1], kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -26,7 +26,7 @@ from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers.models.gpt_neox import GPTNeoXConfig as TransformersGPTNeoXConfig from transformers.models.gpt_neox import GPTNeoXConfig as TransformersGPTNeoXConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
@ -172,8 +172,8 @@ class FlashNeoxAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
qkv[:, 0], qkv[:, 0],
kv_cache[0], kv_cache[0] if SYSTEM != "ipex" else qkv[:, 1],
kv_cache[1], kv_cache[1] if SYSTEM != "ipex" else qkv[:, 2],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -25,6 +25,7 @@ from text_generation_server.layers.layernorm import (
from text_generation_server.layers.rotary import ( from text_generation_server.layers.rotary import (
PositionRotaryEmbedding, PositionRotaryEmbedding,
) )
from text_generation_server.utils.import_utils import SYSTEM
class PhiConfig(PretrainedConfig): class PhiConfig(PretrainedConfig):
@ -193,8 +194,8 @@ class FlashPhiAttention(torch.nn.Module):
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0], kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
kv_cache[1], kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -21,6 +21,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastRMSNorm, FastRMSNorm,
) )
from text_generation_server.utils.import_utils import SYSTEM
def load_attention(config, prefix, weights): def load_attention(config, prefix, weights):
@ -136,8 +137,8 @@ class Qwen2Attention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0], kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
kv_cache[1], kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -5,7 +5,7 @@ import torch.distributed
from torch import nn from torch import nn
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers import ( from text_generation_server.layers import (
SpeculativeHead, SpeculativeHead,
TensorParallelColumnLinear, TensorParallelColumnLinear,
@ -207,8 +207,8 @@ class FlashRWAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0], kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
kv_cache[1], kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
@ -325,12 +325,10 @@ class FlashRWLargeAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
torch.select(kv, dim=2, index=0), kv_cache[0] if SYSTEM != "ipex" else kv[:, :, 0].contiguous(),
torch.select(kv, dim=2, index=1), kv_cache[1] if SYSTEM != "ipex" else kv[:, :, 1].contiguous(),
kv_cache[0], seqlen,
kv_cache[1], block_tables,
cu_seqlen_prefill,
max_s,
self.softmax_scale, self.softmax_scale,
) )
# Decode # Decode

View File

@ -22,6 +22,7 @@ from text_generation_server.layers.gptq import GPTQWeightsLoader
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
) )
from text_generation_server.utils.import_utils import SYSTEM
def load_multi_mqa( def load_multi_mqa(
@ -292,8 +293,8 @@ class FlashMQAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0], kv_cache[0] if SYSTEM != "ipex" else key_value[:, 0],
kv_cache[1], kv_cache[1] if SYSTEM != "ipex" else key_value[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -47,6 +47,7 @@ from text_generation_server.layers.rotary import (
PositionRotaryEmbedding, PositionRotaryEmbedding,
) )
from text_generation_server.utils.weights import UnquantizedWeight from text_generation_server.utils.weights import UnquantizedWeight
from text_generation_server.utils.import_utils import SYSTEM
class Starcoder2Config(PretrainedConfig): class Starcoder2Config(PretrainedConfig):
@ -241,8 +242,8 @@ class Starcoder2Attention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0], kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
kv_cache[1], kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,