added custom PA

This commit is contained in:
Mohit Sharma 2024-09-04 05:46:28 +00:00
parent e557855558
commit ff0505e7f9
22 changed files with 128 additions and 35 deletions

View File

@ -39,7 +39,7 @@ COPY launcher launcher
RUN cargo build --profile release-opt
# Text Generation Inference base image for RoCm
FROM rocm/dev-ubuntu-22.04:6.1.1_hip_update AS base
FROM rocm/dev-ubuntu-22.04:6.2 AS base
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
build-essential \
@ -48,23 +48,25 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
curl \
git \
make \
libmsgpack-dev \
libssl-dev \
llvm-dev \
g++ \
# Needed to build VLLM & flash.
rocthrust-dev \
hipsparse-dev \
hipblas-dev \
hipblaslt-dev \
hipcub-dev \
rocblas-dev \
hiprand-dev \
hipfft-dev \
rocrand-dev \
miopen-hip-dev \
hipfft-dev \
hipcub-dev \
hipsolver-dev \
rccl-dev \
cmake \
python3-dev && \
python3-dev \
python3-venv && \
rm -rf /var/lib/apt/lists/*
# Keep in sync with `server/pyproject.toml
@ -74,7 +76,30 @@ ARG ROCM_VERSION='6.0.2'
ARG PYTHON_VERSION='3.10.10'
# Automatically set by buildx
ARG TARGETPLATFORM
ENV PATH /opt/conda/bin:$PATH
ENV PATH=/opt/conda/bin:$PATH
ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942"
RUN curl -fsSL -v -o cmake-3.30.2-linux-x86_64.sh https://github.com/Kitware/CMake/releases/download/v3.30.2/cmake-3.30.2-linux-x86_64.sh \
&& chmod +x cmake-3.30.2-linux-x86_64.sh \
&& ./cmake-3.30.2-linux-x86_64.sh --skip-license --prefix=/usr/local \
&& rm cmake-3.30.2-linux-x86_64.sh
RUN pip install joblib msgpack
# Install HIPBLASLt
ARG HIPBLASLT_BRANCH="6f65c6e"
RUN git clone https://github.com/ROCm/hipBLASLt \
&& cd hipBLASLt \
&& git checkout ${HIPBLASLT_BRANCH} \
&& SCCACHE_IDLE_TIMEOUT=1800 ./install.sh --architecture ${PYTORCH_ROCM_ARCH} \
&& cd build/release \
&& make package
RUN dpkg -i hipBLASLt/build/release/*.deb \
&& sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
&& sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status;
# && cd .. \
# && rm -rf hipBLASLt
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
# Install mamba
@ -98,15 +123,15 @@ RUN pip uninstall -y triton && \
cd triton/python && \
pip install .
ARG PYTORCH_COMMIT="da320214e66b5af0f7db8fd18a64dbb519d17b27"
RUN git clone --depth 1 --recursive --single-branch --branch main https://github.com/pytorch/pytorch.git pytorch && \
cd pytorch && git fetch --depth 1 origin da320214e66b5af0f7db8fd18a64dbb519d17b27 && \
git checkout da320214e66b5af0f7db8fd18a64dbb519d17b27 && \
cd pytorch && git fetch --depth 1 origin ${PYTORCH_COMMIT} && \
git checkout ${PYTORCH_COMMIT} && \
git submodule update --init --recursive && \
pip install -r requirements.txt --no-cache-dir
ARG _GLIBCXX_USE_CXX11_ABI="1"
ARG CMAKE_PREFIX_PATH="/opt/conda"
ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942"
ARG BUILD_CAFFE2="0" \
BUILD_CAFFE2_OPS="0" \
USE_CUDA="0" \
@ -221,4 +246,3 @@ COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
RUN chmod +x /tgi-entrypoint.sh
ENTRYPOINT ["/tgi-entrypoint.sh"]
CMD ["--json-output"]

View File

@ -1,5 +1,5 @@
flash_att_v2_commit_cuda := v2.6.1
flash_att_v2_commit_rocm := d83c4129a92e4258081f92dfafd34345b3b06130
flash_att_v2_commit_rocm := 3cea2fb6ee54fb7e1aad9db6ac6c9331184b8647 # (Aug28)
build-flash-attention-v2-cuda:
pip install -U packaging wheel

View File

@ -1,5 +1,5 @@
commit_cuda := d243e9dc7e2c9c2e36a4150ec8e64809cb55c01b
commit_rocm := c06ccbf90a213688a2c6a85d2e7af3da7bc4b41b
commit_rocm := 4e0929e6e4fa0a3d09d358715c288020ea9dc247
build-vllm-cuda:
if [ ! -d 'vllm' ]; then \
pip install -U ninja packaging --no-cache-dir && \

View File

@ -42,6 +42,7 @@ def paged_attention(
block_tables: torch.Tensor,
seqlen: Seqlen,
max_s: int,
num_kv_heads: int,
softcap: Optional[float] = None,
):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py

View File

@ -58,6 +58,7 @@ def paged_attention(
block_tables: torch.Tensor,
seqlen: Seqlen,
max_s: int,
num_kv_heads: int,
):
out = torch.empty_like(query)
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(

View File

@ -8,11 +8,17 @@ from loguru import logger
major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
_PARTITION_SIZE = 512
_PARTITION_SIZE_V1V2 = 512
_PARTITION_SIZE_CUSTOM = 256
use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"}
ENGINE = "triton" if use_triton else "ck"
custom_attn_available = os.getenv("VLLM_USE_ROCM_CUSTOM_PAGED_ATTN", "1") != "0"
if custom_attn_available:
from vllm._custom_C import paged_attention_custom
try:
import vllm._custom_ops as ops
except Exception as e:
@ -45,6 +51,7 @@ def paged_attention(
block_tables: torch.Tensor,
input_lengths: Seqlen,
max_s: int,
num_kv_heads: int,
):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
# Copyright 2023 The vLLM team. All rights
@ -66,6 +73,22 @@ def paged_attention(
# value_cache => [num_blocks, num_heads, head_size, block_size]
block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape
gqa_ratio = num_heads // num_kv_heads
use_custom = (
custom_attn_available
and (query.dtype == torch.half or query.dtype == torch.bfloat16)
and (head_size == 128 or head_size == 64)
and (block_size == 16 or block_size == 32)
and (gqa_ratio >= 1 and gqa_ratio <= 16)
and max_s <= 32768
)
if not use_custom:
_PARTITION_SIZE = _PARTITION_SIZE_V1V2
else:
_PARTITION_SIZE = _PARTITION_SIZE_CUSTOM
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
input_lengths = input_lengths.input_lengths
@ -78,7 +101,11 @@ def paged_attention(
# to parallelize.
import vllm._custom_ops as ops
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
use_v1 = (
max_s <= 8192
and (max_num_partitions == 1 or num_seqs * num_heads > 512)
and not use_custom
)
if use_v1:
ops.paged_attention_v1(
out,
@ -110,24 +137,44 @@ def paged_attention(
)
max_logits = torch.empty_like(exp_sums)
ops.paged_attention_v2(
out,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
1.0,
)
if not use_custom:
ops.paged_attention_v2(
out,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
1.0,
)
else:
paged_attention_custom(
out,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
)
return out

View File

@ -313,6 +313,7 @@ class FlashCohereAttention(torch.nn.Module):
block_tables,
input_lengths,
max_s,
self.num_key_value_heads,
)
return self.o_proj(

View File

@ -352,6 +352,7 @@ class DbrxAttention(torch.nn.Module):
block_tables,
input_lengths,
max_s,
self.num_key_value_heads,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -380,6 +380,7 @@ class DeepseekV2Attention(torch.nn.Module):
block_tables,
input_lengths,
max_s,
self.num_key_value_heads,
)
# Remove padding.
@ -424,6 +425,7 @@ class DeepseekV2MLP(nn.Module):
def forward(self, hidden_states: torch.Tensor, reduce: bool = True):
if (
SYSTEM == "rocm"
and hidden_states.dtype == torch.float16
and self.hidden_act == "silu"
and hidden_states.shape[0] == 1
and not self.quantize

View File

@ -256,6 +256,7 @@ class FlashGemma2Attention(torch.nn.Module):
block_tables,
input_lengths,
max_s,
self.num_key_value_heads,
softcap=self.softcap,
)

View File

@ -248,6 +248,7 @@ class FlashGemmaAttention(torch.nn.Module):
block_tables,
input_lengths,
max_s,
self.num_key_value_heads,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -247,6 +247,7 @@ class FlashGPT2Attention(torch.nn.Module):
block_tables,
input_lengths,
max_s,
self.num_key_value_heads,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -235,6 +235,7 @@ class FlashLlamaAttention(torch.nn.Module):
block_tables,
input_lengths,
max_s,
self.num_key_value_heads,
)
return self.o_proj(
@ -318,6 +319,7 @@ class LlamaMLP(nn.Module):
def forward(self, hidden_states, adapter_data):
if (
SYSTEM == "rocm"
and hidden_states.dtype == torch.float16
and self.hidden_act == "silu"
and hidden_states.shape[0] == 1
and self.hidden_size
@ -557,6 +559,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.model(
inputs_embeds,
position_ids,

View File

@ -235,6 +235,7 @@ class MistralAttention(torch.nn.Module):
block_tables,
input_lengths,
max_s,
self.num_key_value_heads,
)
return self.o_proj(
@ -300,6 +301,7 @@ class MistralMLP(nn.Module):
def forward(self, hidden_states, adapter_data):
if (
SYSTEM == "rocm"
and hidden_states.dtype == torch.float16
and self.hidden_act == "silu"
and hidden_states.shape[0] == 1
and not self.quantize

View File

@ -292,6 +292,7 @@ class MixtralAttention(torch.nn.Module):
block_tables,
input_lengths,
max_s,
self.num_key_value_heads,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -180,6 +180,7 @@ class FlashNeoxAttention(torch.nn.Module):
block_tables,
input_lengths,
max_s,
self.num_key_value_heads,
)
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -209,6 +209,7 @@ class FlashPhiAttention(torch.nn.Module):
block_tables,
input_lengths,
max_s,
self.num_key_value_heads,
)
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -153,6 +153,7 @@ class Qwen2Attention(torch.nn.Module):
block_tables,
input_lengths,
max_s,
self.num_key_value_heads,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -223,6 +223,7 @@ class FlashRWAttention(torch.nn.Module):
block_tables,
input_lengths,
max_s,
self.num_key_value_heads,
)
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
@ -341,6 +342,7 @@ class FlashRWLargeAttention(torch.nn.Module):
block_tables,
input_lengths,
max_s,
self.num_key_value_heads,
)
return self.dense(

View File

@ -308,6 +308,7 @@ class FlashMQAttention(torch.nn.Module):
block_tables,
input_lengths,
max_s,
self.num_key_value_heads,
)
return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -258,6 +258,7 @@ class Starcoder2Attention(torch.nn.Module):
block_tables,
input_lengths,
max_s,
self.num_key_value_heads,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))

View File

@ -1007,12 +1007,12 @@ class FlashCausalLM(Model):
else:
self.kv_cache = [
(
torch.empty(
torch.zeros(
(num_blocks, num_heads, head_size // x, BLOCK_SIZE, x),
dtype=dtype,
device=device,
),
torch.empty(
torch.zeros(
(num_blocks, num_heads, head_size, BLOCK_SIZE),
dtype=dtype,
device=device,