feat(server): flash attention v2 (#624)

This commit is contained in:
OlivierDehaene 2023-07-18 16:21:18 +02:00 committed by GitHub
parent 4d38a1c4ad
commit 3b71c38558
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 173 additions and 112 deletions

View File

@ -98,6 +98,16 @@ COPY server/Makefile-flash-att Makefile
# Build specific version of flash attention # Build specific version of flash attention
RUN make build-flash-attention RUN make build-flash-attention
# Build Flash Attention v2 CUDA kernels
FROM kernel-builder as flash-att-v2-builder
WORKDIR /usr/src
COPY server/Makefile-flash-att-v2 Makefile
# Build specific version of flash attention v2
RUN make build-flash-attention-v2
# Build Transformers CUDA kernels # Build Transformers CUDA kernels
FROM kernel-builder as custom-kernels-builder FROM kernel-builder as custom-kernels-builder
@ -146,8 +156,11 @@ COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cp
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
# Copy build artifacts from flash attention v2 builder
COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
# Copy build artifacts from custom kernels builder # Copy build artifacts from custom kernels builder
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39/custom_kernels /usr/src/custom-kernels/src/custom_kernels COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
# Copy builds artifacts from vllm builder # Copy builds artifacts from vllm builder
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages

View File

@ -1,4 +1,5 @@
include Makefile-flash-att include Makefile-flash-att
include Makefile-flash-att-v2
include Makefile-vllm include Makefile-vllm
unit-tests: unit-tests:

View File

@ -0,0 +1,13 @@
flash_att_v2_commit := 4f285b354796fb17df8636485b9a04df3ebbb7dc
flash-attention-v2:
# Clone flash attention
pip install packaging
git clone https://github.com/HazyResearch/flash-attention.git flash-attention-v2
build-flash-attention-v2: flash-attention-v2
cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit)
cd flash-attention-v2 && python setup.py build
install-flash-attention-v2: build-flash-attention-v2
cd flash-attention-v2 && python setup.py install

View File

@ -42,51 +42,21 @@ __all__ = [
"get_model", "get_model",
] ]
FLASH_ATT_ERROR_MESSAGE = ( FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
"{} requires CUDA and Flash Attention kernels to be installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention with `cd server && make install install-flash-attention`"
)
FLASH_ATTENTION = True
try: try:
if not os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": from text_generation_server.models.flash_rw import FlashRWSharded
if not torch.cuda.is_available(): from text_generation_server.models.flash_neox import FlashNeoXSharded
FLASH_ATT_ERROR_MESSAGE = ( from text_generation_server.models.flash_llama import (
"{} requires CUDA. No compatible CUDA devices found." FlashLlama,
)
raise ImportError("CUDA is not available")
major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
is_sm8x = major == 8 and minor >= 0
is_sm90 = major == 9 and minor == 0
supported = is_sm75 or is_sm8x or is_sm90
if not supported:
FLASH_ATT_ERROR_MESSAGE = (
"{} requires a CUDA device with capability 7.5, > 8.0 or 9.0. "
"No compatible CUDA device found."
)
raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported"
)
from text_generation_server.models.flash_rw import FlashRWSharded
from text_generation_server.models.flash_neox import FlashNeoXSharded
from text_generation_server.models.flash_llama import (
FlashLlama,
)
from text_generation_server.models.flash_santacoder import (
FlashSantacoderSharded,
)
FLASH_ATTENTION = True
else:
FLASH_ATTENTION = False
except ImportError:
logger.opt(exception=True).warning(
"Could not import Flash Attention enabled models"
) )
from text_generation_server.models.flash_santacoder import (
FlashSantacoderSharded,
)
except ImportError as e:
logger.warning(f"Could not import Flash Attention enabled models: {e}")
FLASH_ATTENTION = False FLASH_ATTENTION = False
if FLASH_ATTENTION: if FLASH_ATTENTION:

View File

@ -26,13 +26,13 @@ from transformers.activations import ACT2FN
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
# Flash attention imports # Flash attention imports
import flash_attn_cuda
import dropout_layer_norm import dropout_layer_norm
# vllm imports # vllm imports
import vllm_cache_ops import vllm_cache_ops
import vllm_attention_ops import vllm_attention_ops
from text_generation_server.utils.flash_attn import attention
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
@ -164,22 +164,14 @@ class FlashLlamaAttention(torch.nn.Module):
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
flash_attn_cuda.fwd( attention(
qkv[:, 0], qkv[:, 0],
qkv[:, 1], qkv[:, 1],
qkv[:, 2], qkv[:, 2],
attn_output, attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
cu_seqlen_prefill,
max_s, max_s,
max_s,
0.0,
self.softmax_scale, self.softmax_scale,
False,
True,
False,
0,
None,
) )
# Decode # Decode
else: else:

View File

