feat(server): flash attention v2 (#624)
This commit is contained in:
parent
4d38a1c4ad
commit
3b71c38558
15
Dockerfile
15
Dockerfile
|
@ -98,6 +98,16 @@ COPY server/Makefile-flash-att Makefile
|
||||||
# Build specific version of flash attention
|
# Build specific version of flash attention
|
||||||
RUN make build-flash-attention
|
RUN make build-flash-attention
|
||||||
|
|
||||||
|
# Build Flash Attention v2 CUDA kernels
|
||||||
|
FROM kernel-builder as flash-att-v2-builder
|
||||||
|
|
||||||
|
WORKDIR /usr/src
|
||||||
|
|
||||||
|
COPY server/Makefile-flash-att-v2 Makefile
|
||||||
|
|
||||||
|
# Build specific version of flash attention v2
|
||||||
|
RUN make build-flash-attention-v2
|
||||||
|
|
||||||
# Build Transformers CUDA kernels
|
# Build Transformers CUDA kernels
|
||||||
FROM kernel-builder as custom-kernels-builder
|
FROM kernel-builder as custom-kernels-builder
|
||||||
|
|
||||||
|
@ -146,8 +156,11 @@ COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cp
|
||||||
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
||||||
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
||||||
|
|
||||||
|
# Copy build artifacts from flash attention v2 builder
|
||||||
|
COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
||||||
|
|
||||||
# Copy build artifacts from custom kernels builder
|
# Copy build artifacts from custom kernels builder
|
||||||
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39/custom_kernels /usr/src/custom-kernels/src/custom_kernels
|
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
||||||
|
|
||||||
# Copy builds artifacts from vllm builder
|
# Copy builds artifacts from vllm builder
|
||||||
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
include Makefile-flash-att
|
include Makefile-flash-att
|
||||||
|
include Makefile-flash-att-v2
|
||||||
include Makefile-vllm
|
include Makefile-vllm
|
||||||
|
|
||||||
unit-tests:
|
unit-tests:
|
||||||
|
|
|
@ -0,0 +1,13 @@
|
||||||
|
flash_att_v2_commit := 4f285b354796fb17df8636485b9a04df3ebbb7dc
|
||||||
|
|
||||||
|
flash-attention-v2:
|
||||||
|
# Clone flash attention
|
||||||
|
pip install packaging
|
||||||
|
git clone https://github.com/HazyResearch/flash-attention.git flash-attention-v2
|
||||||
|
|
||||||
|
build-flash-attention-v2: flash-attention-v2
|
||||||
|
cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit)
|
||||||
|
cd flash-attention-v2 && python setup.py build
|
||||||
|
|
||||||
|
install-flash-attention-v2: build-flash-attention-v2
|
||||||
|
cd flash-attention-v2 && python setup.py install
|
|
@ -42,35 +42,10 @@ __all__ = [
|
||||||
"get_model",
|
"get_model",
|
||||||
]
|
]
|
||||||
|
|
||||||
FLASH_ATT_ERROR_MESSAGE = (
|
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
|
||||||
"{} requires CUDA and Flash Attention kernels to be installed.\n"
|
|
||||||
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
|
||||||
"or install flash attention with `cd server && make install install-flash-attention`"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
FLASH_ATTENTION = True
|
||||||
try:
|
try:
|
||||||
if not os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
|
||||||
if not torch.cuda.is_available():
|
|
||||||
FLASH_ATT_ERROR_MESSAGE = (
|
|
||||||
"{} requires CUDA. No compatible CUDA devices found."
|
|
||||||
)
|
|
||||||
raise ImportError("CUDA is not available")
|
|
||||||
|
|
||||||
major, minor = torch.cuda.get_device_capability()
|
|
||||||
is_sm75 = major == 7 and minor == 5
|
|
||||||
is_sm8x = major == 8 and minor >= 0
|
|
||||||
is_sm90 = major == 9 and minor == 0
|
|
||||||
|
|
||||||
supported = is_sm75 or is_sm8x or is_sm90
|
|
||||||
if not supported:
|
|
||||||
FLASH_ATT_ERROR_MESSAGE = (
|
|
||||||
"{} requires a CUDA device with capability 7.5, > 8.0 or 9.0. "
|
|
||||||
"No compatible CUDA device found."
|
|
||||||
)
|
|
||||||
raise ImportError(
|
|
||||||
f"GPU with CUDA capability {major} {minor} is not supported"
|
|
||||||
)
|
|
||||||
|
|
||||||
from text_generation_server.models.flash_rw import FlashRWSharded
|
from text_generation_server.models.flash_rw import FlashRWSharded
|
||||||
from text_generation_server.models.flash_neox import FlashNeoXSharded
|
from text_generation_server.models.flash_neox import FlashNeoXSharded
|
||||||
from text_generation_server.models.flash_llama import (
|
from text_generation_server.models.flash_llama import (
|
||||||
|
@ -80,13 +55,8 @@ try:
|
||||||
FlashSantacoderSharded,
|
FlashSantacoderSharded,
|
||||||
)
|
)
|
||||||
|
|
||||||
FLASH_ATTENTION = True
|
except ImportError as e:
|
||||||
else:
|
logger.warning(f"Could not import Flash Attention enabled models: {e}")
|
||||||
FLASH_ATTENTION = False
|
|
||||||
except ImportError:
|
|
||||||
logger.opt(exception=True).warning(
|
|
||||||
"Could not import Flash Attention enabled models"
|
|
||||||
)
|
|
||||||
FLASH_ATTENTION = False
|
FLASH_ATTENTION = False
|
||||||
|
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
|
|
|
@ -26,13 +26,13 @@ from transformers.activations import ACT2FN
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
# Flash attention imports
|
# Flash attention imports
|
||||||
import flash_attn_cuda
|
|
||||||
import dropout_layer_norm
|
import dropout_layer_norm
|
||||||
|
|
||||||
# vllm imports
|
# vllm imports
|
||||||
import vllm_cache_ops
|
import vllm_cache_ops
|
||||||
import vllm_attention_ops
|
import vllm_attention_ops
|
||||||
|
|
||||||
|
from text_generation_server.utils.flash_attn import attention
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
|
@ -164,22 +164,14 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda.fwd(
|
attention(
|
||||||
qkv[:, 0],
|
qkv[:, 0],
|
||||||
qkv[:, 1],
|
qkv[:, 1],
|
||||||
qkv[:, 2],
|
qkv[:, 2],
|
||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
cu_seqlen_prefill,
|
|
||||||
max_s,
|
max_s,
|
||||||
max_s,
|
|
||||||
0.0,
|
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
False,
|
|
||||||
True,
|
|
||||||
False,
|
|
||||||
0,
|
|
||||||
None,
|
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -27,13 +27,11 @@ from transformers.modeling_utils import PreTrainedModel
|
||||||
from transformers.models.gpt_neox import GPTNeoXConfig
|
from transformers.models.gpt_neox import GPTNeoXConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
# Flash attention imports
|
|
||||||
import flash_attn_cuda
|
|
||||||
|
|
||||||
# vllm imports
|
# vllm imports
|
||||||
import vllm_cache_ops
|
import vllm_cache_ops
|
||||||
import vllm_attention_ops
|
import vllm_attention_ops
|
||||||
|
|
||||||
|
from text_generation_server.utils.flash_attn import attention
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
|
@ -153,22 +151,14 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda.fwd(
|
attention(
|
||||||
qkv[:, 0],
|
qkv[:, 0],
|
||||||
qkv[:, 1],
|
qkv[:, 1],
|
||||||
qkv[:, 2],
|
qkv[:, 2],
|
||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
cu_seqlen_prefill,
|
|
||||||
max_s,
|
max_s,
|
||||||
max_s,
|
|
||||||
0.0,
|
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
False,
|
|
||||||
True,
|
|
||||||
False,
|
|
||||||
0,
|
|
||||||
None,
|
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -6,13 +6,11 @@ from transformers.modeling_utils import PreTrainedModel
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
# Flash attention imports
|
|
||||||
import flash_attn_cuda
|
|
||||||
|
|
||||||
# vllm imports
|
# vllm imports
|
||||||
import vllm_cache_ops
|
import vllm_cache_ops
|
||||||
import vllm_attention_ops
|
import vllm_attention_ops
|
||||||
|
|
||||||
|
from text_generation_server.utils.flash_attn import attention
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
|
@ -182,27 +180,15 @@ class FlashRWAttention(torch.nn.Module):
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
if self.num_heads_kv == 1:
|
|
||||||
# Expand to query shape
|
|
||||||
kv = kv.expand(-1, 2, self.num_heads, self.head_size)
|
|
||||||
|
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda.fwd(
|
attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
cu_seqlen_prefill,
|
|
||||||
max_s,
|
max_s,
|
||||||
max_s,
|
|
||||||
0.0,
|
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
False,
|
|
||||||
True,
|
|
||||||
False,
|
|
||||||
0,
|
|
||||||
None,
|
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
|
@ -314,30 +300,15 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# Expand to query shape
|
|
||||||
kv = (
|
|
||||||
kv.unsqueeze(2)
|
|
||||||
.expand(-1, self.num_groups, self.num_heads, 2, self.head_size)
|
|
||||||
.reshape(-1, self.num_groups * self.num_heads, 2, self.head_size)
|
|
||||||
)
|
|
||||||
|
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda.fwd(
|
attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=2, index=0),
|
torch.select(kv, dim=2, index=0),
|
||||||
torch.select(kv, dim=2, index=1),
|
torch.select(kv, dim=2, index=1),
|
||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
cu_seqlen_prefill,
|
|
||||||
max_s,
|
max_s,
|
||||||
max_s,
|
|
||||||
0.0,
|
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
False,
|
|
||||||
True,
|
|
||||||
False,
|
|
||||||
0,
|
|
||||||
None,
|
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -5,13 +5,11 @@ from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
# Flash attention imports
|
|
||||||
import flash_attn_cuda
|
|
||||||
|
|
||||||
# vllm imports
|
# vllm imports
|
||||||
import vllm_cache_ops
|
import vllm_cache_ops
|
||||||
import vllm_attention_ops
|
import vllm_attention_ops
|
||||||
|
|
||||||
|
from text_generation_server.utils.flash_attn import attention
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
|
@ -271,26 +269,15 @@ class FlashMQAttention(torch.nn.Module):
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# Expand from 1 to num_heads
|
|
||||||
key_value = key_value.expand(-1, 2, self.num_heads, self.head_size)
|
|
||||||
|
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn_cuda.fwd(
|
attention(
|
||||||
query,
|
query,
|
||||||
torch.select(key_value, dim=1, index=0),
|
torch.select(key_value, dim=1, index=0),
|
||||||
torch.select(key_value, dim=1, index=1),
|
torch.select(key_value, dim=1, index=1),
|
||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
cu_seqlen_prefill,
|
|
||||||
max_s,
|
max_s,
|
||||||
max_s,
|
|
||||||
0.0,
|
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
False,
|
|
||||||
True,
|
|
||||||
False,
|
|
||||||
0,
|
|
||||||
None,
|
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -0,0 +1,124 @@
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
||||||
|
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
raise ImportError("CUDA is not available")
|
||||||
|
|
||||||
|
major, minor = torch.cuda.get_device_capability()
|
||||||
|
is_sm75 = major == 7 and minor == 5
|
||||||
|
is_sm8x = major == 8 and minor >= 0
|
||||||
|
is_sm90 = major == 9 and minor == 0
|
||||||
|
|
||||||
|
HAS_FLASH_ATTN = False
|
||||||
|
HAS_FLASH_ATTN_V2 = False
|
||||||
|
try:
|
||||||
|
try:
|
||||||
|
import flash_attn_2_cuda
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Flash Attention V2 is not installed.\n"
|
||||||
|
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
||||||
|
"or install flash attention v2 with `cd server && make install install-flash-attention-v2`"
|
||||||
|
)
|
||||||
|
if not (is_sm8x or is_sm90):
|
||||||
|
raise ImportError(
|
||||||
|
f"GPU with CUDA capability {major} {minor} is not supported for "
|
||||||
|
"Flash Attention V2"
|
||||||
|
)
|
||||||
|
HAS_FLASH_ATTN_V2 = True
|
||||||
|
except ImportError as e:
|
||||||
|
try:
|
||||||
|
import flash_attn_cuda
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Flash Attention is not installed.\n"
|
||||||
|
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
||||||
|
"or install flash attention with `cd server && make install install-flash-attention`"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
if not (is_sm75 or is_sm8x or is_sm90):
|
||||||
|
raise ImportError(
|
||||||
|
f"GPU with CUDA capability {major} {minor} is not supported"
|
||||||
|
) from e
|
||||||
|
logger.warning(f"Unable to use Flash Attention V2: {e}")
|
||||||
|
HAS_FLASH_ATTN = True
|
||||||
|
|
||||||
|
|
||||||
|
def attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
softmax_scale,
|
||||||
|
):
|
||||||
|
if HAS_FLASH_ATTN_V2:
|
||||||
|
return flash_attn_2_cuda.varlen_fwd(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
cu_seqlens,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
max_s,
|
||||||
|
0.0,
|
||||||
|
softmax_scale,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if HAS_FLASH_ATTN:
|
||||||
|
# Flash attention v1 requires q, k and v to have the same number of heads
|
||||||
|
if k.shape[1] != q.shape[1]:
|
||||||
|
# MQA expand
|
||||||
|
if k.shape[1] == 1:
|
||||||
|
k = k.expand(-1, q.shape[1], -1)
|
||||||
|
# Grouped attention reshape
|
||||||
|
else:
|
||||||
|
original_shape = k.shape
|
||||||
|
k = (
|
||||||
|
k.unsqueeze(2)
|
||||||
|
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
|
||||||
|
.reshape(original_shape[0], -1, original_shape[2])
|
||||||
|
)
|
||||||
|
if v.shape[1] != q.shape[1]:
|
||||||
|
# MQA expand
|
||||||
|
if v.shape[1] == 1:
|
||||||
|
v = v.expand(-1, q.shape[1], -1)
|
||||||
|
# Grouped attention reshape
|
||||||
|
else:
|
||||||
|
original_shape = v.shape
|
||||||
|
v = (
|
||||||
|
v.unsqueeze(2)
|
||||||
|
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
|
||||||
|
.reshape(original_shape[0], -1, original_shape[2])
|
||||||
|
)
|
||||||
|
|
||||||
|
return flash_attn_cuda.fwd(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
cu_seqlens,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
max_s,
|
||||||
|
0.0,
|
||||||
|
softmax_scale,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
0,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
raise NotImplementedError("flash attention is not installed")
|
Loading…
Reference in New Issue