2024-05-31 09:57:01 -06:00
|
|
|
import torch
|
2024-10-17 02:42:52 -06:00
|
|
|
from text_generation_server.layers.attention.kv_cache import KVCache
|
2024-05-31 09:57:01 -06:00
|
|
|
from text_generation_server.utils.import_utils import SYSTEM
|
2024-08-09 03:42:00 -06:00
|
|
|
from text_generation_server.models.globals import (
|
2024-08-09 08:41:17 -06:00
|
|
|
ATTENTION,
|
2024-08-09 03:42:00 -06:00
|
|
|
BLOCK_SIZE,
|
|
|
|
)
|
2024-07-01 15:28:00 -06:00
|
|
|
from text_generation_server.layers.attention import Seqlen
|
2024-07-22 10:27:10 -06:00
|
|
|
from typing import Optional
|
2024-05-31 09:57:01 -06:00
|
|
|
|
|
|
|
major, minor = torch.cuda.get_device_capability()
|
|
|
|
is_sm75 = major == 7 and minor == 5
|
|
|
|
_PARTITION_SIZE = 512
|
|
|
|
|
|
|
|
|
|
|
|
def paged_attention(
|
|
|
|
query: torch.Tensor,
|
2024-10-17 02:42:52 -06:00
|
|
|
kv_cache: KVCache,
|
2024-05-31 09:57:01 -06:00
|
|
|
kv_head_mapping: torch.Tensor,
|
|
|
|
softmax_scale: float,
|
|
|
|
block_tables: torch.Tensor,
|
2024-07-01 15:28:00 -06:00
|
|
|
seqlen: Seqlen,
|
2024-05-31 09:57:01 -06:00
|
|
|
max_s: int,
|
2024-07-22 10:27:10 -06:00
|
|
|
softcap: Optional[float] = None,
|
2024-05-31 09:57:01 -06:00
|
|
|
):
|
|
|
|
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
|
|
|
# Copyright 2023 The vLLM team. All rights
|
|
|
|
# reserved.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
#
|
|
|
|
|
|
|
|
# value_cache => [num_blocks, num_heads, head_size, block_size]
|
2024-07-01 15:28:00 -06:00
|
|
|
# block_size = value_cache.shape[3]
|
|
|
|
block_size = BLOCK_SIZE
|
2024-05-31 09:57:01 -06:00
|
|
|
num_seqs, num_heads, head_size = query.shape
|
|
|
|
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
|
|
|
|
|
|
|
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
|
|
|
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
|
|
|
# V1 to avoid the overhead of reduction. Also, if the number of
|
|
|
|
# sequences or heads is large, we use V1 since there is enough work
|
|
|
|
# to parallelize.
|
2024-08-09 08:41:17 -06:00
|
|
|
if ATTENTION == "flashinfer":
|
2024-08-20 03:15:30 -06:00
|
|
|
from text_generation_server.layers.attention.flashinfer import decode_state
|
2024-08-09 03:42:00 -06:00
|
|
|
|
|
|
|
return decode_state.get().forward(
|
2024-10-24 06:59:56 -06:00
|
|
|
# TODO: remove `contiguous` call once https://github.com/flashinfer-ai/flashinfer/pull/553 is merged.
|
2024-08-09 03:42:00 -06:00
|
|
|
query.contiguous(),
|
2024-10-17 02:42:52 -06:00
|
|
|
paged_kv_cache=(kv_cache.key, kv_cache.value),
|
2024-08-09 03:42:00 -06:00
|
|
|
logits_soft_cap=softcap,
|
|
|
|
sm_scale=softmax_scale,
|
|
|
|
)
|
2024-08-09 08:41:17 -06:00
|
|
|
elif ATTENTION == "flashdecoding":
|
2024-07-01 15:28:00 -06:00
|
|
|
max_q = 1
|
|
|
|
max_k = max_s
|
|
|
|
import flash_attn_2_cuda
|
2024-05-31 09:57:01 -06:00
|
|
|
|
2024-07-01 15:28:00 -06:00
|
|
|
# TODO fixme when flash contains the fix.
|
|
|
|
# Number of splits is not correctly handled
|
|
|
|
# by the current path
|
|
|
|
# https://github.com/Dao-AILab/flash-attention/blob/320fb59487658f033f56711efd3d61b7c7a6f8f3/csrc/flash_attn/flash_api.cpp#L577
|
|
|
|
# This fails becuase we're using causal, therefore window_right is set to 0 and the split logic is never applied.
|
2024-07-22 10:27:10 -06:00
|
|
|
if softcap is None:
|
|
|
|
softcap = 0.0
|
2024-08-01 09:03:28 -06:00
|
|
|
out = flash_attn_2_cuda.varlen_fwd(
|
2024-05-31 09:57:01 -06:00
|
|
|
query,
|
2024-10-17 02:42:52 -06:00
|
|
|
kv_cache.key,
|
|
|
|
kv_cache.value,
|
2024-07-01 15:28:00 -06:00
|
|
|
None,
|
|
|
|
seqlen.cu_seqlen_q,
|
|
|
|
seqlen.cu_seqlen_k,
|
2024-07-22 10:27:10 -06:00
|
|
|
None, # pad_k
|
2024-07-01 15:28:00 -06:00
|
|
|
None,
|
2024-05-31 09:57:01 -06:00
|
|
|
block_tables,
|
|
|
|
None,
|
2024-07-01 15:28:00 -06:00
|
|
|
max_q,
|
|
|
|
max_k,
|
|
|
|
0.0, # dropout
|
|
|
|
softmax_scale,
|
|
|
|
False, # zero_tensors
|
|
|
|
True, # causal
|
|
|
|
-1, # Window_left
|
|
|
|
-1, # Window right
|
2024-07-22 10:27:10 -06:00
|
|
|
softcap,
|
2024-07-01 15:28:00 -06:00
|
|
|
False, # return softmax
|
|
|
|
None, # generator
|
2024-05-31 09:57:01 -06:00
|
|
|
)
|
2024-08-01 09:03:28 -06:00
|
|
|
return out[0]
|
2024-05-31 09:57:01 -06:00
|
|
|
else:
|
2024-07-22 10:27:10 -06:00
|
|
|
if softcap is not None:
|
|
|
|
raise RuntimeError("Paged attention doesn't support softcapping")
|
2024-10-16 04:49:33 -06:00
|
|
|
input_lengths = seqlen.input_lengths + seqlen.cache_lengths
|
2024-07-01 15:28:00 -06:00
|
|
|
from vllm._C import ops
|
2024-05-31 09:57:01 -06:00
|
|
|
|
2024-08-01 09:03:28 -06:00
|
|
|
out = torch.empty_like(query)
|
|
|
|
|
2024-07-01 15:28:00 -06:00
|
|
|
use_v1 = max_s <= 8192 and (
|
|
|
|
max_num_partitions == 1 or num_seqs * num_heads > 512
|
2024-05-31 09:57:01 -06:00
|
|
|
)
|
2024-07-01 15:28:00 -06:00
|
|
|
if use_v1:
|
|
|
|
ops.paged_attention_v1(
|
|
|
|
out,
|
|
|
|
query,
|
2024-10-17 02:42:52 -06:00
|
|
|
kv_cache.key,
|
|
|
|
kv_cache.value,
|
2024-07-01 15:28:00 -06:00
|
|
|
kv_head_mapping,
|
|
|
|
softmax_scale,
|
|
|
|
block_tables,
|
|
|
|
input_lengths,
|
|
|
|
block_size,
|
|
|
|
max_s,
|
|
|
|
None,
|
|
|
|
"auto",
|
|
|
|
1.0,
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
# Run PagedAttention V2.
|
|
|
|
assert _PARTITION_SIZE % block_size == 0
|
|
|
|
tmp_output = torch.empty(
|
|
|
|
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
|
|
|
dtype=out.dtype,
|
|
|
|
device=out.device,
|
|
|
|
)
|
|
|
|
exp_sums = torch.empty(
|
|
|
|
size=(num_seqs, num_heads, max_num_partitions),
|
|
|
|
dtype=torch.float32,
|
|
|
|
device=out.device,
|
|
|
|
)
|
|
|
|
max_logits = torch.empty_like(exp_sums)
|
|
|
|
|
|
|
|
ops.paged_attention_v2(
|
|
|
|
out,
|
|
|
|
exp_sums,
|
|
|
|
max_logits,
|
|
|
|
tmp_output,
|
|
|
|
query,
|
2024-10-17 02:42:52 -06:00
|
|
|
kv_cache.key,
|
|
|
|
kv_cache.value,
|
2024-07-01 15:28:00 -06:00
|
|
|
kv_head_mapping,
|
|
|
|
softmax_scale,
|
|
|
|
block_tables,
|
|
|
|
input_lengths,
|
|
|
|
block_size,
|
|
|
|
max_s,
|
|
|
|
None,
|
|
|
|
"auto",
|
|
|
|
1.0,
|
|
|
|
)
|
|
|
|
return out
|
2024-05-31 09:57:01 -06:00
|
|
|
|
|
|
|
|
|
|
|
try:
|
2024-08-05 07:11:40 -06:00
|
|
|
is_ampere_or_newer = major >= 8 and minor >= 0
|
|
|
|
if not is_ampere_or_newer:
|
|
|
|
raise ImportError("FlashAttention only supports Ampere GPUs or newer.")
|
|
|
|
|
2024-05-31 09:57:01 -06:00
|
|
|
import flash_attn_2_cuda
|
|
|
|
|
|
|
|
V2 = True
|
|
|
|
except ImportError:
|
|
|
|
try:
|
|
|
|
import flash_attn_cuda
|
|
|
|
|
|
|
|
V2 = False
|
|
|
|
except ImportError as e:
|
|
|
|
if major >= 8:
|
|
|
|
architecture_suffix = f"-{SYSTEM}"
|
|
|
|
raise ImportError(
|
|
|
|
"Flash Attention V2 is not installed.\n"
|
|
|
|
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
|
|
|
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
|
|
|
|
)
|
|
|
|
elif is_sm75:
|
|
|
|
raise ImportError(
|
|
|
|
"Flash Attention is not installed.\n"
|
|
|
|
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
|
|
|
"or install flash attention with `cd server && make install install-flash-attention`"
|
|
|
|
) from e
|
|
|
|
else:
|
|
|
|
raise ImportError(
|
|
|
|
f"GPU with CUDA capability {major} {minor} is not supported"
|
|
|
|
) from e
|
|
|
|
|
|
|
|
|
2024-10-17 02:42:52 -06:00
|
|
|
if ATTENTION == "flashdecoding" and not V2:
|
|
|
|
raise ValueError("Flash decoding requires Flash Attention V2")
|
|
|
|
|
2024-05-31 09:57:01 -06:00
|
|
|
SUPPORTS_WINDOWING = V2
|
2024-08-01 09:03:28 -06:00
|
|
|
|
2024-10-17 02:42:52 -06:00
|
|
|
|
|
|
|
def attention(
|
|
|
|
*,
|
|
|
|
query: torch.Tensor,
|
|
|
|
key: torch.Tensor,
|
|
|
|
value: torch.Tensor,
|
|
|
|
kv_cache: KVCache,
|
|
|
|
seqlen: Seqlen,
|
|
|
|
block_tables: torch.Tensor,
|
|
|
|
softmax_scale: float,
|
|
|
|
window_size_left: int = -1,
|
|
|
|
causal: bool = True,
|
|
|
|
softcap: Optional[float] = None,
|
|
|
|
):
|
|
|
|
if ATTENTION == "flashinfer":
|
2024-08-20 03:15:30 -06:00
|
|
|
from text_generation_server.layers.attention.flashinfer import (
|
|
|
|
prefill_with_paged_kv_state,
|
|
|
|
)
|
2024-08-09 03:42:00 -06:00
|
|
|
|
2024-10-17 02:42:52 -06:00
|
|
|
if softcap is None:
|
|
|
|
softcap = 0.0
|
|
|
|
|
2024-08-20 03:15:30 -06:00
|
|
|
return prefill_with_paged_kv_state.get().forward(
|
2024-10-24 06:59:56 -06:00
|
|
|
# TODO: remove `contiguous` call once https://github.com/flashinfer-ai/flashinfer/pull/553 is merged.
|
2024-10-17 02:42:52 -06:00
|
|
|
query.contiguous(),
|
2024-08-09 03:42:00 -06:00
|
|
|
causal=causal,
|
2024-10-17 02:42:52 -06:00
|
|
|
paged_kv_cache=(kv_cache.key, kv_cache.value),
|
2024-08-09 03:42:00 -06:00
|
|
|
logits_soft_cap=softcap,
|
|
|
|
sm_scale=softmax_scale,
|
2024-08-29 08:29:01 -06:00
|
|
|
window_left=window_size_left,
|
2024-08-09 03:42:00 -06:00
|
|
|
)
|
|
|
|
|
2024-10-17 02:42:52 -06:00
|
|
|
# If we are using flashdecoding or paged, we always use flash-attn for
|
|
|
|
# the prefill. We have to branch on whether we use flash-attn v1 or v2.
|
|
|
|
elif V2:
|
|
|
|
out = torch.empty_like(query)
|
|
|
|
if window_size_left <= 0 and window_size_left != -1:
|
|
|
|
raise ValueError("`window_size_left` must be > 0 or -1")
|
2024-10-16 04:49:33 -06:00
|
|
|
|
2024-10-17 02:42:52 -06:00
|
|
|
if softcap is None:
|
|
|
|
softcap = 0.0
|
2024-10-16 04:49:33 -06:00
|
|
|
|
2024-10-17 02:42:52 -06:00
|
|
|
return flash_attn_2_cuda.varlen_fwd(
|
|
|
|
query,
|
|
|
|
# flashdecoding: pass the KV caches, paged: pass the KV.
|
|
|
|
kv_cache.key if ATTENTION == "flashdecoding" else key,
|
|
|
|
kv_cache.value if ATTENTION == "flashdecoding" else value,
|
|
|
|
out,
|
|
|
|
seqlen.cu_seqlen_q,
|
|
|
|
seqlen.cu_seqlen_k,
|
|
|
|
None,
|
|
|
|
None,
|
|
|
|
block_tables if ATTENTION == "flashdecoding" else None,
|
|
|
|
None,
|
|
|
|
seqlen.max_q,
|
|
|
|
seqlen.max_k,
|
|
|
|
0.0,
|
2024-05-31 09:57:01 -06:00
|
|
|
softmax_scale,
|
2024-10-17 02:42:52 -06:00
|
|
|
False,
|
|
|
|
causal,
|
|
|
|
window_size_left,
|
|
|
|
0,
|
|
|
|
softcap,
|
|
|
|
False,
|
|
|
|
None,
|
|
|
|
)[0]
|
2024-10-16 04:49:33 -06:00
|
|
|
|
|
|
|
else:
|
2024-10-17 02:42:52 -06:00
|
|
|
if window_size_left != -1:
|
|
|
|
raise NotImplementedError(
|
|
|
|
"window_size_left is only available with flash attn v2"
|
|
|
|
)
|
|
|
|
if softcap is not None:
|
|
|
|
raise NotImplementedError("softcap is not available in flash attn v1")
|
|
|
|
|
|
|
|
# Flash attention v1 requires q, k and v to have the same number of heads
|
|
|
|
if key.shape[1] != query.shape[1]:
|
|
|
|
# MQA expand
|
|
|
|
if key.shape[1] == 1:
|
|
|
|
key = key.expand(-1, query.shape[1], -1)
|
|
|
|
# Grouped attention reshape
|
|
|
|
else:
|
|
|
|
original_shape = key.shape
|
|
|
|
key = (
|
|
|
|
key.unsqueeze(2)
|
|
|
|
.expand(-1, -1, query.shape[1] // key.shape[1], -1)
|
|
|
|
.reshape(original_shape[0], -1, original_shape[2])
|
2024-10-16 04:49:33 -06:00
|
|
|
)
|
2024-10-17 02:42:52 -06:00
|
|
|
if value.shape[1] != query.shape[1]:
|
|
|
|
# MQA expand
|
|
|
|
if value.shape[1] == 1:
|
|
|
|
value = value.expand(-1, query.shape[1], -1)
|
|
|
|
# Grouped attention reshape
|
|
|
|
else:
|
|
|
|
original_shape = value.shape
|
|
|
|
value = (
|
|
|
|
value.unsqueeze(2)
|
|
|
|
.expand(-1, -1, query.shape[1] // value.shape[1], -1)
|
|
|
|
.reshape(original_shape[0], -1, original_shape[2])
|
2024-10-16 04:49:33 -06:00
|
|
|
)
|
|
|
|
|
2024-10-17 02:42:52 -06:00
|
|
|
out = torch.empty_like(query)
|
|
|
|
flash_attn_cuda.fwd(
|
|
|
|
query,
|
|
|
|
key,
|
|
|
|
value,
|
|
|
|
out,
|
|
|
|
seqlen.cu_seqlen_q,
|
|
|
|
seqlen.cu_seqlen_q,
|
|
|
|
seqlen.max_q,
|
|
|
|
seqlen.max_k,
|
|
|
|
0.0,
|
|
|
|
softmax_scale,
|
|
|
|
False,
|
|
|
|
causal,
|
|
|
|
False,
|
|
|
|
0,
|
|
|
|
None,
|
|
|
|
)
|
|
|
|
return out
|
2024-09-27 08:19:42 -06:00
|
|
|
|
2024-10-04 09:51:48 -06:00
|
|
|
|
|
|
|
__all__ = [
|
|
|
|
"SUPPORTS_WINDOWING",
|
|
|
|
"attention",
|
|
|
|
"paged_attention",
|
|
|
|
]
|