@ -27,13 +27,11 @@ from transformers.modeling_utils import PreTrainedModel
from transformers.models.gpt_neox import GPTNeoXConfig from transformers.models.gpt_neox import GPTNeoXConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
# Flash attention imports
import flash_attn_cuda
# vllm imports # vllm imports
import vllm_cache_ops import vllm_cache_ops
import vllm_attention_ops import vllm_attention_ops
from text_generation_server.utils.flash_attn import attention
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
@ -153,22 +151,14 @@ class FlashNeoxAttention(torch.nn.Module):
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
flash_attn_cuda.fwd( attention(
qkv[:, 0], qkv[:, 0],
qkv[:, 1], qkv[:, 1],
qkv[:, 2], qkv[:, 2],
attn_output, attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
cu_seqlen_prefill,
max_s, max_s,
max_s,
0.0,
self.softmax_scale, self.softmax_scale,
False,
True,
False,
0,
None,
) )
# Decode # Decode
else: else:

View File

@ -6,13 +6,11 @@ from transformers.modeling_utils import PreTrainedModel
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
# Flash attention imports
import flash_attn_cuda
# vllm imports # vllm imports
import vllm_cache_ops import vllm_cache_ops
import vllm_attention_ops import vllm_attention_ops
from text_generation_server.utils.flash_attn import attention
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
@ -182,27 +180,15 @@ class FlashRWAttention(torch.nn.Module):
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
if self.num_heads_kv == 1:
# Expand to query shape
kv = kv.expand(-1, 2, self.num_heads, self.head_size)
# flash attention # flash attention
flash_attn_cuda.fwd( attention(
query, query,
torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1), torch.select(kv, dim=1, index=1),
attn_output, attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
cu_seqlen_prefill,
max_s, max_s,
max_s,
0.0,
self.softmax_scale, self.softmax_scale,
False,
True,
False,
0,
None,
) )
# Decode # Decode
else: else:
@ -314,30 +300,15 @@ class FlashRWLargeAttention(torch.nn.Module):
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# Expand to query shape
kv = (
kv.unsqueeze(2)
.expand(-1, self.num_groups, self.num_heads, 2, self.head_size)
.reshape(-1, self.num_groups * self.num_heads, 2, self.head_size)
)
# flash attention # flash attention
flash_attn_cuda.fwd( attention(
query, query,
torch.select(kv, dim=2, index=0), torch.select(kv, dim=2, index=0),
torch.select(kv, dim=2, index=1), torch.select(kv, dim=2, index=1),
attn_output, attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
cu_seqlen_prefill,
max_s, max_s,
max_s,
0.0,
self.softmax_scale, self.softmax_scale,
False,
True,
False,
0,
None,
) )
# Decode # Decode
else: else:

View File

@ -5,13 +5,11 @@ from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
# Flash attention imports
import flash_attn_cuda
# vllm imports # vllm imports
import vllm_cache_ops import vllm_cache_ops
import vllm_attention_ops import vllm_attention_ops
from text_generation_server.utils.flash_attn import attention
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
@ -271,26 +269,15 @@ class FlashMQAttention(torch.nn.Module):
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# Expand from 1 to num_heads
key_value = key_value.expand(-1, 2, self.num_heads, self.head_size)
# flash attention # flash attention
flash_attn_cuda.fwd( attention(
query, query,
torch.select(key_value, dim=1, index=0), torch.select(key_value, dim=1, index=0),
torch.select(key_value, dim=1, index=1), torch.select(key_value, dim=1, index=1),
attn_output, attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
cu_seqlen_prefill,
max_s, max_s,
max_s,
0.0,
self.softmax_scale, self.softmax_scale,
False,
True,
False,
0,
None,
) )
# Decode # Decode
else: else:

View File

@ -0,0 +1,124 @@
import os
import torch
from loguru import logger
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.")
if not torch.cuda.is_available():
raise ImportError("CUDA is not available")
major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
is_sm8x = major == 8 and minor >= 0
is_sm90 = major == 9 and minor == 0
HAS_FLASH_ATTN = False
HAS_FLASH_ATTN_V2 = False
try:
try:
import flash_attn_2_cuda
except ImportError:
raise ImportError(
"Flash Attention V2 is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention v2 with `cd server && make install install-flash-attention-v2`"
)
if not (is_sm8x or is_sm90):
raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported for "
"Flash Attention V2"
)
HAS_FLASH_ATTN_V2 = True
except ImportError as e:
try:
import flash_attn_cuda
except ImportError:
raise ImportError(
"Flash Attention is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention with `cd server && make install install-flash-attention`"
) from e
if not (is_sm75 or is_sm8x or is_sm90):
raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported"
) from e
logger.warning(f"Unable to use Flash Attention V2: {e}")
HAS_FLASH_ATTN = True
def attention(
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
):
if HAS_FLASH_ATTN_V2:
return flash_attn_2_cuda.varlen_fwd(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
0.0,
softmax_scale,
False,
True,
False,
None,
)
if HAS_FLASH_ATTN:
# Flash attention v1 requires q, k and v to have the same number of heads
if k.shape[1] != q.shape[1]:
# MQA expand
if k.shape[1] == 1:
k = k.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = k.shape
k = (
k.unsqueeze(2)
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
if v.shape[1] != q.shape[1]:
# MQA expand
if v.shape[1] == 1:
v = v.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = v.shape
v = (
v.unsqueeze(2)
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
return flash_attn_cuda.fwd(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
0.0,
softmax_scale,
False,
True,
False,
0,
None,
)
raise NotImplementedError("flash attention is not installed")