2023-07-18 08:21:18 -06:00
|
|
|
import os
|
|
|
|
import torch
|
|
|
|
|
|
|
|
from loguru import logger
|
2024-04-26 07:48:58 -06:00
|
|
|
import math
|
2023-07-18 08:21:18 -06:00
|
|
|
|
2024-04-26 07:48:58 -06:00
|
|
|
from text_generation_server.utils.import_utils import (
|
|
|
|
IS_CUDA_SYSTEM,
|
|
|
|
IS_ROCM_SYSTEM,
|
|
|
|
IS_XPU_SYSTEM,
|
|
|
|
)
|
2023-11-27 06:08:12 -07:00
|
|
|
|
2023-07-18 08:21:18 -06:00
|
|
|
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
|
|
|
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
2024-04-26 07:48:58 -06:00
|
|
|
HAS_FLASH_ATTN = True
|
|
|
|
HAS_FLASH_ATTN_V2_CUDA = False
|
|
|
|
HAS_FLASH_ATTN_V2_ROCM = False
|
2023-07-18 08:21:18 -06:00
|
|
|
|
2024-04-26 07:48:58 -06:00
|
|
|
if IS_XPU_SYSTEM:
|
|
|
|
import intel_extension_for_pytorch as ipex
|
2023-07-18 08:21:18 -06:00
|
|
|
|
2024-04-26 07:48:58 -06:00
|
|
|
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
|
|
|
|
if not torch.cuda.is_available():
|
|
|
|
raise ImportError("CUDA is not available")
|
2023-07-18 08:21:18 -06:00
|
|
|
|
2024-04-26 07:48:58 -06:00
|
|
|
major, minor = torch.cuda.get_device_capability()
|
|
|
|
is_sm75 = major == 7 and minor == 5
|
|
|
|
is_sm8x = major == 8 and minor >= 0
|
|
|
|
is_sm90 = major == 9 and minor == 0
|
|
|
|
|
|
|
|
HAS_FLASH_ATTN = False
|
|
|
|
HAS_FLASH_ATTN_V2_CUDA = False
|
|
|
|
HAS_FLASH_ATTN_V2_ROCM = False
|
2023-07-18 08:21:18 -06:00
|
|
|
try:
|
2024-04-26 07:48:58 -06:00
|
|
|
try:
|
|
|
|
import flash_attn_2_cuda
|
|
|
|
except ImportError:
|
|
|
|
architecture_suffix = ""
|
|
|
|
if IS_CUDA_SYSTEM:
|
|
|
|
architecture_suffix = "-cuda"
|
|
|
|
elif IS_ROCM_SYSTEM:
|
|
|
|
architecture_suffix = "-rocm"
|
|
|
|
raise ImportError(
|
|
|
|
"Flash Attention V2 is not installed.\n"
|
|
|
|
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
|
|
|
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
|
|
|
|
)
|
|
|
|
if not (is_sm8x or is_sm90):
|
|
|
|
raise ImportError(
|
|
|
|
f"GPU with CUDA capability {major} {minor} is not supported for "
|
|
|
|
"Flash Attention V2"
|
|
|
|
)
|
|
|
|
HAS_FLASH_ATTN_V2_CUDA = IS_CUDA_SYSTEM
|
|
|
|
HAS_FLASH_ATTN_V2_ROCM = IS_ROCM_SYSTEM
|
|
|
|
except ImportError as e:
|
|
|
|
try:
|
|
|
|
import flash_attn_cuda
|
|
|
|
except ImportError:
|
|
|
|
raise ImportError(
|
|
|
|
"Flash Attention is not installed.\n"
|
|
|
|
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
|
|
|
"or install flash attention with `cd server && make install install-flash-attention`"
|
|
|
|
) from e
|
2023-07-18 08:21:18 -06:00
|
|
|
|
2024-04-26 07:48:58 -06:00
|
|
|
if IS_CUDA_SYSTEM and not (is_sm75 or is_sm8x or is_sm90):
|
|
|
|
raise ImportError(
|
|
|
|
f"GPU with CUDA capability {major} {minor} is not supported"
|
|
|
|
) from e
|
|
|
|
elif IS_ROCM_SYSTEM:
|
|
|
|
for idx in range(torch.cuda.device_count()):
|
|
|
|
if "MI210" not in torch.cuda.get_device_name(
|
|
|
|
idx
|
|
|
|
) and "MI250" not in torch.cuda.get_device_name(idx):
|
|
|
|
raise ImportError(
|
|
|
|
f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
|
|
|
|
)
|
2023-11-27 06:08:12 -07:00
|
|
|
|
2024-04-26 07:48:58 -06:00
|
|
|
logger.warning(f"Unable to use Flash Attention V2: {e}")
|
|
|
|
HAS_FLASH_ATTN = True
|
2023-07-18 08:21:18 -06:00
|
|
|
|
|
|
|
|
|
|
|
def attention(
|
|
|
|
q,
|
|
|
|
k,
|
|
|
|
v,
|
|
|
|
out,
|
|
|
|
cu_seqlens,
|
|
|
|
max_s,
|
|
|
|
softmax_scale,
|
2023-09-28 01:55:47 -06:00
|
|
|
window_size_left=-1,
|
2023-07-18 08:21:18 -06:00
|
|
|
):
|
2023-12-14 17:18:39 -07:00
|
|
|
if window_size_left <= 0 and window_size_left != -1:
|
|
|
|
raise ValueError("`window_size_left` must be > 0 or -1")
|
|
|
|
|
2024-04-26 07:48:58 -06:00
|
|
|
if IS_XPU_SYSTEM:
|
|
|
|
if window_size_left != -1:
|
|
|
|
raise ValueError(
|
|
|
|
f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
|
|
|
|
)
|
|
|
|
return ipex.llm.functional.varlen_attention(
|
|
|
|
q,
|
|
|
|
k,
|
|
|
|
v,
|
|
|
|
out,
|
|
|
|
cu_seqlens,
|
|
|
|
cu_seqlens,
|
|
|
|
max_s,
|
|
|
|
max_s,
|
|
|
|
0.0,
|
|
|
|
softmax_scale,
|
|
|
|
False,
|
|
|
|
True,
|
|
|
|
False,
|
|
|
|
None,
|
|
|
|
)
|
|
|
|
|
2023-11-27 06:08:12 -07:00
|
|
|
if HAS_FLASH_ATTN_V2_CUDA:
|
2023-07-18 08:21:18 -06:00
|
|
|
return flash_attn_2_cuda.varlen_fwd(
|
|
|
|
q,
|
|
|
|
k,
|
|
|
|
v,
|
|
|
|
out,
|
|
|
|
cu_seqlens,
|
|
|
|
cu_seqlens,
|
2024-04-10 09:20:25 -06:00
|
|
|
None,
|
|
|
|
None,
|
|
|
|
None,
|
2023-07-18 08:21:18 -06:00
|
|
|
max_s,
|
|
|
|
max_s,
|
|
|
|
0.0,
|
|
|
|
softmax_scale,
|
|
|
|
False,
|
|
|
|
True,
|
2023-09-28 01:55:47 -06:00
|
|
|
window_size_left,
|
|
|
|
0,
|
2023-07-18 08:21:18 -06:00
|
|
|
False,
|
|
|
|
None,
|
|
|
|
)
|
2023-11-27 06:08:12 -07:00
|
|
|
elif HAS_FLASH_ATTN_V2_ROCM:
|
|
|
|
if window_size_left != -1:
|
2023-12-11 06:49:52 -07:00
|
|
|
raise ValueError(
|
|
|
|
f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
|
|
|
|
)
|
|
|
|
|
2023-11-27 06:08:12 -07:00
|
|
|
# RoCm flash API does not take the window_size_left and window_size_right arguments.
|
|
|
|
return flash_attn_2_cuda.varlen_fwd(
|
|
|
|
q,
|
|
|
|
k,
|
|
|
|
v,
|
|
|
|
out,
|
|
|
|
cu_seqlens,
|
|
|
|
cu_seqlens,
|
|
|
|
max_s,
|
|
|
|
max_s,
|
|
|
|
0.0,
|
|
|
|
softmax_scale,
|
|
|
|
False,
|
|
|
|
True,
|
|
|
|
False,
|
|
|
|
None,
|
|
|
|
)
|
|
|
|
elif HAS_FLASH_ATTN:
|
2023-10-02 12:53:14 -06:00
|
|
|
if window_size_left != -1:
|
2023-09-28 01:55:47 -06:00
|
|
|
raise NotImplementedError(
|
|
|
|
"window_size_left is only available with flash attn v2"
|
|
|
|
)
|
|
|
|
|
2023-07-18 08:21:18 -06:00
|
|
|
# Flash attention v1 requires q, k and v to have the same number of heads
|
|
|
|
if k.shape[1] != q.shape[1]:
|
|
|
|
# MQA expand
|
|
|
|
if k.shape[1] == 1:
|
|
|
|
k = k.expand(-1, q.shape[1], -1)
|
|
|
|
# Grouped attention reshape
|
|
|
|
else:
|
|
|
|
original_shape = k.shape
|
|
|
|
k = (
|
|
|
|
k.unsqueeze(2)
|
|
|
|
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
|
|
|
|
.reshape(original_shape[0], -1, original_shape[2])
|
|
|
|
)
|
|
|
|
if v.shape[1] != q.shape[1]:
|
|
|
|
# MQA expand
|
|
|
|
if v.shape[1] == 1:
|
|
|
|
v = v.expand(-1, q.shape[1], -1)
|
|
|
|
# Grouped attention reshape
|
|
|
|
else:
|
|
|
|
original_shape = v.shape
|
|
|
|
v = (
|
|
|
|
v.unsqueeze(2)
|
|
|
|
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
|
|
|
|
.reshape(original_shape[0], -1, original_shape[2])
|
|
|
|
)
|
|
|
|
|
|
|
|
return flash_attn_cuda.fwd(
|
|
|
|
q,
|
|
|
|
k,
|
|
|
|
v,
|
|
|
|
out,
|
|
|
|
cu_seqlens,
|
|
|
|
cu_seqlens,
|
|
|
|
max_s,
|
|
|
|
max_s,
|
|
|
|
0.0,
|
|
|
|
softmax_scale,
|
|
|
|
False,
|
|
|
|
True,
|
|
|
|
False,
|
|
|
|
0,
|
|
|
|
None,
|
|
|
|
)
|
|
|
|
|
|
|
|
raise NotImplementedError("flash attention is not installed")
|