added custom PA
This commit is contained in:
parent
e557855558
commit
ff0505e7f9
|
@ -39,7 +39,7 @@ COPY launcher launcher
|
|||
RUN cargo build --profile release-opt
|
||||
|
||||
# Text Generation Inference base image for RoCm
|
||||
FROM rocm/dev-ubuntu-22.04:6.1.1_hip_update AS base
|
||||
FROM rocm/dev-ubuntu-22.04:6.2 AS base
|
||||
|
||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
|
@ -48,23 +48,25 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
|
|||
curl \
|
||||
git \
|
||||
make \
|
||||
libmsgpack-dev \
|
||||
libssl-dev \
|
||||
llvm-dev \
|
||||
g++ \
|
||||
# Needed to build VLLM & flash.
|
||||
rocthrust-dev \
|
||||
hipsparse-dev \
|
||||
hipblas-dev \
|
||||
hipblaslt-dev \
|
||||
hipcub-dev \
|
||||
rocblas-dev \
|
||||
hiprand-dev \
|
||||
hipfft-dev \
|
||||
rocrand-dev \
|
||||
miopen-hip-dev \
|
||||
hipfft-dev \
|
||||
hipcub-dev \
|
||||
hipsolver-dev \
|
||||
rccl-dev \
|
||||
cmake \
|
||||
python3-dev && \
|
||||
python3-dev \
|
||||
python3-venv && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Keep in sync with `server/pyproject.toml
|
||||
|
@ -74,7 +76,30 @@ ARG ROCM_VERSION='6.0.2'
|
|||
ARG PYTHON_VERSION='3.10.10'
|
||||
# Automatically set by buildx
|
||||
ARG TARGETPLATFORM
|
||||
ENV PATH /opt/conda/bin:$PATH
|
||||
ENV PATH=/opt/conda/bin:$PATH
|
||||
|
||||
ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942"
|
||||
|
||||
RUN curl -fsSL -v -o cmake-3.30.2-linux-x86_64.sh https://github.com/Kitware/CMake/releases/download/v3.30.2/cmake-3.30.2-linux-x86_64.sh \
|
||||
&& chmod +x cmake-3.30.2-linux-x86_64.sh \
|
||||
&& ./cmake-3.30.2-linux-x86_64.sh --skip-license --prefix=/usr/local \
|
||||
&& rm cmake-3.30.2-linux-x86_64.sh
|
||||
|
||||
RUN pip install joblib msgpack
|
||||
|
||||
# Install HIPBLASLt
|
||||
ARG HIPBLASLT_BRANCH="6f65c6e"
|
||||
RUN git clone https://github.com/ROCm/hipBLASLt \
|
||||
&& cd hipBLASLt \
|
||||
&& git checkout ${HIPBLASLT_BRANCH} \
|
||||
&& SCCACHE_IDLE_TIMEOUT=1800 ./install.sh --architecture ${PYTORCH_ROCM_ARCH} \
|
||||
&& cd build/release \
|
||||
&& make package
|
||||
RUN dpkg -i hipBLASLt/build/release/*.deb \
|
||||
&& sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
|
||||
&& sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status;
|
||||
# && cd .. \
|
||||
# && rm -rf hipBLASLt
|
||||
|
||||
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
|
||||
# Install mamba
|
||||
|
@ -98,15 +123,15 @@ RUN pip uninstall -y triton && \
|
|||
cd triton/python && \
|
||||
pip install .
|
||||
|
||||
ARG PYTORCH_COMMIT="da320214e66b5af0f7db8fd18a64dbb519d17b27"
|
||||
RUN git clone --depth 1 --recursive --single-branch --branch main https://github.com/pytorch/pytorch.git pytorch && \
|
||||
cd pytorch && git fetch --depth 1 origin da320214e66b5af0f7db8fd18a64dbb519d17b27 && \
|
||||
git checkout da320214e66b5af0f7db8fd18a64dbb519d17b27 && \
|
||||
cd pytorch && git fetch --depth 1 origin ${PYTORCH_COMMIT} && \
|
||||
git checkout ${PYTORCH_COMMIT} && \
|
||||
git submodule update --init --recursive && \
|
||||
pip install -r requirements.txt --no-cache-dir
|
||||
|
||||
|
||||
ARG _GLIBCXX_USE_CXX11_ABI="1"
|
||||
ARG CMAKE_PREFIX_PATH="/opt/conda"
|
||||
ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942"
|
||||
ARG BUILD_CAFFE2="0" \
|
||||
BUILD_CAFFE2_OPS="0" \
|
||||
USE_CUDA="0" \
|
||||
|
@ -221,4 +246,3 @@ COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
|
|||
RUN chmod +x /tgi-entrypoint.sh
|
||||
|
||||
ENTRYPOINT ["/tgi-entrypoint.sh"]
|
||||
CMD ["--json-output"]
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
flash_att_v2_commit_cuda := v2.6.1
|
||||
flash_att_v2_commit_rocm := d83c4129a92e4258081f92dfafd34345b3b06130
|
||||
flash_att_v2_commit_rocm := 3cea2fb6ee54fb7e1aad9db6ac6c9331184b8647 # (Aug28)
|
||||
|
||||
build-flash-attention-v2-cuda:
|
||||
pip install -U packaging wheel
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
commit_cuda := d243e9dc7e2c9c2e36a4150ec8e64809cb55c01b
|
||||
commit_rocm := c06ccbf90a213688a2c6a85d2e7af3da7bc4b41b
|
||||
commit_rocm := 4e0929e6e4fa0a3d09d358715c288020ea9dc247
|
||||
build-vllm-cuda:
|
||||
if [ ! -d 'vllm' ]; then \
|
||||
pip install -U ninja packaging --no-cache-dir && \
|
||||
|
|
|
@ -42,6 +42,7 @@ def paged_attention(
|
|||
block_tables: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
num_kv_heads: int,
|
||||
softcap: Optional[float] = None,
|
||||
):
|
||||
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
||||
|
|
|
@ -58,6 +58,7 @@ def paged_attention(
|
|||
block_tables: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
num_kv_heads: int,
|
||||
):
|
||||
out = torch.empty_like(query)
|
||||
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
||||
|
|
|
@ -8,11 +8,17 @@ from loguru import logger
|
|||
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
is_sm75 = major == 7 and minor == 5
|
||||
_PARTITION_SIZE = 512
|
||||
|
||||
_PARTITION_SIZE_V1V2 = 512
|
||||
_PARTITION_SIZE_CUSTOM = 256
|
||||
|
||||
use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"}
|
||||
ENGINE = "triton" if use_triton else "ck"
|
||||
|
||||
custom_attn_available = os.getenv("VLLM_USE_ROCM_CUSTOM_PAGED_ATTN", "1") != "0"
|
||||
if custom_attn_available:
|
||||
from vllm._custom_C import paged_attention_custom
|
||||
|
||||
try:
|
||||
import vllm._custom_ops as ops
|
||||
except Exception as e:
|
||||
|
@ -45,6 +51,7 @@ def paged_attention(
|
|||
block_tables: torch.Tensor,
|
||||
input_lengths: Seqlen,
|
||||
max_s: int,
|
||||
num_kv_heads: int,
|
||||
):
|
||||
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
||||
# Copyright 2023 The vLLM team. All rights
|
||||
|
@ -66,6 +73,22 @@ def paged_attention(
|
|||
# value_cache => [num_blocks, num_heads, head_size, block_size]
|
||||
block_size = value_cache.shape[3]
|
||||
num_seqs, num_heads, head_size = query.shape
|
||||
|
||||
gqa_ratio = num_heads // num_kv_heads
|
||||
use_custom = (
|
||||
custom_attn_available
|
||||
and (query.dtype == torch.half or query.dtype == torch.bfloat16)
|
||||
and (head_size == 128 or head_size == 64)
|
||||
and (block_size == 16 or block_size == 32)
|
||||
and (gqa_ratio >= 1 and gqa_ratio <= 16)
|
||||
and max_s <= 32768
|
||||
)
|
||||
|
||||
if not use_custom:
|
||||
_PARTITION_SIZE = _PARTITION_SIZE_V1V2
|
||||
else:
|
||||
_PARTITION_SIZE = _PARTITION_SIZE_CUSTOM
|
||||
|
||||
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
||||
input_lengths = input_lengths.input_lengths
|
||||
|
||||
|
@ -78,7 +101,11 @@ def paged_attention(
|
|||
# to parallelize.
|
||||
import vllm._custom_ops as ops
|
||||
|
||||
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
|
||||
use_v1 = (
|
||||
max_s <= 8192
|
||||
and (max_num_partitions == 1 or num_seqs * num_heads > 512)
|
||||
and not use_custom
|
||||
)
|
||||
if use_v1:
|
||||
ops.paged_attention_v1(
|
||||
out,
|
||||
|
@ -110,24 +137,44 @@ def paged_attention(
|
|||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
|
||||
ops.paged_attention_v2(
|
||||
out,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
kv_head_mapping,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
block_size,
|
||||
max_s,
|
||||
None,
|
||||
"auto",
|
||||
1.0,
|
||||
)
|
||||
if not use_custom:
|
||||
ops.paged_attention_v2(
|
||||
out,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
kv_head_mapping,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
block_size,
|
||||
max_s,
|
||||
None,
|
||||
"auto",
|
||||
1.0,
|
||||
)
|
||||
else:
|
||||
paged_attention_custom(
|
||||
out,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
block_size,
|
||||
max_s,
|
||||
None,
|
||||
"auto",
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
|
|
|
@ -313,6 +313,7 @@ class FlashCohereAttention(torch.nn.Module):
|
|||
block_tables,
|
||||
input_lengths,
|
||||
max_s,
|
||||
self.num_key_value_heads,
|
||||
)
|
||||
|
||||
return self.o_proj(
|
||||
|
|
|
@ -352,6 +352,7 @@ class DbrxAttention(torch.nn.Module):
|
|||
block_tables,
|
||||
input_lengths,
|
||||
max_s,
|
||||
self.num_key_value_heads,
|
||||
)
|
||||
|
||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
|
|
@ -380,6 +380,7 @@ class DeepseekV2Attention(torch.nn.Module):
|
|||
block_tables,
|
||||
input_lengths,
|
||||
max_s,
|
||||
self.num_key_value_heads,
|
||||
)
|
||||
|
||||
# Remove padding.
|
||||
|
@ -424,6 +425,7 @@ class DeepseekV2MLP(nn.Module):
|
|||
def forward(self, hidden_states: torch.Tensor, reduce: bool = True):
|
||||
if (
|
||||
SYSTEM == "rocm"
|
||||
and hidden_states.dtype == torch.float16
|
||||
and self.hidden_act == "silu"
|
||||
and hidden_states.shape[0] == 1
|
||||
and not self.quantize
|
||||
|
|
|
@ -256,6 +256,7 @@ class FlashGemma2Attention(torch.nn.Module):
|
|||
block_tables,
|
||||
input_lengths,
|
||||
max_s,
|
||||
self.num_key_value_heads,
|
||||
softcap=self.softcap,
|
||||
)
|
||||
|
||||
|
|
|
@ -248,6 +248,7 @@ class FlashGemmaAttention(torch.nn.Module):
|
|||
block_tables,
|
||||
input_lengths,
|
||||
max_s,
|
||||
self.num_key_value_heads,
|
||||
)
|
||||
|
||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
|
|
@ -247,6 +247,7 @@ class FlashGPT2Attention(torch.nn.Module):
|
|||
block_tables,
|
||||
input_lengths,
|
||||
max_s,
|
||||
self.num_key_value_heads,
|
||||
)
|
||||
|
||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
|
|
@ -235,6 +235,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
block_tables,
|
||||
input_lengths,
|
||||
max_s,
|
||||
self.num_key_value_heads,
|
||||
)
|
||||
|
||||
return self.o_proj(
|
||||
|
@ -318,6 +319,7 @@ class LlamaMLP(nn.Module):
|
|||
def forward(self, hidden_states, adapter_data):
|
||||
if (
|
||||
SYSTEM == "rocm"
|
||||
and hidden_states.dtype == torch.float16
|
||||
and self.hidden_act == "silu"
|
||||
and hidden_states.shape[0] == 1
|
||||
and self.hidden_size
|
||||
|
@ -557,6 +559,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
hidden_states = self.model(
|
||||
inputs_embeds,
|
||||
position_ids,
|
||||
|
|
|
@ -235,6 +235,7 @@ class MistralAttention(torch.nn.Module):
|
|||
block_tables,
|
||||
input_lengths,
|
||||
max_s,
|
||||
self.num_key_value_heads,
|
||||
)
|
||||
|
||||
return self.o_proj(
|
||||
|
@ -300,6 +301,7 @@ class MistralMLP(nn.Module):
|
|||
def forward(self, hidden_states, adapter_data):
|
||||
if (
|
||||
SYSTEM == "rocm"
|
||||
and hidden_states.dtype == torch.float16
|
||||
and self.hidden_act == "silu"
|
||||
and hidden_states.shape[0] == 1
|
||||
and not self.quantize
|
||||
|
|
|
@ -292,6 +292,7 @@ class MixtralAttention(torch.nn.Module):
|
|||
block_tables,
|
||||
input_lengths,
|
||||
max_s,
|
||||
self.num_key_value_heads,
|
||||
)
|
||||
|
||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
|
|
@ -180,6 +180,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||
block_tables,
|
||||
input_lengths,
|
||||
max_s,
|
||||
self.num_key_value_heads,
|
||||
)
|
||||
|
||||
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
|
|
@ -209,6 +209,7 @@ class FlashPhiAttention(torch.nn.Module):
|
|||
block_tables,
|
||||
input_lengths,
|
||||
max_s,
|
||||
self.num_key_value_heads,
|
||||
)
|
||||
|
||||
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
|
|
@ -153,6 +153,7 @@ class Qwen2Attention(torch.nn.Module):
|
|||
block_tables,
|
||||
input_lengths,
|
||||
max_s,
|
||||
self.num_key_value_heads,
|
||||
)
|
||||
|
||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
|
|
@ -223,6 +223,7 @@ class FlashRWAttention(torch.nn.Module):
|
|||
block_tables,
|
||||
input_lengths,
|
||||
max_s,
|
||||
self.num_key_value_heads,
|
||||
)
|
||||
|
||||
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
@ -341,6 +342,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||
block_tables,
|
||||
input_lengths,
|
||||
max_s,
|
||||
self.num_key_value_heads,
|
||||
)
|
||||
|
||||
return self.dense(
|
||||
|
|
|
@ -308,6 +308,7 @@ class FlashMQAttention(torch.nn.Module):
|
|||
block_tables,
|
||||
input_lengths,
|
||||
max_s,
|
||||
self.num_key_value_heads,
|
||||
)
|
||||
|
||||
return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
|
|
@ -258,6 +258,7 @@ class Starcoder2Attention(torch.nn.Module):
|
|||
block_tables,
|
||||
input_lengths,
|
||||
max_s,
|
||||
self.num_key_value_heads,
|
||||
)
|
||||
|
||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
|
|
@ -1007,12 +1007,12 @@ class FlashCausalLM(Model):
|
|||
else:
|
||||
self.kv_cache = [
|
||||
(
|
||||
torch.empty(
|
||||
torch.zeros(
|
||||
(num_blocks, num_heads, head_size // x, BLOCK_SIZE, x),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
torch.empty(
|
||||
torch.zeros(
|
||||
(num_blocks, num_heads, head_size, BLOCK_SIZE),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
|
|
Loading…
Reference in New Issue