Merge branch 'fix_rocm_fa' into rocm_6.2_fixes

This commit is contained in:
Mohit Sharma 2024-09-19 08:35:31 +00:00
commit 41b297a26b
24 changed files with 267 additions and 123 deletions

View File

@ -41,7 +41,7 @@ COPY launcher launcher
RUN cargo build --profile release-opt RUN cargo build --profile release-opt
# Text Generation Inference base image for RoCm # Text Generation Inference base image for RoCm
FROM rocm/dev-ubuntu-22.04:6.1.1_hip_update AS base FROM rocm/dev-ubuntu-22.04:6.2 AS base
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
build-essential \ build-essential \
@ -50,23 +50,25 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
curl \ curl \
git \ git \
make \ make \
libmsgpack-dev \
libssl-dev \ libssl-dev \
llvm-dev \
g++ \ g++ \
# Needed to build VLLM & flash. # Needed to build VLLM & flash.
rocthrust-dev \ rocthrust-dev \
hipsparse-dev \ hipsparse-dev \
hipblas-dev \ hipblas-dev \
hipblaslt-dev \ hipcub-dev \
rocblas-dev \ rocblas-dev \
hiprand-dev \ hiprand-dev \
hipfft-dev \
rocrand-dev \ rocrand-dev \
miopen-hip-dev \ miopen-hip-dev \
hipfft-dev \
hipcub-dev \
hipsolver-dev \ hipsolver-dev \
rccl-dev \ rccl-dev \
cmake \ cmake \
python3.11-dev && \ python3.11-dev \
python3.11-venv && \
rm -rf /var/lib/apt/lists/* rm -rf /var/lib/apt/lists/*
# Keep in sync with `server/pyproject.toml # Keep in sync with `server/pyproject.toml
@ -76,7 +78,14 @@ ARG ROCM_VERSION='6.0.2'
ARG PYTHON_VERSION='3.11.10' ARG PYTHON_VERSION='3.11.10'
# Automatically set by buildx # Automatically set by buildx
ARG TARGETPLATFORM ARG TARGETPLATFORM
ENV PATH /opt/conda/bin:$PATH ENV PATH=/opt/conda/bin:$PATH
ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942"
RUN curl -fsSL -v -o cmake-3.30.2-linux-x86_64.sh https://github.com/Kitware/CMake/releases/download/v3.30.2/cmake-3.30.2-linux-x86_64.sh \
&& chmod +x cmake-3.30.2-linux-x86_64.sh \
&& ./cmake-3.30.2-linux-x86_64.sh --skip-license --prefix=/usr/local \
&& rm cmake-3.30.2-linux-x86_64.sh
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda. # TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
# Install mamba # Install mamba
@ -100,19 +109,37 @@ RUN case ${TARGETPLATFORM} in \
/opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \ /opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \
esac && \ esac && \
/opt/conda/bin/conda clean -ya /opt/conda/bin/conda clean -ya
# Install flash-attention, torch dependencies # Install flash-attention, torch dependencies
RUN pip install numpy einops ninja --no-cache-dir RUN pip install numpy einops ninja joblib msgpack --no-cache-dir
# Install HIPBLASLt
ARG HIPBLASLT_BRANCH="6f65c6e"
RUN git clone https://github.com/ROCm/hipBLASLt \
&& cd hipBLASLt \
&& git checkout ${HIPBLASLT_BRANCH} \
&& SCCACHE_IDLE_TIMEOUT=1800 ./install.sh --architecture ${PYTORCH_ROCM_ARCH} \
&& cd build/release \
&& make package
RUN dpkg -i hipBLASLt/build/release/*.deb \
&& sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
&& sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status \
&& rm -rf hipBLASLt
RUN pip uninstall -y triton && \ RUN pip uninstall -y triton && \
git clone --depth 1 --single-branch https://github.com/ROCm/triton.git && \ git clone --depth 1 --single-branch https://github.com/ROCm/triton.git && \
cd triton/python && \ cd triton/python && \
pip install . pip install .
RUN git clone --depth 1 --recursive --single-branch --branch 2.3-patched https://github.com/fxmarty/pytorch.git pytorch && cd pytorch && pip install -r requirements.txt --no-cache-dir ARG PYTORCH_COMMIT="da320214e66b5af0f7db8fd18a64dbb519d17b27"
RUN git clone --depth 1 --recursive --single-branch --branch main https://github.com/pytorch/pytorch.git pytorch && \
cd pytorch && git fetch --depth 1 origin ${PYTORCH_COMMIT} && \
git checkout ${PYTORCH_COMMIT} && \
git submodule update --init --recursive && \
pip install -r requirements.txt --no-cache-dir
ARG _GLIBCXX_USE_CXX11_ABI="1" ARG _GLIBCXX_USE_CXX11_ABI="1"
ARG CMAKE_PREFIX_PATH="/opt/conda" ARG CMAKE_PREFIX_PATH="/opt/conda"
ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942"
ARG BUILD_CAFFE2="0" \ ARG BUILD_CAFFE2="0" \
BUILD_CAFFE2_OPS="0" \ BUILD_CAFFE2_OPS="0" \
USE_CUDA="0" \ USE_CUDA="0" \
@ -224,6 +251,13 @@ ENTRYPOINT ["./entrypoint.sh"]
# Final image # Final image
FROM base-copy FROM base-copy
ENV ROCM_USE_CUSTOM_PAGED_ATTN=1
ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0
ENV VLLM_MOE_PADDING=0
ENV ATTENTION=paged
ENV USE_PREFIX_CACHING=0
ENV ROCM_USE_SKINNY_GEMM=1
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
RUN chmod +x /tgi-entrypoint.sh RUN chmod +x /tgi-entrypoint.sh

View File

@ -31,6 +31,12 @@ Two implementations of Flash Attention are available for ROCm, the first is [ROC
By default, the Composable Kernel implementation is used. However, the Triton implementation has slightly lower latency on MI250 and MI300, but requires a warmup which can be prohibitive as it needs to be done again for each new prompt length. If needed, FA Triton impelmentation can be enabled with `--env ROCM_USE_FLASH_ATTN_V2_TRITON="0"` when launching TGI's docker container. By default, the Composable Kernel implementation is used. However, the Triton implementation has slightly lower latency on MI250 and MI300, but requires a warmup which can be prohibitive as it needs to be done again for each new prompt length. If needed, FA Triton impelmentation can be enabled with `--env ROCM_USE_FLASH_ATTN_V2_TRITON="0"` when launching TGI's docker container.
## Custom PagedAttention
For better performance on ROCm, a custom Paged Attention kernel is available and is enabled by default. To disable it and fall back to the PagedAttention v2 kernel, set the environment variable `ROCM_USE_CUSTOM_PAGED_ATTN=0`.
The custom kernel supports bf16 and fp16 data types, block size of 16, head size of 128, a maximum context length of 16k, and GQA ratios between 1 and 16. For other configurations, we use the PagedAttention v2 kernel.
## Unsupported features ## Unsupported features
The following features are currently not supported in the ROCm version of TGI, and the supported may be extended in the future: The following features are currently not supported in the ROCm version of TGI, and the supported may be extended in the future:

View File

@ -1,5 +1,5 @@
flash_att_v2_commit_cuda := v2.6.1 flash_att_v2_commit_cuda := v2.6.1
flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6 flash_att_v2_commit_rocm := 3cea2fb6ee54fb7e1aad9db6ac6c9331184b8647 # (Aug28)
build-flash-attention-v2-cuda: build-flash-attention-v2-cuda:
pip install -U packaging wheel pip install -U packaging wheel

View File

@ -1,5 +1,5 @@
commit_cuda := d243e9dc7e2c9c2e36a4150ec8e64809cb55c01b commit_cuda := d243e9dc7e2c9c2e36a4150ec8e64809cb55c01b
commit_rocm := c6ee53b1be97e3bbc791b95f22827501297f8921 commit_rocm := 4e0929e6e4fa0a3d09d358715c288020ea9dc247
build-vllm-cuda: build-vllm-cuda:
if [ ! -d 'vllm' ]; then \ if [ ! -d 'vllm' ]; then \
pip install -U ninja packaging --no-cache-dir && \ pip install -U ninja packaging --no-cache-dir && \
@ -13,7 +13,7 @@ install-vllm-cuda: build-vllm-cuda
build-vllm-rocm: build-vllm-rocm:
if [ ! -d 'vllm' ]; then \ if [ ! -d 'vllm' ]; then \
pip install -U ninja packaging --no-cache-dir && \ pip install -U ninja packaging --no-cache-dir && \
git clone https://github.com/fxmarty/rocm-vllm.git vllm; \ git clone https://github.com/mht-sharma/vllm.git vllm; \
fi fi
cd vllm && git fetch && git checkout $(commit_rocm) && \ cd vllm && git fetch && git checkout $(commit_rocm) && \
PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build

View File

@ -1,4 +1,5 @@
import os import os
from typing import Optional
import torch import torch
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import ATTENTION from text_generation_server.models.globals import ATTENTION
@ -8,13 +9,19 @@ from loguru import logger
major, minor = torch.cuda.get_device_capability() major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5 is_sm75 = major == 7 and minor == 5
_PARTITION_SIZE = 512
_PARTITION_SIZE_V1V2 = 512
_PARTITION_SIZE_CUSTOM = 256
use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"} use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"}
ENGINE = "triton" if use_triton else "ck" ENGINE = "triton" if use_triton else "ck"
custom_attn_available = os.getenv("ROCM_USE_CUSTOM_PAGED_ATTN", "1") != "0"
if custom_attn_available:
from vllm._custom_C import paged_attention_custom
try: try:
from vllm._C import cache_ops import vllm._custom_ops as ops
except Exception as e: except Exception as e:
raise ImportError( raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}" f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
@ -33,9 +40,7 @@ def reshape_and_cache(
key_cache.view(-1, shape[-2], shape[-1])[slots] = key key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else: else:
cache_ops.reshape_and_cache( ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
key, value, key_cache, value_cache, slots, "auto", 1.0
)
def paged_attention( def paged_attention(
@ -45,8 +50,9 @@ def paged_attention(
kv_head_mapping: torch.Tensor, kv_head_mapping: torch.Tensor,
softmax_scale: float, softmax_scale: float,
block_tables: torch.Tensor, block_tables: torch.Tensor,
input_lengths: Seqlen, seqlen: Seqlen,
max_s: int, max_s: int,
softcap: Optional[float] = None,
): ):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
# Copyright 2023 The vLLM team. All rights # Copyright 2023 The vLLM team. All rights
@ -68,8 +74,25 @@ def paged_attention(
# value_cache => [num_blocks, num_heads, head_size, block_size] # value_cache => [num_blocks, num_heads, head_size, block_size]
block_size = value_cache.shape[3] block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape num_seqs, num_heads, head_size = query.shape
num_kv_heads = key_cache.shape[1]
gqa_ratio = num_heads // num_kv_heads
use_custom = (
custom_attn_available
and (query.dtype == torch.half or query.dtype == torch.bfloat16)
and (head_size == 128 or head_size == 64)
and (block_size == 16 or block_size == 32)
and (gqa_ratio >= 1 and gqa_ratio <= 16)
and max_s <= 32768
)
if not use_custom:
_PARTITION_SIZE = _PARTITION_SIZE_V1V2
else:
_PARTITION_SIZE = _PARTITION_SIZE_CUSTOM
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
input_lengths = input_lengths.input_lengths input_lengths = seqlen.input_lengths
out = torch.empty_like(query) out = torch.empty_like(query)
@ -78,9 +101,13 @@ def paged_attention(
# V1 to avoid the overhead of reduction. Also, if the number of # V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work # sequences or heads is large, we use V1 since there is enough work
# to parallelize. # to parallelize.
from vllm._C import ops import vllm._custom_ops as ops
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512) use_v1 = (
max_s <= 8192
and (max_num_partitions == 1 or num_seqs * num_heads > 512)
and not use_custom
)
if use_v1: if use_v1:
ops.paged_attention_v1( ops.paged_attention_v1(
out, out,
@ -112,6 +139,7 @@ def paged_attention(
) )
max_logits = torch.empty_like(exp_sums) max_logits = torch.empty_like(exp_sums)
if not use_custom:
ops.paged_attention_v2( ops.paged_attention_v2(
out, out,
exp_sums, exp_sums,
@ -130,6 +158,25 @@ def paged_attention(
"auto", "auto",
1.0, 1.0,
) )
else:
paged_attention_custom(
out,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
)
return out return out
@ -173,13 +220,14 @@ if ENGINE == "ck":
def attention( def attention(
q, q,
k, key_cache: torch.Tensor,
v, value_cache: torch.Tensor,
cu_seqlens, seqlen: Seqlen,
max_s, block_tables: torch.Tensor,
softmax_scale, softmax_scale,
window_size_left=-1, window_size_left=-1,
causal=True, causal=True,
softcap=0.0,
): ):
if window_size_left <= 0 and window_size_left != -1: if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1") raise ValueError("`window_size_left` must be > 0 or -1")
@ -189,46 +237,54 @@ if ENGINE == "ck":
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
return flash_attn_2_cuda.varlen_fwd( return flash_attn_2_cuda.varlen_fwd(
q, q,
k, key_cache,
v, value_cache,
out, out,
cu_seqlens, seqlen.cu_seqlen_q,
cu_seqlens, seqlen.cu_seqlen_q,
max_s, None,
max_s, None,
None,
None,
seqlen.max_q,
seqlen.max_k,
0.0, 0.0,
softmax_scale, softmax_scale,
False, False,
causal, causal,
window_size_left,
0,
softcap,
False, False,
None, None,
) )[0]
elif ENGINE == "triton": elif ENGINE == "triton":
from .flash_attn_triton import triton_attention from .flash_attn_triton import triton_attention
def attention( def attention(
q, q,
k, key_cache: torch.Tensor,
v, value_cache: torch.Tensor,
cu_seqlens, seqlen: Seqlen,
max_s, block_tables: torch.Tensor,
softmax_scale, softmax_scale,
window_size_left=-1, window_size_left=-1,
causal=True, causal=True,
softcap=0.0,
): ):
out = torch.empty_like(q) out = torch.empty_like(q)
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
output, _ = triton_attention( output, _ = triton_attention(
q, q,
k, key_cache,
v, value_cache,
out, out,
cu_seqlens, seqlen.cu_seqlen_q,
cu_seqlens, seqlen.cu_seqlen_q,
max_s, seqlen.max_q,
max_s, seqlen.max_k,
causal, causal,
softmax_scale, softmax_scale,
) )

View File

@ -1,8 +1,15 @@
import torch import torch
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from torch.nn import functional as F from torch.nn import functional as F
import os
if SYSTEM == "rocm": if SYSTEM == "rocm":
ROCM_USE_SKINNY_GEMM = os.getenv("ROCM_USE_SKINNY_GEMM", "True").lower() in (
"true",
"1",
)
if ROCM_USE_SKINNY_GEMM:
try: try:
from vllm import _custom_C from vllm import _custom_C
except Exception as e: except Exception as e:
@ -48,6 +55,14 @@ class FastLinearROCm(torch.nn.Module):
else: else:
self.bias = None self.bias = None
self.cu_count = torch.cuda.get_device_properties(
device="cuda"
).multi_processor_count
self.use_skinny_gemm = (
ROCM_USE_SKINNY_GEMM
and "gfx1" not in torch.cuda.get_device_properties("cuda").gcnArchName
)
@classmethod @classmethod
def load(cls, config, prefix: str, weights, bias: bool): def load(cls, config, prefix: str, weights, bias: bool):
weight = weights.get_tensor(f"{prefix}.weight") weight = weights.get_tensor(f"{prefix}.weight")
@ -61,7 +76,11 @@ class FastLinearROCm(torch.nn.Module):
weight = self.weight weight = self.weight
bias = self.bias bias = self.bias
if SYSTEM == "rocm" and inp.numel() // inp.shape[-1] == 1: if (
self.use_skinny_gemm
and inp.dtype == torch.float16
and inp.shape[-1] % 8 == 0
):
batched = False batched = False
inp_shape = inp.shape inp_shape = inp.shape
@ -69,13 +88,16 @@ class FastLinearROCm(torch.nn.Module):
inp = inp.view(-1, inp_shape[-1]) inp = inp.view(-1, inp_shape[-1])
batched = True batched = True
m, k = weight.shape[0], inp_shape[1] m, n, k = weight.shape[0], inp_shape[0], inp_shape[1]
if m > 8 and n <= 4:
out = torch.empty( out = torch.empty(
inp_shape[0], weight.shape[0], dtype=inp.dtype, device="cuda" inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
)
_custom_C.wvSpltK(weight, inp, out, n, self.cu_count)
elif m % 4 == 0 and n == 1 and k <= 8192:
out = torch.empty(
inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
) )
if (k == 8192 and (m == 1280 or m == 7168)) or (k == 3584 and m == 8192):
_custom_C.LLMM1(weight, inp, out, 8)
elif k <= 8192 and k % 8 == 0 and m % 4 == 0:
_custom_C.LLMM1(weight, inp, out, 4) _custom_C.LLMM1(weight, inp, out, 4)
else: else:
out = F.linear(inp, weight) out = F.linear(inp, weight)

View File

@ -18,6 +18,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from text_generation_server.models.globals import PAGED_KV
import torch import torch
import torch.distributed import torch.distributed
@ -297,8 +298,8 @@ class FlashCohereAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else key, kv_cache[0] if PAGED_KV else key,
kv_cache[1] if SYSTEM != "ipex" else value, kv_cache[1] if PAGED_KV else value,
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from text_generation_server.models.globals import PAGED_KV
import torch import torch
import torch.distributed import torch.distributed
@ -336,8 +337,8 @@ class DbrxAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0], kv_cache[0] if PAGED_KV else kv[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1], kv_cache[1] if PAGED_KV else kv[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -15,6 +15,7 @@
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from text_generation_server.models.globals import PAGED_KV
from moe_kernels.fused_moe import grouped_topk from moe_kernels.fused_moe import grouped_topk
import torch import torch
import torch.distributed import torch.distributed
@ -327,8 +328,8 @@ class DeepseekV2Attention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else key, kv_cache[0] if PAGED_KV else key,
kv_cache[1] if SYSTEM != "ipex" else value, kv_cache[1] if PAGED_KV else value,
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
@ -388,6 +389,7 @@ class DeepseekV2MLP(nn.Module):
def forward(self, hidden_states: torch.Tensor, reduce: bool = True): def forward(self, hidden_states: torch.Tensor, reduce: bool = True):
if ( if (
SYSTEM == "rocm" SYSTEM == "rocm"
and hidden_states.dtype == torch.float16
and self.hidden_act == "silu" and self.hidden_act == "silu"
and hidden_states.shape[0] == 1 and hidden_states.shape[0] == 1
and not self.quantize and not self.quantize

View File

@ -18,6 +18,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from text_generation_server.models.globals import PAGED_KV
import torch import torch
import torch.distributed import torch.distributed
@ -25,7 +26,6 @@ from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
@ -237,8 +237,8 @@ class FlashGemma2Attention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0], kv_cache[0] if PAGED_KV else kv[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1], kv_cache[1] if PAGED_KV else kv[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -18,6 +18,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from text_generation_server.models.globals import PAGED_KV
import torch import torch
import torch.distributed import torch.distributed
@ -25,7 +26,6 @@ from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
@ -231,8 +231,8 @@ class FlashGemmaAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0], kv_cache[0] if PAGED_KV else kv[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1], kv_cache[1] if PAGED_KV else kv[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -18,13 +18,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from text_generation_server.models.globals import PAGED_KV
import torch import torch
import torch.distributed import torch.distributed
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
@ -231,8 +231,8 @@ class FlashGPT2Attention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else key, kv_cache[0] if PAGED_KV else key,
kv_cache[1] if SYSTEM != "ipex" else value, kv_cache[1] if PAGED_KV else value,
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -18,6 +18,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from text_generation_server.models.globals import PAGED_KV
import torch import torch
import torch.distributed import torch.distributed
@ -192,8 +193,8 @@ class FlashGPTJAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else key, kv_cache[0] if PAGED_KV else key,
kv_cache[1] if SYSTEM != "ipex" else value, kv_cache[1] if PAGED_KV else value,
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -28,6 +28,7 @@ from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import PAGED_KV
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
@ -220,8 +221,8 @@ class FlashLlamaAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0], kv_cache[0] if PAGED_KV else kv[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1], kv_cache[1] if PAGED_KV else kv[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
@ -315,11 +316,16 @@ class LlamaMLP(nn.Module):
# TODO: This is a hotfix to be removed & properly refactored. # TODO: This is a hotfix to be removed & properly refactored.
self.quantize = config.quantize self.quantize = config.quantize
self.hidden_size = config.hidden_size
def forward(self, hidden_states, adapter_data): def forward(self, hidden_states, adapter_data):
if ( if (
SYSTEM == "rocm" SYSTEM == "rocm"
and hidden_states.dtype == torch.float16
and self.hidden_act == "silu" and self.hidden_act == "silu"
and hidden_states.shape[0] == 1 and hidden_states.shape[0] == 1
and self.hidden_size
!= 16384 # TODO: Temporary workaround for `LLMM_Silu` kernel not working with LLama3.1 405B; needs refactoring once fixed.
and not self.quantize and not self.quantize
): ):
out = torch.empty( out = torch.empty(
@ -555,6 +561,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.model( hidden_states = self.model(
inputs_embeds, inputs_embeds,
position_ids, position_ids,

View File

@ -18,6 +18,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from text_generation_server.models.globals import PAGED_KV
import torch import torch
import torch.distributed import torch.distributed
@ -218,8 +219,8 @@ class MistralAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0], kv_cache[0] if PAGED_KV else kv_to_cache[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1], kv_cache[1] if PAGED_KV else kv_to_cache[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
@ -301,6 +302,7 @@ class MistralMLP(nn.Module):
def forward(self, hidden_states, adapter_data): def forward(self, hidden_states, adapter_data):
if ( if (
SYSTEM == "rocm" SYSTEM == "rocm"
and hidden_states.dtype == torch.float16
and self.hidden_act == "silu" and self.hidden_act == "silu"
and hidden_states.shape[0] == 1 and hidden_states.shape[0] == 1
and not self.quantize and not self.quantize

View File

@ -18,12 +18,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from text_generation_server.models.globals import PAGED_KV
import torch import torch
import torch.distributed import torch.distributed
from torch import nn from torch import nn
from text_generation_server.utils.import_utils import SYSTEM
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
@ -274,8 +274,8 @@ class MixtralAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0], kv_cache[0] if PAGED_KV else kv_to_cache[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1], kv_cache[1] if PAGED_KV else kv_to_cache[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -18,6 +18,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from text_generation_server.models.globals import PAGED_KV
import torch import torch
import torch.distributed import torch.distributed
@ -26,7 +27,6 @@ from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers.models.gpt_neox import GPTNeoXConfig as TransformersGPTNeoXConfig from transformers.models.gpt_neox import GPTNeoXConfig as TransformersGPTNeoXConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
@ -172,8 +172,8 @@ class FlashNeoxAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
qkv[:, 0], qkv[:, 0],
kv_cache[0] if SYSTEM != "ipex" else qkv[:, 1], kv_cache[0] if PAGED_KV else qkv[:, 1],
kv_cache[1] if SYSTEM != "ipex" else qkv[:, 2], kv_cache[1] if PAGED_KV else qkv[:, 2],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -1,3 +1,4 @@
from text_generation_server.models.globals import PAGED_KV
import torch import torch
import torch.distributed import torch.distributed
@ -25,7 +26,6 @@ from text_generation_server.layers.layernorm import (
from text_generation_server.layers.rotary import ( from text_generation_server.layers.rotary import (
PositionRotaryEmbedding, PositionRotaryEmbedding,
) )
from text_generation_server.utils.import_utils import SYSTEM
class PhiConfig(PretrainedConfig): class PhiConfig(PretrainedConfig):
@ -194,8 +194,8 @@ class FlashPhiAttention(torch.nn.Module):
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0], kv_cache[0] if PAGED_KV else kv[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1], kv_cache[1] if PAGED_KV else kv[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -1,3 +1,4 @@
from text_generation_server.models.globals import PAGED_KV
import torch import torch
import torch.distributed import torch.distributed
@ -21,7 +22,6 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastRMSNorm, FastRMSNorm,
) )
from text_generation_server.utils.import_utils import SYSTEM
def load_attention(config, prefix, weights): def load_attention(config, prefix, weights):
@ -137,8 +137,8 @@ class Qwen2Attention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0], kv_cache[0] if PAGED_KV else kv_to_cache[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1], kv_cache[1] if PAGED_KV else kv_to_cache[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -1,11 +1,11 @@
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from text_generation_server.models.globals import PAGED_KV
import torch import torch
import torch.distributed import torch.distributed
from torch import nn from torch import nn
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers import ( from text_generation_server.layers import (
SpeculativeHead, SpeculativeHead,
TensorParallelColumnLinear, TensorParallelColumnLinear,
@ -207,8 +207,8 @@ class FlashRWAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0], kv_cache[0] if PAGED_KV else kv[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1], kv_cache[1] if PAGED_KV else kv[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,
@ -325,8 +325,8 @@ class FlashRWLargeAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else kv[:, :, 0].contiguous(), kv_cache[0] if PAGED_KV else kv[:, :, 0].contiguous(),
kv_cache[1] if SYSTEM != "ipex" else kv[:, :, 1].contiguous(), kv_cache[1] if PAGED_KV else kv[:, :, 1].contiguous(),
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -1,3 +1,4 @@
from text_generation_server.models.globals import PAGED_KV
import torch import torch
import torch.distributed import torch.distributed
@ -22,7 +23,6 @@ from text_generation_server.layers.gptq import GPTQWeightsLoader
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
) )
from text_generation_server.utils.import_utils import SYSTEM
def load_multi_mqa( def load_multi_mqa(
@ -293,8 +293,8 @@ class FlashMQAttention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else key_value[:, 0], kv_cache[0] if PAGED_KV else key_value[:, 0],
kv_cache[1] if SYSTEM != "ipex" else key_value[:, 1], kv_cache[1] if PAGED_KV else key_value[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -18,6 +18,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from text_generation_server.models.globals import PAGED_KV
import torch import torch
import torch.distributed import torch.distributed
@ -47,7 +48,6 @@ from text_generation_server.layers.rotary import (
PositionRotaryEmbedding, PositionRotaryEmbedding,
) )
from text_generation_server.utils.weights import UnquantizedWeight from text_generation_server.utils.weights import UnquantizedWeight
from text_generation_server.utils.import_utils import SYSTEM
class Starcoder2Config(PretrainedConfig): class Starcoder2Config(PretrainedConfig):
@ -242,8 +242,8 @@ class Starcoder2Attention(torch.nn.Module):
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query,
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0], kv_cache[0] if PAGED_KV else kv_to_cache[:, 0],
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1], kv_cache[1] if PAGED_KV else kv_to_cache[:, 1],
seqlen, seqlen,
block_tables, block_tables,
self.softmax_scale, self.softmax_scale,

View File

@ -1125,12 +1125,12 @@ class FlashCausalLM(Model):
else: else:
self.kv_cache = [ self.kv_cache = [
( (
torch.empty( torch.zeros(
(num_blocks, num_heads, head_size // x, BLOCK_SIZE, x), (num_blocks, num_heads, head_size // x, BLOCK_SIZE, x),
dtype=dtype, dtype=dtype,
device=device, device=device,
), ),
torch.empty( torch.zeros(
(num_blocks, num_heads, head_size, BLOCK_SIZE), (num_blocks, num_heads, head_size, BLOCK_SIZE),
dtype=dtype, dtype=dtype,
device=device, device=device,
@ -1320,8 +1320,7 @@ class FlashCausalLM(Model):
elif CUDA_GRAPHS is not None: elif CUDA_GRAPHS is not None:
tuning_sequences = CUDA_GRAPHS tuning_sequences = CUDA_GRAPHS
else: else:
# For seqlen = 1, we dispatch to LLMM1 kernel. tuning_sequences = [1, 2, 3, 4, 5, 6, 7]
tuning_sequences = [2, 3, 4, 5, 6, 7]
tunableop_filepath = os.path.join( tunableop_filepath = os.path.join(
HUGGINGFACE_HUB_CACHE, HUGGINGFACE_HUB_CACHE,
@ -1330,7 +1329,11 @@ class FlashCausalLM(Model):
log_master( log_master(
logger.info, logger.info,
f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`.", f"PyTorch TunableOp is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`.",
)
torch.cuda.tunable.set_filename(
tunableop_filepath, insert_device_ordinal=False
) )
if os.path.isfile(tunableop_filepath): if os.path.isfile(tunableop_filepath):
@ -1346,6 +1349,7 @@ class FlashCausalLM(Model):
log_master(logger.info, f"Warming up TunableOp for seqlen={seqlen}") log_master(logger.info, f"Warming up TunableOp for seqlen={seqlen}")
self.tunableop_warmup(seqlen) self.tunableop_warmup(seqlen)
torch.cuda.tunable.write_file(tunableop_filepath) torch.cuda.tunable.write_file(tunableop_filepath)
if os.environ.get("PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP") != "1":
torch.cuda.tunable.tuning_enable(False) torch.cuda.tunable.tuning_enable(False)
else: else:
log_master( log_master(
@ -1382,6 +1386,7 @@ class FlashCausalLM(Model):
cu_seqlen_prefill = torch.tensor( cu_seqlen_prefill = torch.tensor(
[0, seqlen], device=self.device, dtype=torch.int32 [0, seqlen], device=self.device, dtype=torch.int32
) )
max_s = seqlen
seqlen = Seqlen( seqlen = Seqlen(
input_lengths=input_lengths, input_lengths=input_lengths,
prefix_lengths=prefix_lens_tensor, prefix_lengths=prefix_lens_tensor,
@ -1399,7 +1404,7 @@ class FlashCausalLM(Model):
block_tables=None, block_tables=None,
seqlen=seqlen, seqlen=seqlen,
slots=slots, slots=slots,
max_s=seqlen, max_s=max_s,
lm_head_indices=None, lm_head_indices=None,
prefill_cache_indices=None, prefill_cache_indices=None,
) )

View File

@ -4,6 +4,7 @@ from loguru import logger
from typing import Dict, Optional from typing import Dict, Optional
from text_generation_server.utils.log import log_master from text_generation_server.utils.log import log_master
from text_generation_server.utils.import_utils import SYSTEM
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING").lower() in {"1", "true"} PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING").lower() in {"1", "true"}
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}") log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
@ -52,6 +53,12 @@ CUDA_GRAPHS = cuda_graphs
# index in all cases. # index in all cases.
ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None
PAGED_KV: bool
if SYSTEM in {"rocm", "ipex"}:
PAGED_KV = False
else:
PAGED_KV = True
def set_adapter_to_index(adapter_to_index: Dict[str, int]): def set_adapter_to_index(adapter_to_index: Dict[str, int]):
global ADAPTER_TO_INDEX global ADAPTER_TO_INDEX