fix FlashDecoding change's regression in intel platform (#2161)
install triton because GPTQParams needs it. Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
022f6515a4
commit
5d97e0c4a3
|
@ -62,6 +62,7 @@ ENV HUGGINGFACE_HUB_CACHE=/data \
|
||||||
|
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/torch-2.1.0.post1%2Bcxx11.abi-cp310-cp310-linux_x86_64.whl && pip install torch-2.1.0.post1+cxx11.abi-cp310-cp310-linux_x86_64.whl
|
RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/torch-2.1.0.post1%2Bcxx11.abi-cp310-cp310-linux_x86_64.whl && pip install torch-2.1.0.post1+cxx11.abi-cp310-cp310-linux_x86_64.whl
|
||||||
|
RUN pip install https://github.com/intel/intel-xpu-backend-for-triton/releases/download/v2.1.0/triton-2.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
|
||||||
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b distributed origin/dev/distributed
|
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b distributed origin/dev/distributed
|
||||||
|
|
||||||
# Install server
|
# Install server
|
||||||
|
@ -132,6 +133,7 @@ RUN conda install -c conda-forge gperftools mkl
|
||||||
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
|
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
|
||||||
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.19.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
|
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.19.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
|
||||||
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
|
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
|
||||||
|
RUN pip install triton
|
||||||
|
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
import torch
|
import torch
|
||||||
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
|
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
|
||||||
|
from text_generation_server.layers.attention import Seqlen
|
||||||
|
|
||||||
SUPPORTS_WINDOWING = False
|
SUPPORTS_WINDOWING = False
|
||||||
|
|
||||||
|
@ -55,11 +56,10 @@ def paged_attention(
|
||||||
kv_head_mapping: torch.Tensor,
|
kv_head_mapping: torch.Tensor,
|
||||||
softmax_scale: float,
|
softmax_scale: float,
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
cu_seqlen_q: torch.Tensor,
|
seqlen: Seqlen,
|
||||||
cu_seqlen_k: torch.Tensor,
|
|
||||||
max_s: int,
|
max_s: int,
|
||||||
):
|
):
|
||||||
return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
||||||
out,
|
out,
|
||||||
query,
|
query,
|
||||||
key_cache,
|
key_cache,
|
||||||
|
@ -67,8 +67,9 @@ def paged_attention(
|
||||||
kv_head_mapping,
|
kv_head_mapping,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
cu_seqlen_q,
|
seqlen.input_lengths,
|
||||||
BLOCK_SIZE,
|
BLOCK_SIZE,
|
||||||
max_s,
|
max_s,
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
return out
|
||||||
|
|
Loading…
Reference in New Issue