hf_text-generation-inference/server/text_generation_server/utils/flash_attn.py

176 lines
5.2 KiB
Python
Raw Permalink Normal View History

import os
import torch
from loguru import logger
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.")
if not torch.cuda.is_available():
raise ImportError("CUDA is not available")
major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
is_sm8x = major == 8 and minor >= 0
is_sm90 = major == 9 and minor == 0
HAS_FLASH_ATTN = False
HAS_FLASH_ATTN_V2_CUDA = False
HAS_FLASH_ATTN_V2_ROCM = False
try:
try:
import flash_attn_2_cuda
except ImportError:
architecture_suffix = ""
if IS_CUDA_SYSTEM:
architecture_suffix = "-cuda"
elif IS_ROCM_SYSTEM:
architecture_suffix = "-rocm"
raise ImportError(
"Flash Attention V2 is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
)
if not (is_sm8x or is_sm90):
raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported for "
"Flash Attention V2"
)
HAS_FLASH_ATTN_V2_CUDA = IS_CUDA_SYSTEM
HAS_FLASH_ATTN_V2_ROCM = IS_ROCM_SYSTEM
except ImportError as e:
try:
import flash_attn_cuda
except ImportError:
raise ImportError(
"Flash Attention is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention with `cd server && make install install-flash-attention`"
) from e
if IS_CUDA_SYSTEM and not (is_sm75 or is_sm8x or is_sm90):
raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported"
) from e
elif IS_ROCM_SYSTEM:
for idx in range(torch.cuda.device_count()):
2023-12-11 06:49:52 -07:00
if "MI210" not in torch.cuda.get_device_name(
idx
) and "MI250" not in torch.cuda.get_device_name(idx):
raise ImportError(
f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
)
logger.warning(f"Unable to use Flash Attention V2: {e}")
HAS_FLASH_ATTN = True
def attention(
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
2023-09-28 01:55:47 -06:00
window_size_left=-1,
):
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
if HAS_FLASH_ATTN_V2_CUDA:
return flash_attn_2_cuda.varlen_fwd(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
0.0,
softmax_scale,
False,
True,
2023-09-28 01:55:47 -06:00
window_size_left,
0,
False,
None,
)
elif HAS_FLASH_ATTN_V2_ROCM:
if window_size_left != -1:
2023-12-11 06:49:52 -07:00
raise ValueError(
f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
)
# RoCm flash API does not take the window_size_left and window_size_right arguments.
return flash_attn_2_cuda.varlen_fwd(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
0.0,
softmax_scale,
False,
True,
False,
None,
)
elif HAS_FLASH_ATTN:
Fix window_size_left for flash attention v1 (#1089) This fixes flash attention v1 which was always NotImplementedError("window_size_left is only available with flash attn v2"). Currently flash_llama_modeling.py doesn't override the default value of window_size_left when calling attention(..) (line 282). This means that window_size_left will always be the default of -1, but flash attention v1 throws an exception if `window_size_left != 0`. To fix this, we should be checking `window_size_left != -1` before throwing the NotImplementedError. Fixes #1084 ## Before submitting - [x] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. @OlivierDehaene OR @Narsil
2023-10-02 12:53:14 -06:00
if window_size_left != -1:
2023-09-28 01:55:47 -06:00
raise NotImplementedError(
"window_size_left is only available with flash attn v2"
)
# Flash attention v1 requires q, k and v to have the same number of heads
if k.shape[1] != q.shape[1]:
# MQA expand
if k.shape[1] == 1:
k = k.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = k.shape
k = (
k.unsqueeze(2)
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
if v.shape[1] != q.shape[1]:
# MQA expand
if v.shape[1] == 1:
v = v.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = v.shape
v = (
v.unsqueeze(2)
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
return flash_attn_cuda.fwd(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
0.0,
softmax_scale,
False,
True,
False,
0,
None,
)
raise NotImplementedError("flash attention is not installed")