From 3b71c38558a9198c8424c09bbcfac03a93282d08 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Tue, 18 Jul 2023 16:21:18 +0200 Subject: [PATCH] feat(server): flash attention v2 (#624) --- Dockerfile | 15 ++- server/Makefile | 1 + server/Makefile-flash-att-v2 | 13 ++ .../text_generation_server/models/__init__.py | 54 ++------ .../custom_modeling/flash_llama_modeling.py | 12 +- .../custom_modeling/flash_neox_modeling.py | 14 +- .../custom_modeling/flash_rw_modeling.py | 35 +---- .../flash_santacoder_modeling.py | 17 +-- .../utils/flash_attn.py | 124 ++++++++++++++++++ 9 files changed, 173 insertions(+), 112 deletions(-) create mode 100644 server/Makefile-flash-att-v2 create mode 100644 server/text_generation_server/utils/flash_attn.py diff --git a/Dockerfile b/Dockerfile index 66e0091..168f2f9 100644 --- a/Dockerfile +++ b/Dockerfile @@ -98,6 +98,16 @@ COPY server/Makefile-flash-att Makefile # Build specific version of flash attention RUN make build-flash-attention +# Build Flash Attention v2 CUDA kernels +FROM kernel-builder as flash-att-v2-builder + +WORKDIR /usr/src + +COPY server/Makefile-flash-att-v2 Makefile + +# Build specific version of flash attention v2 +RUN make build-flash-attention-v2 + # Build Transformers CUDA kernels FROM kernel-builder as custom-kernels-builder @@ -146,8 +156,11 @@ COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cp COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages +# Copy build artifacts from flash attention v2 builder +COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages + # Copy build artifacts from custom kernels builder -COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39/custom_kernels /usr/src/custom-kernels/src/custom_kernels +COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages # Copy builds artifacts from vllm builder COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages diff --git a/server/Makefile b/server/Makefile index d008692..0dc0b5c 100644 --- a/server/Makefile +++ b/server/Makefile @@ -1,4 +1,5 @@ include Makefile-flash-att +include Makefile-flash-att-v2 include Makefile-vllm unit-tests: diff --git a/server/Makefile-flash-att-v2 b/server/Makefile-flash-att-v2 new file mode 100644 index 0000000..a7d6335 --- /dev/null +++ b/server/Makefile-flash-att-v2 @@ -0,0 +1,13 @@ +flash_att_v2_commit := 4f285b354796fb17df8636485b9a04df3ebbb7dc + +flash-attention-v2: + # Clone flash attention + pip install packaging + git clone https://github.com/HazyResearch/flash-attention.git flash-attention-v2 + +build-flash-attention-v2: flash-attention-v2 + cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit) + cd flash-attention-v2 && python setup.py build + +install-flash-attention-v2: build-flash-attention-v2 + cd flash-attention-v2 && python setup.py install \ No newline at end of file diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index fd97f8b..ffc224c 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -42,51 +42,21 @@ __all__ = [ "get_model", ] -FLASH_ATT_ERROR_MESSAGE = ( - "{} requires CUDA and Flash Attention kernels to be installed.\n" - "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " - "or install flash attention with `cd server && make install install-flash-attention`" -) +FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models." +FLASH_ATTENTION = True try: - if not os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": - if not torch.cuda.is_available(): - FLASH_ATT_ERROR_MESSAGE = ( - "{} requires CUDA. No compatible CUDA devices found." - ) - raise ImportError("CUDA is not available") - - major, minor = torch.cuda.get_device_capability() - is_sm75 = major == 7 and minor == 5 - is_sm8x = major == 8 and minor >= 0 - is_sm90 = major == 9 and minor == 0 - - supported = is_sm75 or is_sm8x or is_sm90 - if not supported: - FLASH_ATT_ERROR_MESSAGE = ( - "{} requires a CUDA device with capability 7.5, > 8.0 or 9.0. " - "No compatible CUDA device found." - ) - raise ImportError( - f"GPU with CUDA capability {major} {minor} is not supported" - ) - - from text_generation_server.models.flash_rw import FlashRWSharded - from text_generation_server.models.flash_neox import FlashNeoXSharded - from text_generation_server.models.flash_llama import ( - FlashLlama, - ) - from text_generation_server.models.flash_santacoder import ( - FlashSantacoderSharded, - ) - - FLASH_ATTENTION = True - else: - FLASH_ATTENTION = False -except ImportError: - logger.opt(exception=True).warning( - "Could not import Flash Attention enabled models" + from text_generation_server.models.flash_rw import FlashRWSharded + from text_generation_server.models.flash_neox import FlashNeoXSharded + from text_generation_server.models.flash_llama import ( + FlashLlama, ) + from text_generation_server.models.flash_santacoder import ( + FlashSantacoderSharded, + ) + +except ImportError as e: + logger.warning(f"Could not import Flash Attention enabled models: {e}") FLASH_ATTENTION = False if FLASH_ATTENTION: diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index d9f3c7b..d3c719d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -26,13 +26,13 @@ from transformers.activations import ACT2FN from typing import Optional, List, Tuple # Flash attention imports -import flash_attn_cuda import dropout_layer_norm # vllm imports import vllm_cache_ops import vllm_attention_ops +from text_generation_server.utils.flash_attn import attention from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -164,22 +164,14 @@ class FlashLlamaAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: # flash attention - flash_attn_cuda.fwd( + attention( qkv[:, 0], qkv[:, 1], qkv[:, 2], attn_output, cu_seqlen_prefill, - cu_seqlen_prefill, max_s, - max_s, - 0.0, self.softmax_scale, - False, - True, - False, - 0, - None, ) # Decode else: diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index b2dce22..e7c8ced 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -27,13 +27,11 @@ from transformers.modeling_utils import PreTrainedModel from transformers.models.gpt_neox import GPTNeoXConfig from typing import Optional, List, Tuple -# Flash attention imports -import flash_attn_cuda - # vllm imports import vllm_cache_ops import vllm_attention_ops +from text_generation_server.utils.flash_attn import attention from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -153,22 +151,14 @@ class FlashNeoxAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: # flash attention - flash_attn_cuda.fwd( + attention( qkv[:, 0], qkv[:, 1], qkv[:, 2], attn_output, cu_seqlen_prefill, - cu_seqlen_prefill, max_s, - max_s, - 0.0, self.softmax_scale, - False, - True, - False, - 0, - None, ) # Decode else: diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index acac274..1e9539c 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -6,13 +6,11 @@ from transformers.modeling_utils import PreTrainedModel from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple -# Flash attention imports -import flash_attn_cuda - # vllm imports import vllm_cache_ops import vllm_attention_ops +from text_generation_server.utils.flash_attn import attention from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -182,27 +180,15 @@ class FlashRWAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: - if self.num_heads_kv == 1: - # Expand to query shape - kv = kv.expand(-1, 2, self.num_heads, self.head_size) - # flash attention - flash_attn_cuda.fwd( + attention( query, torch.select(kv, dim=1, index=0), torch.select(kv, dim=1, index=1), attn_output, cu_seqlen_prefill, - cu_seqlen_prefill, max_s, - max_s, - 0.0, self.softmax_scale, - False, - True, - False, - 0, - None, ) # Decode else: @@ -314,30 +300,15 @@ class FlashRWLargeAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: - # Expand to query shape - kv = ( - kv.unsqueeze(2) - .expand(-1, self.num_groups, self.num_heads, 2, self.head_size) - .reshape(-1, self.num_groups * self.num_heads, 2, self.head_size) - ) - # flash attention - flash_attn_cuda.fwd( + attention( query, torch.select(kv, dim=2, index=0), torch.select(kv, dim=2, index=1), attn_output, cu_seqlen_prefill, - cu_seqlen_prefill, max_s, - max_s, - 0.0, self.softmax_scale, - False, - True, - False, - 0, - None, ) # Decode else: diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index a19623a..6f5c60f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -5,13 +5,11 @@ from torch import nn from transformers.activations import ACT2FN from typing import Optional, List, Tuple -# Flash attention imports -import flash_attn_cuda - # vllm imports import vllm_cache_ops import vllm_attention_ops +from text_generation_server.utils.flash_attn import attention from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -271,26 +269,15 @@ class FlashMQAttention(torch.nn.Module): # Prefill if cu_seqlen_prefill is not None: - # Expand from 1 to num_heads - key_value = key_value.expand(-1, 2, self.num_heads, self.head_size) - # flash attention - flash_attn_cuda.fwd( + attention( query, torch.select(key_value, dim=1, index=0), torch.select(key_value, dim=1, index=1), attn_output, cu_seqlen_prefill, - cu_seqlen_prefill, max_s, - max_s, - 0.0, self.softmax_scale, - False, - True, - False, - 0, - None, ) # Decode else: diff --git a/server/text_generation_server/utils/flash_attn.py b/server/text_generation_server/utils/flash_attn.py new file mode 100644 index 0000000..c472d1f --- /dev/null +++ b/server/text_generation_server/utils/flash_attn.py @@ -0,0 +1,124 @@ +import os +import torch + +from loguru import logger + +if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": + raise ImportError("`USE_FLASH_ATTENTION` is false.") + +if not torch.cuda.is_available(): + raise ImportError("CUDA is not available") + +major, minor = torch.cuda.get_device_capability() +is_sm75 = major == 7 and minor == 5 +is_sm8x = major == 8 and minor >= 0 +is_sm90 = major == 9 and minor == 0 + +HAS_FLASH_ATTN = False +HAS_FLASH_ATTN_V2 = False +try: + try: + import flash_attn_2_cuda + except ImportError: + raise ImportError( + "Flash Attention V2 is not installed.\n" + "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " + "or install flash attention v2 with `cd server && make install install-flash-attention-v2`" + ) + if not (is_sm8x or is_sm90): + raise ImportError( + f"GPU with CUDA capability {major} {minor} is not supported for " + "Flash Attention V2" + ) + HAS_FLASH_ATTN_V2 = True +except ImportError as e: + try: + import flash_attn_cuda + except ImportError: + raise ImportError( + "Flash Attention is not installed.\n" + "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " + "or install flash attention with `cd server && make install install-flash-attention`" + ) from e + + if not (is_sm75 or is_sm8x or is_sm90): + raise ImportError( + f"GPU with CUDA capability {major} {minor} is not supported" + ) from e + logger.warning(f"Unable to use Flash Attention V2: {e}") + HAS_FLASH_ATTN = True + + +def attention( + q, + k, + v, + out, + cu_seqlens, + max_s, + softmax_scale, +): + if HAS_FLASH_ATTN_V2: + return flash_attn_2_cuda.varlen_fwd( + q, + k, + v, + out, + cu_seqlens, + cu_seqlens, + max_s, + max_s, + 0.0, + softmax_scale, + False, + True, + False, + None, + ) + + if HAS_FLASH_ATTN: + # Flash attention v1 requires q, k and v to have the same number of heads + if k.shape[1] != q.shape[1]: + # MQA expand + if k.shape[1] == 1: + k = k.expand(-1, q.shape[1], -1) + # Grouped attention reshape + else: + original_shape = k.shape + k = ( + k.unsqueeze(2) + .expand(-1, -1, q.shape[1] // k.shape[1], -1) + .reshape(original_shape[0], -1, original_shape[2]) + ) + if v.shape[1] != q.shape[1]: + # MQA expand + if v.shape[1] == 1: + v = v.expand(-1, q.shape[1], -1) + # Grouped attention reshape + else: + original_shape = v.shape + v = ( + v.unsqueeze(2) + .expand(-1, -1, q.shape[1] // v.shape[1], -1) + .reshape(original_shape[0], -1, original_shape[2]) + ) + + return flash_attn_cuda.fwd( + q, + k, + v, + out, + cu_seqlens, + cu_seqlens, + max_s, + max_s, + 0.0, + softmax_scale, + False, + True, + False, + 0, + None, + ) + + raise NotImplementedError("flash attention is not installed")