Merge branch 'fix_rocm_fa' into rocm_6.2_fixes
This commit is contained in:
commit
41b297a26b
|
@ -41,7 +41,7 @@ COPY launcher launcher
|
||||||
RUN cargo build --profile release-opt
|
RUN cargo build --profile release-opt
|
||||||
|
|
||||||
# Text Generation Inference base image for RoCm
|
# Text Generation Inference base image for RoCm
|
||||||
FROM rocm/dev-ubuntu-22.04:6.1.1_hip_update AS base
|
FROM rocm/dev-ubuntu-22.04:6.2 AS base
|
||||||
|
|
||||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||||
build-essential \
|
build-essential \
|
||||||
|
@ -50,23 +50,25 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
|
||||||
curl \
|
curl \
|
||||||
git \
|
git \
|
||||||
make \
|
make \
|
||||||
|
libmsgpack-dev \
|
||||||
libssl-dev \
|
libssl-dev \
|
||||||
|
llvm-dev \
|
||||||
g++ \
|
g++ \
|
||||||
# Needed to build VLLM & flash.
|
# Needed to build VLLM & flash.
|
||||||
rocthrust-dev \
|
rocthrust-dev \
|
||||||
hipsparse-dev \
|
hipsparse-dev \
|
||||||
hipblas-dev \
|
hipblas-dev \
|
||||||
hipblaslt-dev \
|
hipcub-dev \
|
||||||
rocblas-dev \
|
rocblas-dev \
|
||||||
hiprand-dev \
|
hiprand-dev \
|
||||||
|
hipfft-dev \
|
||||||
rocrand-dev \
|
rocrand-dev \
|
||||||
miopen-hip-dev \
|
miopen-hip-dev \
|
||||||
hipfft-dev \
|
|
||||||
hipcub-dev \
|
|
||||||
hipsolver-dev \
|
hipsolver-dev \
|
||||||
rccl-dev \
|
rccl-dev \
|
||||||
cmake \
|
cmake \
|
||||||
python3.11-dev && \
|
python3.11-dev \
|
||||||
|
python3.11-venv && \
|
||||||
rm -rf /var/lib/apt/lists/*
|
rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# Keep in sync with `server/pyproject.toml
|
# Keep in sync with `server/pyproject.toml
|
||||||
|
@ -76,7 +78,14 @@ ARG ROCM_VERSION='6.0.2'
|
||||||
ARG PYTHON_VERSION='3.11.10'
|
ARG PYTHON_VERSION='3.11.10'
|
||||||
# Automatically set by buildx
|
# Automatically set by buildx
|
||||||
ARG TARGETPLATFORM
|
ARG TARGETPLATFORM
|
||||||
ENV PATH /opt/conda/bin:$PATH
|
ENV PATH=/opt/conda/bin:$PATH
|
||||||
|
|
||||||
|
ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942"
|
||||||
|
|
||||||
|
RUN curl -fsSL -v -o cmake-3.30.2-linux-x86_64.sh https://github.com/Kitware/CMake/releases/download/v3.30.2/cmake-3.30.2-linux-x86_64.sh \
|
||||||
|
&& chmod +x cmake-3.30.2-linux-x86_64.sh \
|
||||||
|
&& ./cmake-3.30.2-linux-x86_64.sh --skip-license --prefix=/usr/local \
|
||||||
|
&& rm cmake-3.30.2-linux-x86_64.sh
|
||||||
|
|
||||||
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
|
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
|
||||||
# Install mamba
|
# Install mamba
|
||||||
|
@ -100,19 +109,37 @@ RUN case ${TARGETPLATFORM} in \
|
||||||
/opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \
|
/opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \
|
||||||
esac && \
|
esac && \
|
||||||
/opt/conda/bin/conda clean -ya
|
/opt/conda/bin/conda clean -ya
|
||||||
|
|
||||||
# Install flash-attention, torch dependencies
|
# Install flash-attention, torch dependencies
|
||||||
RUN pip install numpy einops ninja --no-cache-dir
|
RUN pip install numpy einops ninja joblib msgpack --no-cache-dir
|
||||||
|
|
||||||
|
# Install HIPBLASLt
|
||||||
|
ARG HIPBLASLT_BRANCH="6f65c6e"
|
||||||
|
RUN git clone https://github.com/ROCm/hipBLASLt \
|
||||||
|
&& cd hipBLASLt \
|
||||||
|
&& git checkout ${HIPBLASLT_BRANCH} \
|
||||||
|
&& SCCACHE_IDLE_TIMEOUT=1800 ./install.sh --architecture ${PYTORCH_ROCM_ARCH} \
|
||||||
|
&& cd build/release \
|
||||||
|
&& make package
|
||||||
|
RUN dpkg -i hipBLASLt/build/release/*.deb \
|
||||||
|
&& sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
|
||||||
|
&& sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status \
|
||||||
|
&& rm -rf hipBLASLt
|
||||||
|
|
||||||
RUN pip uninstall -y triton && \
|
RUN pip uninstall -y triton && \
|
||||||
git clone --depth 1 --single-branch https://github.com/ROCm/triton.git && \
|
git clone --depth 1 --single-branch https://github.com/ROCm/triton.git && \
|
||||||
cd triton/python && \
|
cd triton/python && \
|
||||||
pip install .
|
pip install .
|
||||||
|
|
||||||
RUN git clone --depth 1 --recursive --single-branch --branch 2.3-patched https://github.com/fxmarty/pytorch.git pytorch && cd pytorch && pip install -r requirements.txt --no-cache-dir
|
ARG PYTORCH_COMMIT="da320214e66b5af0f7db8fd18a64dbb519d17b27"
|
||||||
|
RUN git clone --depth 1 --recursive --single-branch --branch main https://github.com/pytorch/pytorch.git pytorch && \
|
||||||
|
cd pytorch && git fetch --depth 1 origin ${PYTORCH_COMMIT} && \
|
||||||
|
git checkout ${PYTORCH_COMMIT} && \
|
||||||
|
git submodule update --init --recursive && \
|
||||||
|
pip install -r requirements.txt --no-cache-dir
|
||||||
|
|
||||||
ARG _GLIBCXX_USE_CXX11_ABI="1"
|
ARG _GLIBCXX_USE_CXX11_ABI="1"
|
||||||
ARG CMAKE_PREFIX_PATH="/opt/conda"
|
ARG CMAKE_PREFIX_PATH="/opt/conda"
|
||||||
ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942"
|
|
||||||
ARG BUILD_CAFFE2="0" \
|
ARG BUILD_CAFFE2="0" \
|
||||||
BUILD_CAFFE2_OPS="0" \
|
BUILD_CAFFE2_OPS="0" \
|
||||||
USE_CUDA="0" \
|
USE_CUDA="0" \
|
||||||
|
@ -224,6 +251,13 @@ ENTRYPOINT ["./entrypoint.sh"]
|
||||||
# Final image
|
# Final image
|
||||||
FROM base-copy
|
FROM base-copy
|
||||||
|
|
||||||
|
ENV ROCM_USE_CUSTOM_PAGED_ATTN=1
|
||||||
|
ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0
|
||||||
|
ENV VLLM_MOE_PADDING=0
|
||||||
|
ENV ATTENTION=paged
|
||||||
|
ENV USE_PREFIX_CACHING=0
|
||||||
|
ENV ROCM_USE_SKINNY_GEMM=1
|
||||||
|
|
||||||
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
|
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
|
||||||
RUN chmod +x /tgi-entrypoint.sh
|
RUN chmod +x /tgi-entrypoint.sh
|
||||||
|
|
||||||
|
|
|
@ -31,6 +31,12 @@ Two implementations of Flash Attention are available for ROCm, the first is [ROC
|
||||||
|
|
||||||
By default, the Composable Kernel implementation is used. However, the Triton implementation has slightly lower latency on MI250 and MI300, but requires a warmup which can be prohibitive as it needs to be done again for each new prompt length. If needed, FA Triton impelmentation can be enabled with `--env ROCM_USE_FLASH_ATTN_V2_TRITON="0"` when launching TGI's docker container.
|
By default, the Composable Kernel implementation is used. However, the Triton implementation has slightly lower latency on MI250 and MI300, but requires a warmup which can be prohibitive as it needs to be done again for each new prompt length. If needed, FA Triton impelmentation can be enabled with `--env ROCM_USE_FLASH_ATTN_V2_TRITON="0"` when launching TGI's docker container.
|
||||||
|
|
||||||
|
## Custom PagedAttention
|
||||||
|
|
||||||
|
For better performance on ROCm, a custom Paged Attention kernel is available and is enabled by default. To disable it and fall back to the PagedAttention v2 kernel, set the environment variable `ROCM_USE_CUSTOM_PAGED_ATTN=0`.
|
||||||
|
|
||||||
|
The custom kernel supports bf16 and fp16 data types, block size of 16, head size of 128, a maximum context length of 16k, and GQA ratios between 1 and 16. For other configurations, we use the PagedAttention v2 kernel.
|
||||||
|
|
||||||
## Unsupported features
|
## Unsupported features
|
||||||
|
|
||||||
The following features are currently not supported in the ROCm version of TGI, and the supported may be extended in the future:
|
The following features are currently not supported in the ROCm version of TGI, and the supported may be extended in the future:
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
flash_att_v2_commit_cuda := v2.6.1
|
flash_att_v2_commit_cuda := v2.6.1
|
||||||
flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6
|
flash_att_v2_commit_rocm := 3cea2fb6ee54fb7e1aad9db6ac6c9331184b8647 # (Aug28)
|
||||||
|
|
||||||
build-flash-attention-v2-cuda:
|
build-flash-attention-v2-cuda:
|
||||||
pip install -U packaging wheel
|
pip install -U packaging wheel
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
commit_cuda := d243e9dc7e2c9c2e36a4150ec8e64809cb55c01b
|
commit_cuda := d243e9dc7e2c9c2e36a4150ec8e64809cb55c01b
|
||||||
commit_rocm := c6ee53b1be97e3bbc791b95f22827501297f8921
|
commit_rocm := 4e0929e6e4fa0a3d09d358715c288020ea9dc247
|
||||||
build-vllm-cuda:
|
build-vllm-cuda:
|
||||||
if [ ! -d 'vllm' ]; then \
|
if [ ! -d 'vllm' ]; then \
|
||||||
pip install -U ninja packaging --no-cache-dir && \
|
pip install -U ninja packaging --no-cache-dir && \
|
||||||
|
@ -13,7 +13,7 @@ install-vllm-cuda: build-vllm-cuda
|
||||||
build-vllm-rocm:
|
build-vllm-rocm:
|
||||||
if [ ! -d 'vllm' ]; then \
|
if [ ! -d 'vllm' ]; then \
|
||||||
pip install -U ninja packaging --no-cache-dir && \
|
pip install -U ninja packaging --no-cache-dir && \
|
||||||
git clone https://github.com/fxmarty/rocm-vllm.git vllm; \
|
git clone https://github.com/mht-sharma/vllm.git vllm; \
|
||||||
fi
|
fi
|
||||||
cd vllm && git fetch && git checkout $(commit_rocm) && \
|
cd vllm && git fetch && git checkout $(commit_rocm) && \
|
||||||
PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build
|
PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import os
|
import os
|
||||||
|
from typing import Optional
|
||||||
import torch
|
import torch
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.models.globals import ATTENTION
|
from text_generation_server.models.globals import ATTENTION
|
||||||
|
@ -8,13 +9,19 @@ from loguru import logger
|
||||||
|
|
||||||
major, minor = torch.cuda.get_device_capability()
|
major, minor = torch.cuda.get_device_capability()
|
||||||
is_sm75 = major == 7 and minor == 5
|
is_sm75 = major == 7 and minor == 5
|
||||||
_PARTITION_SIZE = 512
|
|
||||||
|
_PARTITION_SIZE_V1V2 = 512
|
||||||
|
_PARTITION_SIZE_CUSTOM = 256
|
||||||
|
|
||||||
use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"}
|
use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"}
|
||||||
ENGINE = "triton" if use_triton else "ck"
|
ENGINE = "triton" if use_triton else "ck"
|
||||||
|
|
||||||
|
custom_attn_available = os.getenv("ROCM_USE_CUSTOM_PAGED_ATTN", "1") != "0"
|
||||||
|
if custom_attn_available:
|
||||||
|
from vllm._custom_C import paged_attention_custom
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from vllm._C import cache_ops
|
import vllm._custom_ops as ops
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
|
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
|
||||||
|
@ -33,9 +40,7 @@ def reshape_and_cache(
|
||||||
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
||||||
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
||||||
else:
|
else:
|
||||||
cache_ops.reshape_and_cache(
|
ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
|
||||||
key, value, key_cache, value_cache, slots, "auto", 1.0
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def paged_attention(
|
def paged_attention(
|
||||||
|
@ -45,8 +50,9 @@ def paged_attention(
|
||||||
kv_head_mapping: torch.Tensor,
|
kv_head_mapping: torch.Tensor,
|
||||||
softmax_scale: float,
|
softmax_scale: float,
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
input_lengths: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
|
softcap: Optional[float] = None,
|
||||||
):
|
):
|
||||||
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
||||||
# Copyright 2023 The vLLM team. All rights
|
# Copyright 2023 The vLLM team. All rights
|
||||||
|
@ -68,8 +74,25 @@ def paged_attention(
|
||||||
# value_cache => [num_blocks, num_heads, head_size, block_size]
|
# value_cache => [num_blocks, num_heads, head_size, block_size]
|
||||||
block_size = value_cache.shape[3]
|
block_size = value_cache.shape[3]
|
||||||
num_seqs, num_heads, head_size = query.shape
|
num_seqs, num_heads, head_size = query.shape
|
||||||
|
|
||||||
|
num_kv_heads = key_cache.shape[1]
|
||||||
|
gqa_ratio = num_heads // num_kv_heads
|
||||||
|
use_custom = (
|
||||||
|
custom_attn_available
|
||||||
|
and (query.dtype == torch.half or query.dtype == torch.bfloat16)
|
||||||
|
and (head_size == 128 or head_size == 64)
|
||||||
|
and (block_size == 16 or block_size == 32)
|
||||||
|
and (gqa_ratio >= 1 and gqa_ratio <= 16)
|
||||||
|
and max_s <= 32768
|
||||||
|
)
|
||||||
|
|
||||||
|
if not use_custom:
|
||||||
|
_PARTITION_SIZE = _PARTITION_SIZE_V1V2
|
||||||
|
else:
|
||||||
|
_PARTITION_SIZE = _PARTITION_SIZE_CUSTOM
|
||||||
|
|
||||||
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
||||||
input_lengths = input_lengths.input_lengths
|
input_lengths = seqlen.input_lengths
|
||||||
|
|
||||||
out = torch.empty_like(query)
|
out = torch.empty_like(query)
|
||||||
|
|
||||||
|
@ -78,9 +101,13 @@ def paged_attention(
|
||||||
# V1 to avoid the overhead of reduction. Also, if the number of
|
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||||
# sequences or heads is large, we use V1 since there is enough work
|
# sequences or heads is large, we use V1 since there is enough work
|
||||||
# to parallelize.
|
# to parallelize.
|
||||||
from vllm._C import ops
|
import vllm._custom_ops as ops
|
||||||
|
|
||||||
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
|
use_v1 = (
|
||||||
|
max_s <= 8192
|
||||||
|
and (max_num_partitions == 1 or num_seqs * num_heads > 512)
|
||||||
|
and not use_custom
|
||||||
|
)
|
||||||
if use_v1:
|
if use_v1:
|
||||||
ops.paged_attention_v1(
|
ops.paged_attention_v1(
|
||||||
out,
|
out,
|
||||||
|
@ -112,6 +139,7 @@ def paged_attention(
|
||||||
)
|
)
|
||||||
max_logits = torch.empty_like(exp_sums)
|
max_logits = torch.empty_like(exp_sums)
|
||||||
|
|
||||||
|
if not use_custom:
|
||||||
ops.paged_attention_v2(
|
ops.paged_attention_v2(
|
||||||
out,
|
out,
|
||||||
exp_sums,
|
exp_sums,
|
||||||
|
@ -130,6 +158,25 @@ def paged_attention(
|
||||||
"auto",
|
"auto",
|
||||||
1.0,
|
1.0,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
paged_attention_custom(
|
||||||
|
out,
|
||||||
|
exp_sums,
|
||||||
|
max_logits,
|
||||||
|
tmp_output,
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
num_kv_heads,
|
||||||
|
softmax_scale,
|
||||||
|
block_tables,
|
||||||
|
input_lengths,
|
||||||
|
block_size,
|
||||||
|
max_s,
|
||||||
|
None,
|
||||||
|
"auto",
|
||||||
|
)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@ -173,13 +220,14 @@ if ENGINE == "ck":
|
||||||
|
|
||||||
def attention(
|
def attention(
|
||||||
q,
|
q,
|
||||||
k,
|
key_cache: torch.Tensor,
|
||||||
v,
|
value_cache: torch.Tensor,
|
||||||
cu_seqlens,
|
seqlen: Seqlen,
|
||||||
max_s,
|
block_tables: torch.Tensor,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
window_size_left=-1,
|
window_size_left=-1,
|
||||||
causal=True,
|
causal=True,
|
||||||
|
softcap=0.0,
|
||||||
):
|
):
|
||||||
if window_size_left <= 0 and window_size_left != -1:
|
if window_size_left <= 0 and window_size_left != -1:
|
||||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||||
|
@ -189,46 +237,54 @@ if ENGINE == "ck":
|
||||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||||
return flash_attn_2_cuda.varlen_fwd(
|
return flash_attn_2_cuda.varlen_fwd(
|
||||||
q,
|
q,
|
||||||
k,
|
key_cache,
|
||||||
v,
|
value_cache,
|
||||||
out,
|
out,
|
||||||
cu_seqlens,
|
seqlen.cu_seqlen_q,
|
||||||
cu_seqlens,
|
seqlen.cu_seqlen_q,
|
||||||
max_s,
|
None,
|
||||||
max_s,
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
seqlen.max_q,
|
||||||
|
seqlen.max_k,
|
||||||
0.0,
|
0.0,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
False,
|
False,
|
||||||
causal,
|
causal,
|
||||||
|
window_size_left,
|
||||||
|
0,
|
||||||
|
softcap,
|
||||||
False,
|
False,
|
||||||
None,
|
None,
|
||||||
)
|
)[0]
|
||||||
|
|
||||||
elif ENGINE == "triton":
|
elif ENGINE == "triton":
|
||||||
from .flash_attn_triton import triton_attention
|
from .flash_attn_triton import triton_attention
|
||||||
|
|
||||||
def attention(
|
def attention(
|
||||||
q,
|
q,
|
||||||
k,
|
key_cache: torch.Tensor,
|
||||||
v,
|
value_cache: torch.Tensor,
|
||||||
cu_seqlens,
|
seqlen: Seqlen,
|
||||||
max_s,
|
block_tables: torch.Tensor,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
window_size_left=-1,
|
window_size_left=-1,
|
||||||
causal=True,
|
causal=True,
|
||||||
|
softcap=0.0,
|
||||||
):
|
):
|
||||||
out = torch.empty_like(q)
|
out = torch.empty_like(q)
|
||||||
|
|
||||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||||
output, _ = triton_attention(
|
output, _ = triton_attention(
|
||||||
q,
|
q,
|
||||||
k,
|
key_cache,
|
||||||
v,
|
value_cache,
|
||||||
out,
|
out,
|
||||||
cu_seqlens,
|
seqlen.cu_seqlen_q,
|
||||||
cu_seqlens,
|
seqlen.cu_seqlen_q,
|
||||||
max_s,
|
seqlen.max_q,
|
||||||
max_s,
|
seqlen.max_k,
|
||||||
causal,
|
causal,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,8 +1,15 @@
|
||||||
import torch
|
import torch
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
import os
|
||||||
|
|
||||||
if SYSTEM == "rocm":
|
if SYSTEM == "rocm":
|
||||||
|
ROCM_USE_SKINNY_GEMM = os.getenv("ROCM_USE_SKINNY_GEMM", "True").lower() in (
|
||||||
|
"true",
|
||||||
|
"1",
|
||||||
|
)
|
||||||
|
|
||||||
|
if ROCM_USE_SKINNY_GEMM:
|
||||||
try:
|
try:
|
||||||
from vllm import _custom_C
|
from vllm import _custom_C
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -48,6 +55,14 @@ class FastLinearROCm(torch.nn.Module):
|
||||||
else:
|
else:
|
||||||
self.bias = None
|
self.bias = None
|
||||||
|
|
||||||
|
self.cu_count = torch.cuda.get_device_properties(
|
||||||
|
device="cuda"
|
||||||
|
).multi_processor_count
|
||||||
|
self.use_skinny_gemm = (
|
||||||
|
ROCM_USE_SKINNY_GEMM
|
||||||
|
and "gfx1" not in torch.cuda.get_device_properties("cuda").gcnArchName
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, config, prefix: str, weights, bias: bool):
|
def load(cls, config, prefix: str, weights, bias: bool):
|
||||||
weight = weights.get_tensor(f"{prefix}.weight")
|
weight = weights.get_tensor(f"{prefix}.weight")
|
||||||
|
@ -61,7 +76,11 @@ class FastLinearROCm(torch.nn.Module):
|
||||||
weight = self.weight
|
weight = self.weight
|
||||||
bias = self.bias
|
bias = self.bias
|
||||||
|
|
||||||
if SYSTEM == "rocm" and inp.numel() // inp.shape[-1] == 1:
|
if (
|
||||||
|
self.use_skinny_gemm
|
||||||
|
and inp.dtype == torch.float16
|
||||||
|
and inp.shape[-1] % 8 == 0
|
||||||
|
):
|
||||||
batched = False
|
batched = False
|
||||||
inp_shape = inp.shape
|
inp_shape = inp.shape
|
||||||
|
|
||||||
|
@ -69,13 +88,16 @@ class FastLinearROCm(torch.nn.Module):
|
||||||
inp = inp.view(-1, inp_shape[-1])
|
inp = inp.view(-1, inp_shape[-1])
|
||||||
batched = True
|
batched = True
|
||||||
|
|
||||||
m, k = weight.shape[0], inp_shape[1]
|
m, n, k = weight.shape[0], inp_shape[0], inp_shape[1]
|
||||||
|
if m > 8 and n <= 4:
|
||||||
out = torch.empty(
|
out = torch.empty(
|
||||||
inp_shape[0], weight.shape[0], dtype=inp.dtype, device="cuda"
|
inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
|
||||||
|
)
|
||||||
|
_custom_C.wvSpltK(weight, inp, out, n, self.cu_count)
|
||||||
|
elif m % 4 == 0 and n == 1 and k <= 8192:
|
||||||
|
out = torch.empty(
|
||||||
|
inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
|
||||||
)
|
)
|
||||||
if (k == 8192 and (m == 1280 or m == 7168)) or (k == 3584 and m == 8192):
|
|
||||||
_custom_C.LLMM1(weight, inp, out, 8)
|
|
||||||
elif k <= 8192 and k % 8 == 0 and m % 4 == 0:
|
|
||||||
_custom_C.LLMM1(weight, inp, out, 4)
|
_custom_C.LLMM1(weight, inp, out, 4)
|
||||||
else:
|
else:
|
||||||
out = F.linear(inp, weight)
|
out = F.linear(inp, weight)
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from text_generation_server.models.globals import PAGED_KV
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
|
@ -297,8 +298,8 @@ class FlashCohereAttention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if SYSTEM != "ipex" else key,
|
kv_cache[0] if PAGED_KV else key,
|
||||||
kv_cache[1] if SYSTEM != "ipex" else value,
|
kv_cache[1] if PAGED_KV else value,
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from text_generation_server.models.globals import PAGED_KV
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
|
@ -336,8 +337,8 @@ class DbrxAttention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
|
kv_cache[0] if PAGED_KV else kv[:, 0],
|
||||||
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
|
kv_cache[1] if PAGED_KV else kv[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
|
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
from text_generation_server.models.globals import PAGED_KV
|
||||||
from moe_kernels.fused_moe import grouped_topk
|
from moe_kernels.fused_moe import grouped_topk
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
@ -327,8 +328,8 @@ class DeepseekV2Attention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if SYSTEM != "ipex" else key,
|
kv_cache[0] if PAGED_KV else key,
|
||||||
kv_cache[1] if SYSTEM != "ipex" else value,
|
kv_cache[1] if PAGED_KV else value,
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -388,6 +389,7 @@ class DeepseekV2MLP(nn.Module):
|
||||||
def forward(self, hidden_states: torch.Tensor, reduce: bool = True):
|
def forward(self, hidden_states: torch.Tensor, reduce: bool = True):
|
||||||
if (
|
if (
|
||||||
SYSTEM == "rocm"
|
SYSTEM == "rocm"
|
||||||
|
and hidden_states.dtype == torch.float16
|
||||||
and self.hidden_act == "silu"
|
and self.hidden_act == "silu"
|
||||||
and hidden_states.shape[0] == 1
|
and hidden_states.shape[0] == 1
|
||||||
and not self.quantize
|
and not self.quantize
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from text_generation_server.models.globals import PAGED_KV
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
|
@ -25,7 +26,6 @@ from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
|
@ -237,8 +237,8 @@ class FlashGemma2Attention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
|
kv_cache[0] if PAGED_KV else kv[:, 0],
|
||||||
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
|
kv_cache[1] if PAGED_KV else kv[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from text_generation_server.models.globals import PAGED_KV
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
|
@ -25,7 +26,6 @@ from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
|
@ -231,8 +231,8 @@ class FlashGemmaAttention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
|
kv_cache[0] if PAGED_KV else kv[:, 0],
|
||||||
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
|
kv_cache[1] if PAGED_KV else kv[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -18,13 +18,13 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from text_generation_server.models.globals import PAGED_KV
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
|
@ -231,8 +231,8 @@ class FlashGPT2Attention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if SYSTEM != "ipex" else key,
|
kv_cache[0] if PAGED_KV else key,
|
||||||
kv_cache[1] if SYSTEM != "ipex" else value,
|
kv_cache[1] if PAGED_KV else value,
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from text_generation_server.models.globals import PAGED_KV
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
|
@ -192,8 +193,8 @@ class FlashGPTJAttention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if SYSTEM != "ipex" else key,
|
kv_cache[0] if PAGED_KV else key,
|
||||||
kv_cache[1] if SYSTEM != "ipex" else value,
|
kv_cache[1] if PAGED_KV else value,
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -28,6 +28,7 @@ from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
from text_generation_server.models.globals import PAGED_KV
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
|
@ -220,8 +221,8 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
|
kv_cache[0] if PAGED_KV else kv[:, 0],
|
||||||
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
|
kv_cache[1] if PAGED_KV else kv[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -315,11 +316,16 @@ class LlamaMLP(nn.Module):
|
||||||
# TODO: This is a hotfix to be removed & properly refactored.
|
# TODO: This is a hotfix to be removed & properly refactored.
|
||||||
self.quantize = config.quantize
|
self.quantize = config.quantize
|
||||||
|
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
|
||||||
def forward(self, hidden_states, adapter_data):
|
def forward(self, hidden_states, adapter_data):
|
||||||
if (
|
if (
|
||||||
SYSTEM == "rocm"
|
SYSTEM == "rocm"
|
||||||
|
and hidden_states.dtype == torch.float16
|
||||||
and self.hidden_act == "silu"
|
and self.hidden_act == "silu"
|
||||||
and hidden_states.shape[0] == 1
|
and hidden_states.shape[0] == 1
|
||||||
|
and self.hidden_size
|
||||||
|
!= 16384 # TODO: Temporary workaround for `LLMM_Silu` kernel not working with LLama3.1 405B; needs refactoring once fixed.
|
||||||
and not self.quantize
|
and not self.quantize
|
||||||
):
|
):
|
||||||
out = torch.empty(
|
out = torch.empty(
|
||||||
|
@ -555,6 +561,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
position_ids,
|
position_ids,
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from text_generation_server.models.globals import PAGED_KV
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
|
@ -218,8 +219,8 @@ class MistralAttention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
|
kv_cache[0] if PAGED_KV else kv_to_cache[:, 0],
|
||||||
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
|
kv_cache[1] if PAGED_KV else kv_to_cache[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -301,6 +302,7 @@ class MistralMLP(nn.Module):
|
||||||
def forward(self, hidden_states, adapter_data):
|
def forward(self, hidden_states, adapter_data):
|
||||||
if (
|
if (
|
||||||
SYSTEM == "rocm"
|
SYSTEM == "rocm"
|
||||||
|
and hidden_states.dtype == torch.float16
|
||||||
and self.hidden_act == "silu"
|
and self.hidden_act == "silu"
|
||||||
and hidden_states.shape[0] == 1
|
and hidden_states.shape[0] == 1
|
||||||
and not self.quantize
|
and not self.quantize
|
||||||
|
|
|
@ -18,12 +18,12 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from text_generation_server.models.globals import PAGED_KV
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
@ -274,8 +274,8 @@ class MixtralAttention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
|
kv_cache[0] if PAGED_KV else kv_to_cache[:, 0],
|
||||||
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
|
kv_cache[1] if PAGED_KV else kv_to_cache[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from text_generation_server.models.globals import PAGED_KV
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
|
@ -26,7 +27,6 @@ from transformers.activations import ACT2FN
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
from transformers.models.gpt_neox import GPTNeoXConfig as TransformersGPTNeoXConfig
|
from transformers.models.gpt_neox import GPTNeoXConfig as TransformersGPTNeoXConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
paged_attention,
|
paged_attention,
|
||||||
attention,
|
attention,
|
||||||
|
@ -172,8 +172,8 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
qkv[:, 0],
|
qkv[:, 0],
|
||||||
kv_cache[0] if SYSTEM != "ipex" else qkv[:, 1],
|
kv_cache[0] if PAGED_KV else qkv[:, 1],
|
||||||
kv_cache[1] if SYSTEM != "ipex" else qkv[:, 2],
|
kv_cache[1] if PAGED_KV else qkv[:, 2],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from text_generation_server.models.globals import PAGED_KV
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
|
@ -25,7 +26,6 @@ from text_generation_server.layers.layernorm import (
|
||||||
from text_generation_server.layers.rotary import (
|
from text_generation_server.layers.rotary import (
|
||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
|
|
||||||
class PhiConfig(PretrainedConfig):
|
class PhiConfig(PretrainedConfig):
|
||||||
|
@ -194,8 +194,8 @@ class FlashPhiAttention(torch.nn.Module):
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
|
kv_cache[0] if PAGED_KV else kv[:, 0],
|
||||||
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
|
kv_cache[1] if PAGED_KV else kv[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from text_generation_server.models.globals import PAGED_KV
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
|
@ -21,7 +22,6 @@ from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
|
|
||||||
def load_attention(config, prefix, weights):
|
def load_attention(config, prefix, weights):
|
||||||
|
@ -137,8 +137,8 @@ class Qwen2Attention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
|
kv_cache[0] if PAGED_KV else kv_to_cache[:, 0],
|
||||||
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
|
kv_cache[1] if PAGED_KV else kv_to_cache[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
from text_generation_server.models.globals import PAGED_KV
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
|
@ -207,8 +207,8 @@ class FlashRWAttention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if SYSTEM != "ipex" else kv[:, 0],
|
kv_cache[0] if PAGED_KV else kv[:, 0],
|
||||||
kv_cache[1] if SYSTEM != "ipex" else kv[:, 1],
|
kv_cache[1] if PAGED_KV else kv[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
@ -325,8 +325,8 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if SYSTEM != "ipex" else kv[:, :, 0].contiguous(),
|
kv_cache[0] if PAGED_KV else kv[:, :, 0].contiguous(),
|
||||||
kv_cache[1] if SYSTEM != "ipex" else kv[:, :, 1].contiguous(),
|
kv_cache[1] if PAGED_KV else kv[:, :, 1].contiguous(),
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from text_generation_server.models.globals import PAGED_KV
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
|
@ -22,7 +23,6 @@ from text_generation_server.layers.gptq import GPTQWeightsLoader
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
FastLayerNorm,
|
FastLayerNorm,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
|
|
||||||
def load_multi_mqa(
|
def load_multi_mqa(
|
||||||
|
@ -293,8 +293,8 @@ class FlashMQAttention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if SYSTEM != "ipex" else key_value[:, 0],
|
kv_cache[0] if PAGED_KV else key_value[:, 0],
|
||||||
kv_cache[1] if SYSTEM != "ipex" else key_value[:, 1],
|
kv_cache[1] if PAGED_KV else key_value[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from text_generation_server.models.globals import PAGED_KV
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
|
@ -47,7 +48,6 @@ from text_generation_server.layers.rotary import (
|
||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.weights import UnquantizedWeight
|
from text_generation_server.utils.weights import UnquantizedWeight
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
|
|
||||||
class Starcoder2Config(PretrainedConfig):
|
class Starcoder2Config(PretrainedConfig):
|
||||||
|
@ -242,8 +242,8 @@ class Starcoder2Attention(torch.nn.Module):
|
||||||
# flash attention
|
# flash attention
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
query,
|
query,
|
||||||
kv_cache[0] if SYSTEM != "ipex" else kv_to_cache[:, 0],
|
kv_cache[0] if PAGED_KV else kv_to_cache[:, 0],
|
||||||
kv_cache[1] if SYSTEM != "ipex" else kv_to_cache[:, 1],
|
kv_cache[1] if PAGED_KV else kv_to_cache[:, 1],
|
||||||
seqlen,
|
seqlen,
|
||||||
block_tables,
|
block_tables,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
|
|
@ -1125,12 +1125,12 @@ class FlashCausalLM(Model):
|
||||||
else:
|
else:
|
||||||
self.kv_cache = [
|
self.kv_cache = [
|
||||||
(
|
(
|
||||||
torch.empty(
|
torch.zeros(
|
||||||
(num_blocks, num_heads, head_size // x, BLOCK_SIZE, x),
|
(num_blocks, num_heads, head_size // x, BLOCK_SIZE, x),
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
),
|
),
|
||||||
torch.empty(
|
torch.zeros(
|
||||||
(num_blocks, num_heads, head_size, BLOCK_SIZE),
|
(num_blocks, num_heads, head_size, BLOCK_SIZE),
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
|
@ -1320,8 +1320,7 @@ class FlashCausalLM(Model):
|
||||||
elif CUDA_GRAPHS is not None:
|
elif CUDA_GRAPHS is not None:
|
||||||
tuning_sequences = CUDA_GRAPHS
|
tuning_sequences = CUDA_GRAPHS
|
||||||
else:
|
else:
|
||||||
# For seqlen = 1, we dispatch to LLMM1 kernel.
|
tuning_sequences = [1, 2, 3, 4, 5, 6, 7]
|
||||||
tuning_sequences = [2, 3, 4, 5, 6, 7]
|
|
||||||
|
|
||||||
tunableop_filepath = os.path.join(
|
tunableop_filepath = os.path.join(
|
||||||
HUGGINGFACE_HUB_CACHE,
|
HUGGINGFACE_HUB_CACHE,
|
||||||
|
@ -1330,7 +1329,11 @@ class FlashCausalLM(Model):
|
||||||
|
|
||||||
log_master(
|
log_master(
|
||||||
logger.info,
|
logger.info,
|
||||||
f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`.",
|
f"PyTorch TunableOp is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`.",
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.cuda.tunable.set_filename(
|
||||||
|
tunableop_filepath, insert_device_ordinal=False
|
||||||
)
|
)
|
||||||
|
|
||||||
if os.path.isfile(tunableop_filepath):
|
if os.path.isfile(tunableop_filepath):
|
||||||
|
@ -1346,6 +1349,7 @@ class FlashCausalLM(Model):
|
||||||
log_master(logger.info, f"Warming up TunableOp for seqlen={seqlen}")
|
log_master(logger.info, f"Warming up TunableOp for seqlen={seqlen}")
|
||||||
self.tunableop_warmup(seqlen)
|
self.tunableop_warmup(seqlen)
|
||||||
torch.cuda.tunable.write_file(tunableop_filepath)
|
torch.cuda.tunable.write_file(tunableop_filepath)
|
||||||
|
if os.environ.get("PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP") != "1":
|
||||||
torch.cuda.tunable.tuning_enable(False)
|
torch.cuda.tunable.tuning_enable(False)
|
||||||
else:
|
else:
|
||||||
log_master(
|
log_master(
|
||||||
|
@ -1382,6 +1386,7 @@ class FlashCausalLM(Model):
|
||||||
cu_seqlen_prefill = torch.tensor(
|
cu_seqlen_prefill = torch.tensor(
|
||||||
[0, seqlen], device=self.device, dtype=torch.int32
|
[0, seqlen], device=self.device, dtype=torch.int32
|
||||||
)
|
)
|
||||||
|
max_s = seqlen
|
||||||
seqlen = Seqlen(
|
seqlen = Seqlen(
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
prefix_lengths=prefix_lens_tensor,
|
prefix_lengths=prefix_lens_tensor,
|
||||||
|
@ -1399,7 +1404,7 @@ class FlashCausalLM(Model):
|
||||||
block_tables=None,
|
block_tables=None,
|
||||||
seqlen=seqlen,
|
seqlen=seqlen,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
max_s=seqlen,
|
max_s=max_s,
|
||||||
lm_head_indices=None,
|
lm_head_indices=None,
|
||||||
prefill_cache_indices=None,
|
prefill_cache_indices=None,
|
||||||
)
|
)
|
||||||
|
|
|
@ -4,6 +4,7 @@ from loguru import logger
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
from text_generation_server.utils.log import log_master
|
from text_generation_server.utils.log import log_master
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING").lower() in {"1", "true"}
|
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING").lower() in {"1", "true"}
|
||||||
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
|
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
|
||||||
|
@ -52,6 +53,12 @@ CUDA_GRAPHS = cuda_graphs
|
||||||
# index in all cases.
|
# index in all cases.
|
||||||
ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None
|
ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None
|
||||||
|
|
||||||
|
PAGED_KV: bool
|
||||||
|
if SYSTEM in {"rocm", "ipex"}:
|
||||||
|
PAGED_KV = False
|
||||||
|
else:
|
||||||
|
PAGED_KV = True
|
||||||
|
|
||||||
|
|
||||||
def set_adapter_to_index(adapter_to_index: Dict[str, int]):
|
def set_adapter_to_index(adapter_to_index: Dict[str, int]):
|
||||||
global ADAPTER_TO_INDEX
|
global ADAPTER_TO_INDEX
|
||||||
|
|
Loading…
Reference in New Issue