From 5cd8025f1849bd4c13edcf9eb4f72e199e6a5c37 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Thu, 5 Sep 2024 23:41:39 +0800 Subject: [PATCH] hotfix: fix regression of attention api change in intel platform (#2439) fix regression caused by attention api change. ipex.varlen_attention does not support paged-cache format kv input now. Signed-off-by: Wang, Yi A --- Dockerfile_intel | 3 +++ .../layers/attention/ipex.py | 22 +++++++++---------- .../custom_modeling/flash_cohere_modeling.py | 4 ++-- .../custom_modeling/flash_dbrx_modeling.py | 4 ++-- .../flash_deepseek_v2_modeling.py | 4 ++-- .../custom_modeling/flash_gemma2_modeling.py | 6 ++--- .../custom_modeling/flash_gemma_modeling.py | 6 ++--- .../custom_modeling/flash_gpt2_modeling.py | 6 ++--- .../custom_modeling/flash_gptj_modeling.py | 7 +++--- .../custom_modeling/flash_llama_modeling.py | 4 ++-- .../custom_modeling/flash_mistral_modeling.py | 4 ++-- .../custom_modeling/flash_mixtral_modeling.py | 4 ++-- .../custom_modeling/flash_neox_modeling.py | 6 ++--- .../custom_modeling/flash_phi_modeling.py | 5 +++-- .../custom_modeling/flash_qwen2_modeling.py | 5 +++-- .../custom_modeling/flash_rw_modeling.py | 16 ++++++-------- .../flash_santacoder_modeling.py | 5 +++-- .../flash_starcoder2_modeling.py | 5 +++-- 18 files changed, 60 insertions(+), 56 deletions(-) diff --git a/Dockerfile_intel b/Dockerfile_intel index 9af6422c..0cda4d4b 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -171,5 +171,8 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher FROM ${PLATFORM} AS final +ENV ATTENTION=paged +ENV USE_PREFIX_CACHING=0 +ENV CUDA_GRAPHS=0 ENTRYPOINT ["text-generation-launcher"] CMD ["--json-output"] diff --git a/server/text_generation_server/layers/attention/ipex.py b/server/text_generation_server/layers/attention/ipex.py index d7cf780a..2d1427ae 100644 --- a/server/text_generation_server/layers/attention/ipex.py +++ b/server/text_generation_server/layers/attention/ipex.py @@ -8,11 +8,11 @@ SUPPORTS_WINDOWING = False def attention( - q, - k, - v, - cu_seqlens, - max_s, + q: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + seqlen: Seqlen, + block_tables: torch.Tensor, softmax_scale, window_size_left=-1, causal=True, @@ -23,13 +23,13 @@ def attention( # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. ipex.llm.functional.varlen_attention( q, - k, - v, + key_cache, + value_cache, out, - cu_seqlens, - cu_seqlens, - max_s, - max_s, + seqlen.cu_seqlen_q, + seqlen.cu_seqlen_q, + seqlen.max_q, + seqlen.max_q, 0.0, softmax_scale, False, diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index fe19180a..374ccb10 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -297,8 +297,8 @@ class FlashCohereAttention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0], - kv_cache[1], + kv_cache[0] if SYSTEM != "ipex" else key, + kv_cache[1] if SYSTEM != "ipex" else value, seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index b82b5473..0dc88098 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -336,8 +336,8 @@ class DbrxAttention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0], - kv_cache[1], + kv_cache[0] if SYSTEM != "ipex" else kv[:, 0], + kv_cache[1] if SYSTEM != "ipex" else kv[:, 1], seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py index 0585b40e..f62dfe66 100644 --- a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -363,8 +363,8 @@ class DeepseekV2Attention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0], - kv_cache[1], + kv_cache[0] if SYSTEM != "ipex" else key, + kv_cache[1] if SYSTEM != "ipex" else value, seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index d16e805f..e12bff00 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -25,7 +25,7 @@ from torch import nn from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple - +from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( paged_attention, attention, @@ -237,8 +237,8 @@ class FlashGemma2Attention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0], - kv_cache[1], + kv_cache[0] if SYSTEM != "ipex" else kv[:, 0], + kv_cache[1] if SYSTEM != "ipex" else kv[:, 1], seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 34be4cb8..77ae4b35 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -25,7 +25,7 @@ from torch import nn from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple - +from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( paged_attention, attention, @@ -231,8 +231,8 @@ class FlashGemmaAttention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0], - kv_cache[1], + kv_cache[0] if SYSTEM != "ipex" else kv[:, 0], + kv_cache[1] if SYSTEM != "ipex" else kv[:, 1], seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index 403fa908..411c4ce1 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -24,7 +24,7 @@ import torch.distributed from torch import nn from transformers.activations import ACT2FN from typing import Optional, List, Tuple - +from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( paged_attention, attention, @@ -231,8 +231,8 @@ class FlashGPT2Attention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0], - kv_cache[1], + kv_cache[0] if SYSTEM != "ipex" else key, + kv_cache[1] if SYSTEM != "ipex" else value, seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py index 35ab2791..ef071d46 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py @@ -24,7 +24,7 @@ import torch.distributed from torch import nn from transformers.activations import ACT2FN from typing import Optional, List, Tuple - +from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( paged_attention, attention, @@ -44,7 +44,6 @@ from text_generation_server.layers.rotary import ( from text_generation_server.layers.layernorm import ( FastLayerNorm, ) -from text_generation_server.utils.import_utils import SYSTEM def load_attention(config, prefix: str, weights): @@ -193,8 +192,8 @@ class FlashGPTJAttention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0], - kv_cache[1], + kv_cache[0] if SYSTEM != "ipex" else key, + kv_cache[1] if SYSTEM != "ipex" else value, seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index ae981c9a..7d639e35 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -220,8 +220,8 @@ class FlashLlamaAttention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0], - kv_cache[1], + kv_cache[0] if SYSTEM != "ipex" else kv[:, 0], + kv_cache[1] if SYSTEM != "ipex" else kv[:, 1], seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 30ca3faf..cdd23796 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -218,8 +218,8 @@ class MistralAttention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0], - kv_cache[1], + kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0], + kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1], seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index c5d60af1..c36e97f6 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -275,8 +275,8 @@ class MixtralAttention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0], - kv_cache[1], + kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0], + kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1], seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index fda648f9..454e45eb 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -26,7 +26,7 @@ from transformers.activations import ACT2FN from transformers.modeling_utils import PreTrainedModel from transformers.models.gpt_neox import GPTNeoXConfig as TransformersGPTNeoXConfig from typing import Optional, List, Tuple - +from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( paged_attention, attention, @@ -172,8 +172,8 @@ class FlashNeoxAttention(torch.nn.Module): # flash attention attn_output = attention( qkv[:, 0], - kv_cache[0], - kv_cache[1], + kv_cache[0] if SYSTEM != "ipex" else qkv[:, 1], + kv_cache[1] if SYSTEM != "ipex" else qkv[:, 2], seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index 37adb8be..e2d9bbbc 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -25,6 +25,7 @@ from text_generation_server.layers.layernorm import ( from text_generation_server.layers.rotary import ( PositionRotaryEmbedding, ) +from text_generation_server.utils.import_utils import SYSTEM class PhiConfig(PretrainedConfig): @@ -193,8 +194,8 @@ class FlashPhiAttention(torch.nn.Module): if cu_seqlen_prefill is not None: attn_output = attention( query, - kv_cache[0], - kv_cache[1], + kv_cache[0] if SYSTEM != "ipex" else kv[:, 0], + kv_cache[1] if SYSTEM != "ipex" else kv[:, 1], seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 5aac28a3..999b72e7 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -21,6 +21,7 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, ) +from text_generation_server.utils.import_utils import SYSTEM def load_attention(config, prefix, weights): @@ -136,8 +137,8 @@ class Qwen2Attention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0], - kv_cache[1], + kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0], + kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1], seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 1c55dd91..edc54c09 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -5,7 +5,7 @@ import torch.distributed from torch import nn from transformers.configuration_utils import PretrainedConfig from transformers.modeling_utils import PreTrainedModel - +from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers import ( SpeculativeHead, TensorParallelColumnLinear, @@ -207,8 +207,8 @@ class FlashRWAttention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0], - kv_cache[1], + kv_cache[0] if SYSTEM != "ipex" else kv[:, 0], + kv_cache[1] if SYSTEM != "ipex" else kv[:, 1], seqlen, block_tables, self.softmax_scale, @@ -325,12 +325,10 @@ class FlashRWLargeAttention(torch.nn.Module): # flash attention attn_output = attention( query, - torch.select(kv, dim=2, index=0), - torch.select(kv, dim=2, index=1), - kv_cache[0], - kv_cache[1], - cu_seqlen_prefill, - max_s, + kv_cache[0] if SYSTEM != "ipex" else kv[:, :, 0].contiguous(), + kv_cache[1] if SYSTEM != "ipex" else kv[:, :, 1].contiguous(), + seqlen, + block_tables, self.softmax_scale, ) # Decode diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 19025c4c..f97b4409 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -22,6 +22,7 @@ from text_generation_server.layers.gptq import GPTQWeightsLoader from text_generation_server.layers.layernorm import ( FastLayerNorm, ) +from text_generation_server.utils.import_utils import SYSTEM def load_multi_mqa( @@ -292,8 +293,8 @@ class FlashMQAttention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0], - kv_cache[1], + kv_cache[0] if SYSTEM != "ipex" else key_value[:, 0], + kv_cache[1] if SYSTEM != "ipex" else key_value[:, 1], seqlen, block_tables, self.softmax_scale, diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index 2f9ecd0d..6aa7fa21 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -47,6 +47,7 @@ from text_generation_server.layers.rotary import ( PositionRotaryEmbedding, ) from text_generation_server.utils.weights import UnquantizedWeight +from text_generation_server.utils.import_utils import SYSTEM class Starcoder2Config(PretrainedConfig): @@ -241,8 +242,8 @@ class Starcoder2Attention(torch.nn.Module): # flash attention attn_output = attention( query, - kv_cache[0], - kv_cache[1], + kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0], + kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1], seqlen, block_tables, self.softmax_scale,