Purely refactors paged/attention into `layers/attention` and make hardware differences more obvious with 1 file per hardware. (#1986)
# What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil -->
This commit is contained in:
parent
659bd67fec
commit
06edde9491
|
@ -70,7 +70,6 @@ impl Infer {
|
||||||
tokenizer_config: HubTokenizerConfig,
|
tokenizer_config: HubTokenizerConfig,
|
||||||
processor_config: HubProcessorConfig,
|
processor_config: HubProcessorConfig,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// Infer shared state
|
|
||||||
let queue = Queue::new(requires_padding, 16, window_size, speculate);
|
let queue = Queue::new(requires_padding, 16, window_size, speculate);
|
||||||
let shared = Arc::new(Shared {
|
let shared = Arc::new(Shared {
|
||||||
batching_task: Notify::new(),
|
batching_task: Notify::new(),
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
flash_att_v2_commit_cuda := 23e8fa5a263d1c7122bc46a86ef32030ee7130f9
|
flash_att_v2_commit_cuda := v2.5.8
|
||||||
flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6
|
flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6
|
||||||
|
|
||||||
|
|
||||||
flash-attention-v2-cuda:
|
flash-attention-v2-cuda:
|
||||||
# Clone flash attention
|
# Clone flash attention
|
||||||
pip install -U packaging ninja --no-cache-dir
|
pip install -U packaging ninja --no-cache-dir
|
||||||
git clone https://github.com/HazyResearch/flash-attention.git flash-attention-v2
|
git clone https://github.com/Dao-AILab/flash-attention.git flash-attention-v2
|
||||||
|
|
||||||
build-flash-attention-v2-cuda: flash-attention-v2-cuda
|
build-flash-attention-v2-cuda: flash-attention-v2-cuda
|
||||||
cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_cuda)
|
cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_cuda)
|
||||||
|
|
|
@ -0,0 +1,13 @@
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
import os
|
||||||
|
|
||||||
|
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
||||||
|
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
||||||
|
if SYSTEM == "cuda":
|
||||||
|
from .cuda import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
||||||
|
elif SYSTEM == "rocm":
|
||||||
|
from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
||||||
|
elif SYSTEM == "xpu":
|
||||||
|
from .xpu import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
||||||
|
else:
|
||||||
|
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
|
|
@ -0,0 +1,245 @@
|
||||||
|
import torch
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
|
major, minor = torch.cuda.get_device_capability()
|
||||||
|
is_sm75 = major == 7 and minor == 5
|
||||||
|
_PARTITION_SIZE = 512
|
||||||
|
|
||||||
|
try:
|
||||||
|
from vllm._C import cache_ops
|
||||||
|
from vllm._C import ops
|
||||||
|
except Exception as e:
|
||||||
|
raise ImportError(
|
||||||
|
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def reshape_and_cache(
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
):
|
||||||
|
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
|
||||||
|
|
||||||
|
|
||||||
|
def paged_attention(
|
||||||
|
out: torch.Tensor,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
kv_head_mapping: torch.Tensor,
|
||||||
|
softmax_scale: float,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
input_lengths: torch.Tensor,
|
||||||
|
max_s: int,
|
||||||
|
):
|
||||||
|
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
||||||
|
# Copyright 2023 The vLLM team. All rights
|
||||||
|
# reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
# value_cache => [num_blocks, num_heads, head_size, block_size]
|
||||||
|
block_size = value_cache.shape[3]
|
||||||
|
num_seqs, num_heads, head_size = query.shape
|
||||||
|
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
||||||
|
|
||||||
|
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
||||||
|
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
||||||
|
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||||
|
# sequences or heads is large, we use V1 since there is enough work
|
||||||
|
# to parallelize.
|
||||||
|
from vllm._C import ops
|
||||||
|
|
||||||
|
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
|
||||||
|
if use_v1:
|
||||||
|
ops.paged_attention_v1(
|
||||||
|
out,
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
kv_head_mapping,
|
||||||
|
softmax_scale,
|
||||||
|
block_tables,
|
||||||
|
input_lengths,
|
||||||
|
block_size,
|
||||||
|
max_s,
|
||||||
|
None,
|
||||||
|
"auto",
|
||||||
|
1.0,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Run PagedAttention V2.
|
||||||
|
assert _PARTITION_SIZE % block_size == 0
|
||||||
|
tmp_output = torch.empty(
|
||||||
|
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
||||||
|
dtype=out.dtype,
|
||||||
|
device=out.device,
|
||||||
|
)
|
||||||
|
exp_sums = torch.empty(
|
||||||
|
size=(num_seqs, num_heads, max_num_partitions),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=out.device,
|
||||||
|
)
|
||||||
|
max_logits = torch.empty_like(exp_sums)
|
||||||
|
|
||||||
|
ops.paged_attention_v2(
|
||||||
|
out,
|
||||||
|
exp_sums,
|
||||||
|
max_logits,
|
||||||
|
tmp_output,
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
kv_head_mapping,
|
||||||
|
softmax_scale,
|
||||||
|
block_tables,
|
||||||
|
input_lengths,
|
||||||
|
block_size,
|
||||||
|
max_s,
|
||||||
|
None,
|
||||||
|
"auto",
|
||||||
|
1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
import flash_attn_2_cuda
|
||||||
|
|
||||||
|
V2 = True
|
||||||
|
except ImportError:
|
||||||
|
try:
|
||||||
|
import flash_attn_cuda
|
||||||
|
|
||||||
|
V2 = False
|
||||||
|
except ImportError as e:
|
||||||
|
if major >= 8:
|
||||||
|
architecture_suffix = f"-{SYSTEM}"
|
||||||
|
raise ImportError(
|
||||||
|
"Flash Attention V2 is not installed.\n"
|
||||||
|
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
||||||
|
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
|
||||||
|
)
|
||||||
|
elif is_sm75:
|
||||||
|
raise ImportError(
|
||||||
|
"Flash Attention is not installed.\n"
|
||||||
|
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
||||||
|
"or install flash attention with `cd server && make install install-flash-attention`"
|
||||||
|
) from e
|
||||||
|
else:
|
||||||
|
raise ImportError(
|
||||||
|
f"GPU with CUDA capability {major} {minor} is not supported"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
SUPPORTS_WINDOWING = V2
|
||||||
|
if V2:
|
||||||
|
|
||||||
|
def attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
softmax_scale,
|
||||||
|
window_size_left=-1,
|
||||||
|
causal=True,
|
||||||
|
):
|
||||||
|
if window_size_left <= 0 and window_size_left != -1:
|
||||||
|
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||||
|
return flash_attn_2_cuda.varlen_fwd(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
cu_seqlens,
|
||||||
|
cu_seqlens,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
max_s,
|
||||||
|
max_s,
|
||||||
|
0.0,
|
||||||
|
softmax_scale,
|
||||||
|
False,
|
||||||
|
causal,
|
||||||
|
window_size_left,
|
||||||
|
0,
|
||||||
|
False,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
def attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
softmax_scale,
|
||||||
|
window_size_left=-1,
|
||||||
|
):
|
||||||
|
if window_size_left != -1:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"window_size_left is only available with flash attn v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Flash attention v1 requires q, k and v to have the same number of heads
|
||||||
|
if k.shape[1] != q.shape[1]:
|
||||||
|
# MQA expand
|
||||||
|
if k.shape[1] == 1:
|
||||||
|
k = k.expand(-1, q.shape[1], -1)
|
||||||
|
# Grouped attention reshape
|
||||||
|
else:
|
||||||
|
original_shape = k.shape
|
||||||
|
k = (
|
||||||
|
k.unsqueeze(2)
|
||||||
|
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
|
||||||
|
.reshape(original_shape[0], -1, original_shape[2])
|
||||||
|
)
|
||||||
|
if v.shape[1] != q.shape[1]:
|
||||||
|
# MQA expand
|
||||||
|
if v.shape[1] == 1:
|
||||||
|
v = v.expand(-1, q.shape[1], -1)
|
||||||
|
# Grouped attention reshape
|
||||||
|
else:
|
||||||
|
original_shape = v.shape
|
||||||
|
v = (
|
||||||
|
v.unsqueeze(2)
|
||||||
|
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
|
||||||
|
.reshape(original_shape[0], -1, original_shape[2])
|
||||||
|
)
|
||||||
|
|
||||||
|
return flash_attn_cuda.fwd(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
cu_seqlens,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
max_s,
|
||||||
|
0.0,
|
||||||
|
softmax_scale,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
0,
|
||||||
|
None,
|
||||||
|
)
|
|
@ -0,0 +1,295 @@
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
major, minor = torch.cuda.get_device_capability()
|
||||||
|
is_sm75 = major == 7 and minor == 5
|
||||||
|
_PARTITION_SIZE = 512
|
||||||
|
|
||||||
|
use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"}
|
||||||
|
ENGINE = "triton" if use_triton else "ck"
|
||||||
|
|
||||||
|
try:
|
||||||
|
from vllm._C import cache_ops
|
||||||
|
from vllm._C import ops
|
||||||
|
except Exception as e:
|
||||||
|
raise ImportError(
|
||||||
|
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def reshape_and_cache(
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
):
|
||||||
|
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
|
||||||
|
|
||||||
|
|
||||||
|
def paged_attention(
|
||||||
|
out: torch.Tensor,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
kv_head_mapping: torch.Tensor,
|
||||||
|
softmax_scale: float,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
input_lengths: torch.Tensor,
|
||||||
|
max_s: int,
|
||||||
|
):
|
||||||
|
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
||||||
|
# Copyright 2023 The vLLM team. All rights
|
||||||
|
# reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
# value_cache => [num_blocks, num_heads, head_size, block_size]
|
||||||
|
block_size = value_cache.shape[3]
|
||||||
|
num_seqs, num_heads, head_size = query.shape
|
||||||
|
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
||||||
|
|
||||||
|
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
||||||
|
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
||||||
|
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||||
|
# sequences or heads is large, we use V1 since there is enough work
|
||||||
|
# to parallelize.
|
||||||
|
from vllm._C import ops
|
||||||
|
|
||||||
|
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
|
||||||
|
if use_v1:
|
||||||
|
ops.paged_attention_v1(
|
||||||
|
out,
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
kv_head_mapping,
|
||||||
|
softmax_scale,
|
||||||
|
block_tables,
|
||||||
|
input_lengths,
|
||||||
|
block_size,
|
||||||
|
max_s,
|
||||||
|
None,
|
||||||
|
"auto",
|
||||||
|
1.0,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Run PagedAttention V2.
|
||||||
|
assert _PARTITION_SIZE % block_size == 0
|
||||||
|
tmp_output = torch.empty(
|
||||||
|
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
||||||
|
dtype=out.dtype,
|
||||||
|
device=out.device,
|
||||||
|
)
|
||||||
|
exp_sums = torch.empty(
|
||||||
|
size=(num_seqs, num_heads, max_num_partitions),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=out.device,
|
||||||
|
)
|
||||||
|
max_logits = torch.empty_like(exp_sums)
|
||||||
|
|
||||||
|
ops.paged_attention_v2(
|
||||||
|
out,
|
||||||
|
exp_sums,
|
||||||
|
max_logits,
|
||||||
|
tmp_output,
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
kv_head_mapping,
|
||||||
|
softmax_scale,
|
||||||
|
block_tables,
|
||||||
|
input_lengths,
|
||||||
|
block_size,
|
||||||
|
max_s,
|
||||||
|
None,
|
||||||
|
"auto",
|
||||||
|
1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if ENGINE != "triton":
|
||||||
|
try:
|
||||||
|
import flash_attn_2_cuda
|
||||||
|
|
||||||
|
logger.info("ROCm: using Flash Attention 2 Composable Kernel implementation.")
|
||||||
|
except ImportError:
|
||||||
|
try:
|
||||||
|
import flash_attn_cuda
|
||||||
|
|
||||||
|
ENGINE = "v1"
|
||||||
|
logger.info("ROCm: using Flash Attention 1")
|
||||||
|
except ImportError as e:
|
||||||
|
if major >= 8:
|
||||||
|
architecture_suffix = f"-{SYSTEM}"
|
||||||
|
raise ImportError(
|
||||||
|
"Flash Attention V2 is not installed.\n"
|
||||||
|
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
||||||
|
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
|
||||||
|
)
|
||||||
|
elif is_sm75:
|
||||||
|
raise ImportError(
|
||||||
|
"Flash Attention is not installed.\n"
|
||||||
|
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
||||||
|
"or install flash attention with `cd server && make install install-flash-attention`"
|
||||||
|
) from e
|
||||||
|
else:
|
||||||
|
|
||||||
|
for idx in range(torch.cuda.device_count()):
|
||||||
|
name = torch.cuda.get_device_name(idx)
|
||||||
|
if "MI210" not in name and "MI250" not in name:
|
||||||
|
raise ImportError(
|
||||||
|
f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
|
||||||
|
)
|
||||||
|
raise ImportError(
|
||||||
|
f"AMD GPU with ROCm capability {major} {minor} is not supported"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
SUPPORTS_WINDOWING = ENGINE != "v1"
|
||||||
|
if ENGINE == "ck":
|
||||||
|
|
||||||
|
def attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
softmax_scale,
|
||||||
|
window_size_left=-1,
|
||||||
|
causal=True,
|
||||||
|
):
|
||||||
|
if window_size_left <= 0 and window_size_left != -1:
|
||||||
|
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||||
|
if window_size_left != -1:
|
||||||
|
raise ValueError(
|
||||||
|
f"ROCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
|
||||||
|
)
|
||||||
|
return flash_attn_2_cuda.varlen_fwd(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
cu_seqlens,
|
||||||
|
cu_seqlens,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
max_s,
|
||||||
|
max_s,
|
||||||
|
0.0,
|
||||||
|
softmax_scale,
|
||||||
|
False,
|
||||||
|
causal,
|
||||||
|
window_size_left,
|
||||||
|
0,
|
||||||
|
False,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif ENGINE == "triton":
|
||||||
|
from .flash_attn_triton import triton_attention
|
||||||
|
|
||||||
|
def attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
softmax_scale,
|
||||||
|
window_size_left=-1,
|
||||||
|
causal=True,
|
||||||
|
):
|
||||||
|
if window_size_left != -1:
|
||||||
|
raise ValueError(
|
||||||
|
f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
|
||||||
|
)
|
||||||
|
output, _ = triton_attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
cu_seqlens,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
max_s,
|
||||||
|
causal,
|
||||||
|
softmax_scale,
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
def attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
softmax_scale,
|
||||||
|
window_size_left=-1,
|
||||||
|
):
|
||||||
|
if window_size_left != -1:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"window_size_left is only available with flash attn v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Flash attention v1 requires q, k and v to have the same number of heads
|
||||||
|
if k.shape[1] != q.shape[1]:
|
||||||
|
# MQA expand
|
||||||
|
if k.shape[1] == 1:
|
||||||
|
k = k.expand(-1, q.shape[1], -1)
|
||||||
|
# Grouped attention reshape
|
||||||
|
else:
|
||||||
|
original_shape = k.shape
|
||||||
|
k = (
|
||||||
|
k.unsqueeze(2)
|
||||||
|
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
|
||||||
|
.reshape(original_shape[0], -1, original_shape[2])
|
||||||
|
)
|
||||||
|
if v.shape[1] != q.shape[1]:
|
||||||
|
# MQA expand
|
||||||
|
if v.shape[1] == 1:
|
||||||
|
v = v.expand(-1, q.shape[1], -1)
|
||||||
|
# Grouped attention reshape
|
||||||
|
else:
|
||||||
|
original_shape = v.shape
|
||||||
|
v = (
|
||||||
|
v.unsqueeze(2)
|
||||||
|
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
|
||||||
|
.reshape(original_shape[0], -1, original_shape[2])
|
||||||
|
)
|
||||||
|
|
||||||
|
return flash_attn_cuda.fwd(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
cu_seqlens,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
max_s,
|
||||||
|
0.0,
|
||||||
|
softmax_scale,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
0,
|
||||||
|
None,
|
||||||
|
)
|
|
@ -0,0 +1,76 @@
|
||||||
|
import intel_extension_for_pytorch as ipex
|
||||||
|
import torch
|
||||||
|
|
||||||
|
SUPPORTS_WINDOWING = False
|
||||||
|
|
||||||
|
|
||||||
|
def attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
softmax_scale,
|
||||||
|
window_size_left=-1,
|
||||||
|
):
|
||||||
|
if window_size_left != -1:
|
||||||
|
raise ValueError(
|
||||||
|
f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
|
||||||
|
)
|
||||||
|
return ipex.llm.functional.varlen_attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out,
|
||||||
|
cu_seqlens,
|
||||||
|
cu_seqlens,
|
||||||
|
max_s,
|
||||||
|
max_s,
|
||||||
|
0.0,
|
||||||
|
softmax_scale,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def reshape_and_cache(
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
):
|
||||||
|
ipex.llm.modules.PagedAttention.reshape_and_cache(
|
||||||
|
key, value, key_cache, value_cache, slots
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def paged_attention(
|
||||||
|
out: torch.Tensor,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
kv_head_mapping: torch.Tensor,
|
||||||
|
softmax_scale: float,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
input_lengths: torch.Tensor,
|
||||||
|
max_s: int,
|
||||||
|
):
|
||||||
|
query = query.contiguous()
|
||||||
|
block_size = value_cache.shape[3]
|
||||||
|
return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
||||||
|
out,
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
kv_head_mapping,
|
||||||
|
softmax_scale,
|
||||||
|
block_tables,
|
||||||
|
input_lengths,
|
||||||
|
block_size,
|
||||||
|
max_s,
|
||||||
|
None,
|
||||||
|
)
|
|
@ -80,15 +80,11 @@ try:
|
||||||
from text_generation_server.models.flash_phi import FlashPhi
|
from text_generation_server.models.flash_phi import FlashPhi
|
||||||
from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
|
from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
|
||||||
from text_generation_server.models.flash_dbrx import FlashDbrx
|
from text_generation_server.models.flash_dbrx import FlashDbrx
|
||||||
from text_generation_server.utils.flash_attn import (
|
from text_generation_server.layers.attention import SUPPORTS_WINDOWING
|
||||||
HAS_FLASH_ATTN_V2_CUDA,
|
|
||||||
HAS_FLASH_ATTN_V2_ROCM,
|
|
||||||
)
|
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.warning(f"Could not import Flash Attention enabled models: {e}")
|
logger.warning(f"Could not import Flash Attention enabled models: {e}")
|
||||||
|
SUPPORTS_WINDOWING = False
|
||||||
FLASH_ATTENTION = False
|
FLASH_ATTENTION = False
|
||||||
HAS_FLASH_ATTN_V2_CUDA = False
|
|
||||||
HAS_FLASH_ATTN_V2_ROCM = False
|
|
||||||
|
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
__all__.append(FlashGPT2)
|
__all__.append(FlashGPT2)
|
||||||
|
@ -262,6 +258,7 @@ def get_model(
|
||||||
dtype: Optional[str],
|
dtype: Optional[str],
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
) -> Model:
|
) -> Model:
|
||||||
|
global FLASH_ATTENTION
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
if quantize in ["awq", "exl2", "gptq"]:
|
if quantize in ["awq", "exl2", "gptq"]:
|
||||||
# These quantizers only work with float16 params.
|
# These quantizers only work with float16 params.
|
||||||
|
@ -412,6 +409,12 @@ def get_model(
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Sharding is currently not supported with `exl2` quantization"
|
"Sharding is currently not supported with `exl2` quantization"
|
||||||
)
|
)
|
||||||
|
sliding_window = config_dict.get("sliding_window", -1)
|
||||||
|
if sliding_window != -1 and not SUPPORTS_WINDOWING:
|
||||||
|
logger.warning(
|
||||||
|
f"Flash attention is available, but doesn't support windowing which is required by model {model_id}"
|
||||||
|
)
|
||||||
|
FLASH_ATTENTION = False
|
||||||
|
|
||||||
if model_type == MAMBA:
|
if model_type == MAMBA:
|
||||||
return Mamba(
|
return Mamba(
|
||||||
|
@ -699,11 +702,7 @@ def get_model(
|
||||||
|
|
||||||
if model_type == MISTRAL:
|
if model_type == MISTRAL:
|
||||||
sliding_window = config_dict.get("sliding_window", -1)
|
sliding_window = config_dict.get("sliding_window", -1)
|
||||||
if (
|
if FLASH_ATTENTION:
|
||||||
((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
|
|
||||||
or HAS_FLASH_ATTN_V2_CUDA
|
|
||||||
or HAS_FLASH_ATTN_V2_ROCM
|
|
||||||
):
|
|
||||||
return FlashMistral(
|
return FlashMistral(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
|
@ -726,11 +725,7 @@ def get_model(
|
||||||
|
|
||||||
if model_type == MIXTRAL:
|
if model_type == MIXTRAL:
|
||||||
sliding_window = config_dict.get("sliding_window", -1)
|
sliding_window = config_dict.get("sliding_window", -1)
|
||||||
if (
|
if FLASH_ATTENTION:
|
||||||
((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
|
|
||||||
or HAS_FLASH_ATTN_V2_CUDA
|
|
||||||
or HAS_FLASH_ATTN_V2_ROCM
|
|
||||||
):
|
|
||||||
return FlashMixtral(
|
return FlashMixtral(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
|
@ -753,11 +748,7 @@ def get_model(
|
||||||
|
|
||||||
if model_type == STARCODER2:
|
if model_type == STARCODER2:
|
||||||
sliding_window = config_dict.get("sliding_window", -1)
|
sliding_window = config_dict.get("sliding_window", -1)
|
||||||
if (
|
if FLASH_ATTENTION:
|
||||||
((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
|
|
||||||
or HAS_FLASH_ATTN_V2_CUDA
|
|
||||||
or HAS_FLASH_ATTN_V2_ROCM
|
|
||||||
):
|
|
||||||
return FlashStarcoder2(
|
return FlashStarcoder2(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
|
@ -781,11 +772,7 @@ def get_model(
|
||||||
|
|
||||||
if model_type == QWEN2:
|
if model_type == QWEN2:
|
||||||
sliding_window = config_dict.get("sliding_window", -1)
|
sliding_window = config_dict.get("sliding_window", -1)
|
||||||
if (
|
if (sliding_window is None or sliding_window != -1) and SUPPORTS_WINDOWING:
|
||||||
((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
|
|
||||||
or HAS_FLASH_ATTN_V2_CUDA
|
|
||||||
or HAS_FLASH_ATTN_V2_ROCM
|
|
||||||
):
|
|
||||||
return FlashQwen2(
|
return FlashQwen2(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
|
|
|
@ -25,7 +25,11 @@ from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
from text_generation_server.utils import paged_attention, flash_attn
|
from text_generation_server.layers.attention import (
|
||||||
|
paged_attention,
|
||||||
|
attention,
|
||||||
|
reshape_and_cache,
|
||||||
|
)
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
|
@ -281,7 +285,7 @@ class FlashCohereAttention(torch.nn.Module):
|
||||||
|
|
||||||
self.rotary_emb(query, key, cos, sin)
|
self.rotary_emb(query, key, cos, sin)
|
||||||
|
|
||||||
paged_attention.reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
|
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
|
||||||
|
|
||||||
# output tensor
|
# output tensor
|
||||||
attn_output = torch.empty_like(query)
|
attn_output = torch.empty_like(query)
|
||||||
|
@ -289,7 +293,7 @@ class FlashCohereAttention(torch.nn.Module):
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn.attention(
|
attention(
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
|
@ -300,7 +304,7 @@ class FlashCohereAttention(torch.nn.Module):
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
paged_attention.attention(
|
paged_attention(
|
||||||
attn_output,
|
attn_output,
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
|
|
|
@ -27,7 +27,11 @@ from text_generation_server.utils.import_utils import SYSTEM
|
||||||
if SYSTEM != "xpu":
|
if SYSTEM != "xpu":
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||||
|
|
||||||
from text_generation_server.utils import paged_attention, flash_attn
|
from text_generation_server.layers.attention import (
|
||||||
|
paged_attention,
|
||||||
|
attention,
|
||||||
|
reshape_and_cache,
|
||||||
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
FastLinear,
|
FastLinear,
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
|
@ -424,9 +428,7 @@ class DbrxAttention(torch.nn.Module):
|
||||||
|
|
||||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
paged_attention.reshape_and_cache(
|
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
||||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
|
||||||
)
|
|
||||||
|
|
||||||
# output tensor
|
# output tensor
|
||||||
attn_output = torch.empty_like(query)
|
attn_output = torch.empty_like(query)
|
||||||
|
@ -434,7 +436,7 @@ class DbrxAttention(torch.nn.Module):
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn.attention(
|
attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
|
@ -445,7 +447,7 @@ class DbrxAttention(torch.nn.Module):
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
paged_attention.attention(
|
paged_attention(
|
||||||
attn_output,
|
attn_output,
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
|
|
|
@ -26,7 +26,11 @@ from transformers.activations import ACT2FN
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
from text_generation_server.utils import paged_attention, flash_attn
|
from text_generation_server.layers.attention import (
|
||||||
|
paged_attention,
|
||||||
|
attention,
|
||||||
|
reshape_and_cache,
|
||||||
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
|
@ -221,9 +225,7 @@ class FlashGemmaAttention(torch.nn.Module):
|
||||||
|
|
||||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
paged_attention.reshape_and_cache(
|
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
||||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
|
||||||
)
|
|
||||||
|
|
||||||
# output tensor
|
# output tensor
|
||||||
attn_output = torch.empty_like(query)
|
attn_output = torch.empty_like(query)
|
||||||
|
@ -231,7 +233,7 @@ class FlashGemmaAttention(torch.nn.Module):
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn.attention(
|
attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
|
@ -243,7 +245,7 @@ class FlashGemmaAttention(torch.nn.Module):
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
paged_attention.attention(
|
paged_attention(
|
||||||
attn_output,
|
attn_output,
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
|
|
|
@ -25,7 +25,11 @@ from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
from text_generation_server.utils import paged_attention, flash_attn
|
from text_generation_server.layers.attention import (
|
||||||
|
paged_attention,
|
||||||
|
attention,
|
||||||
|
reshape_and_cache,
|
||||||
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
|
@ -213,7 +217,7 @@ class FlashGPT2Attention(torch.nn.Module):
|
||||||
key = key.view(-1, self.num_heads, self.head_size)
|
key = key.view(-1, self.num_heads, self.head_size)
|
||||||
value = value.view(-1, self.num_heads, self.head_size)
|
value = value.view(-1, self.num_heads, self.head_size)
|
||||||
|
|
||||||
paged_attention.reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
|
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
|
||||||
|
|
||||||
# output tensor
|
# output tensor
|
||||||
attn_output = torch.empty_like(query)
|
attn_output = torch.empty_like(query)
|
||||||
|
@ -221,7 +225,7 @@ class FlashGPT2Attention(torch.nn.Module):
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn.attention(
|
attention(
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
|
@ -232,7 +236,7 @@ class FlashGPT2Attention(torch.nn.Module):
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
paged_attention.attention(
|
paged_attention(
|
||||||
attn_output,
|
attn_output,
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
|
|
|
@ -28,7 +28,11 @@ from transformers.activations import ACT2FN
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.utils import paged_attention, flash_attn
|
from text_generation_server.layers.attention import (
|
||||||
|
paged_attention,
|
||||||
|
attention,
|
||||||
|
reshape_and_cache,
|
||||||
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
|
@ -145,9 +149,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
|
|
||||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
paged_attention.reshape_and_cache(
|
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
||||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
|
||||||
)
|
|
||||||
|
|
||||||
# output tensor
|
# output tensor
|
||||||
attn_output = torch.empty_like(query)
|
attn_output = torch.empty_like(query)
|
||||||
|
@ -155,7 +157,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn.attention(
|
attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
|
@ -166,7 +168,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
paged_attention.attention(
|
paged_attention(
|
||||||
attn_output,
|
attn_output,
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
|
|
|
@ -27,7 +27,11 @@ from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.utils import paged_attention, flash_attn
|
from text_generation_server.layers.attention import (
|
||||||
|
paged_attention,
|
||||||
|
attention,
|
||||||
|
reshape_and_cache,
|
||||||
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
|
@ -186,7 +190,7 @@ class MistralAttention(torch.nn.Module):
|
||||||
else:
|
else:
|
||||||
kv_to_cache = kv
|
kv_to_cache = kv
|
||||||
|
|
||||||
paged_attention.reshape_and_cache(
|
reshape_and_cache(
|
||||||
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -196,7 +200,7 @@ class MistralAttention(torch.nn.Module):
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn.attention(
|
attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
|
@ -208,7 +212,7 @@ class MistralAttention(torch.nn.Module):
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
paged_attention.attention(
|
paged_attention(
|
||||||
attn_output,
|
attn_output,
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
|
|
|
@ -33,7 +33,11 @@ from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from text_generation_server.utils import paged_attention, flash_attn
|
from text_generation_server.layers.attention import (
|
||||||
|
paged_attention,
|
||||||
|
attention,
|
||||||
|
reshape_and_cache,
|
||||||
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
FastLinear,
|
FastLinear,
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
|
@ -265,7 +269,7 @@ class MixtralAttention(torch.nn.Module):
|
||||||
else:
|
else:
|
||||||
kv_to_cache = kv
|
kv_to_cache = kv
|
||||||
|
|
||||||
paged_attention.reshape_and_cache(
|
reshape_and_cache(
|
||||||
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -275,7 +279,7 @@ class MixtralAttention(torch.nn.Module):
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn.attention(
|
attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
|
@ -287,7 +291,7 @@ class MixtralAttention(torch.nn.Module):
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
paged_attention.attention(
|
paged_attention(
|
||||||
attn_output,
|
attn_output,
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
|
|
|
@ -27,8 +27,11 @@ from transformers.modeling_utils import PreTrainedModel
|
||||||
from transformers.models.gpt_neox import GPTNeoXConfig
|
from transformers.models.gpt_neox import GPTNeoXConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
from text_generation_server.utils import paged_attention, flash_attn
|
from text_generation_server.layers.attention import (
|
||||||
from text_generation_server.utils.flash_attn import attention
|
paged_attention,
|
||||||
|
attention,
|
||||||
|
reshape_and_cache,
|
||||||
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
|
@ -146,9 +149,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||||
# Inplace rotary
|
# Inplace rotary
|
||||||
self.rotary_emb(qkv[:, 0], qkv[:, 1], cos, sin)
|
self.rotary_emb(qkv[:, 0], qkv[:, 1], cos, sin)
|
||||||
|
|
||||||
paged_attention.reshape_and_cache(
|
reshape_and_cache(qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots)
|
||||||
qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots
|
|
||||||
)
|
|
||||||
|
|
||||||
# output tensor
|
# output tensor
|
||||||
attn_output = torch.empty_like(qkv[:, 0])
|
attn_output = torch.empty_like(qkv[:, 0])
|
||||||
|
@ -156,7 +157,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn.attention(
|
attention(
|
||||||
qkv[:, 0],
|
qkv[:, 0],
|
||||||
qkv[:, 1],
|
qkv[:, 1],
|
||||||
qkv[:, 2],
|
qkv[:, 2],
|
||||||
|
@ -167,7 +168,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
paged_attention.attention(
|
paged_attention(
|
||||||
attn_output,
|
attn_output,
|
||||||
qkv[:, 0],
|
qkv[:, 0],
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
|
|
|
@ -6,7 +6,11 @@ from transformers.activations import ACT2FN
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
from text_generation_server.utils import paged_attention, flash_attn
|
from text_generation_server.layers.attention import (
|
||||||
|
paged_attention,
|
||||||
|
attention,
|
||||||
|
reshape_and_cache,
|
||||||
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
|
@ -185,16 +189,14 @@ class FlashPhiAttention(torch.nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reshape key and value and cache
|
# Reshape key and value and cache
|
||||||
paged_attention.reshape_and_cache(
|
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
||||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
|
||||||
)
|
|
||||||
|
|
||||||
# output tensor
|
# output tensor
|
||||||
attn_output = torch.empty_like(query)
|
attn_output = torch.empty_like(query)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
flash_attn.attention(
|
attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
|
@ -205,7 +207,7 @@ class FlashPhiAttention(torch.nn.Module):
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
paged_attention.attention(
|
paged_attention(
|
||||||
attn_output,
|
attn_output,
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
|
|
|
@ -5,7 +5,11 @@ from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
from text_generation_server.utils import paged_attention, flash_attn
|
from text_generation_server.layers.attention import (
|
||||||
|
paged_attention,
|
||||||
|
attention,
|
||||||
|
reshape_and_cache,
|
||||||
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
|
@ -142,7 +146,7 @@ class Qwen2Attention(torch.nn.Module):
|
||||||
else:
|
else:
|
||||||
kv_to_cache = kv
|
kv_to_cache = kv
|
||||||
|
|
||||||
paged_attention.reshape_and_cache(
|
reshape_and_cache(
|
||||||
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -152,7 +156,7 @@ class Qwen2Attention(torch.nn.Module):
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn.attention(
|
attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
|
@ -164,7 +168,7 @@ class Qwen2Attention(torch.nn.Module):
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
paged_attention.attention(
|
paged_attention(
|
||||||
attn_output,
|
attn_output,
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
|
|
|
@ -15,7 +15,11 @@ from text_generation_server.layers import (
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.layernorm import FastLayerNorm
|
from text_generation_server.layers.layernorm import FastLayerNorm
|
||||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
from text_generation_server.utils import flash_attn, paged_attention
|
from text_generation_server.layers.attention import (
|
||||||
|
attention,
|
||||||
|
paged_attention,
|
||||||
|
reshape_and_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_row(config, prefix: str, weights, bias: bool):
|
def load_row(config, prefix: str, weights, bias: bool):
|
||||||
|
@ -194,9 +198,7 @@ class FlashRWAttention(torch.nn.Module):
|
||||||
# Inplace rotary
|
# Inplace rotary
|
||||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
paged_attention.reshape_and_cache(
|
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
||||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
|
||||||
)
|
|
||||||
|
|
||||||
# output
|
# output
|
||||||
attn_output = torch.empty_like(query)
|
attn_output = torch.empty_like(query)
|
||||||
|
@ -204,7 +206,7 @@ class FlashRWAttention(torch.nn.Module):
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn.attention(
|
attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
|
@ -215,7 +217,7 @@ class FlashRWAttention(torch.nn.Module):
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
paged_attention.attention(
|
paged_attention(
|
||||||
attn_output,
|
attn_output,
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
|
@ -313,7 +315,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||||
# Inplace rotary
|
# Inplace rotary
|
||||||
self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin)
|
||||||
|
|
||||||
paged_attention.reshape_and_cache(
|
reshape_and_cache(
|
||||||
kv[:, :, 0].contiguous(),
|
kv[:, :, 0].contiguous(),
|
||||||
kv[:, :, 1].contiguous(),
|
kv[:, :, 1].contiguous(),
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
|
@ -327,7 +329,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn.attention(
|
attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=2, index=0),
|
torch.select(kv, dim=2, index=0),
|
||||||
torch.select(kv, dim=2, index=1),
|
torch.select(kv, dim=2, index=1),
|
||||||
|
@ -338,7 +340,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
paged_attention.attention(
|
paged_attention(
|
||||||
attn_output,
|
attn_output,
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
|
|
|
@ -6,7 +6,11 @@ from transformers.activations import ACT2FN
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
from text_generation_server.layers.gptq import GPTQWeight
|
from text_generation_server.layers.gptq import GPTQWeight
|
||||||
from text_generation_server.utils import paged_attention, flash_attn
|
from text_generation_server.layers.attention import (
|
||||||
|
paged_attention,
|
||||||
|
attention,
|
||||||
|
reshape_and_cache,
|
||||||
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
|
@ -276,7 +280,7 @@ class FlashMQAttention(torch.nn.Module):
|
||||||
query = query.view(-1, self.num_heads, self.head_size)
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
key_value = key_value.view(-1, 2, 1, self.head_size)
|
key_value = key_value.view(-1, 2, 1, self.head_size)
|
||||||
|
|
||||||
paged_attention.reshape_and_cache(
|
reshape_and_cache(
|
||||||
key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots
|
key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -286,7 +290,7 @@ class FlashMQAttention(torch.nn.Module):
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn.attention(
|
attention(
|
||||||
query,
|
query,
|
||||||
torch.select(key_value, dim=1, index=0),
|
torch.select(key_value, dim=1, index=0),
|
||||||
torch.select(key_value, dim=1, index=1),
|
torch.select(key_value, dim=1, index=1),
|
||||||
|
@ -297,7 +301,7 @@ class FlashMQAttention(torch.nn.Module):
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
paged_attention.attention(
|
paged_attention(
|
||||||
attn_output,
|
attn_output,
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
|
|
|
@ -26,7 +26,11 @@ from transformers.activations import ACT2FN
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
from text_generation_server.utils import paged_attention, flash_attn
|
from text_generation_server.layers.attention import (
|
||||||
|
paged_attention,
|
||||||
|
attention,
|
||||||
|
reshape_and_cache,
|
||||||
|
)
|
||||||
from text_generation_server.layers import (
|
from text_generation_server.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
|
@ -229,7 +233,7 @@ class Starcoder2Attention(torch.nn.Module):
|
||||||
else:
|
else:
|
||||||
kv_to_cache = kv
|
kv_to_cache = kv
|
||||||
|
|
||||||
paged_attention.reshape_and_cache(
|
reshape_and_cache(
|
||||||
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -239,7 +243,7 @@ class Starcoder2Attention(torch.nn.Module):
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn.attention(
|
attention(
|
||||||
query,
|
query,
|
||||||
torch.select(kv, dim=1, index=0),
|
torch.select(kv, dim=1, index=0),
|
||||||
torch.select(kv, dim=1, index=1),
|
torch.select(kv, dim=1, index=1),
|
||||||
|
@ -251,7 +255,7 @@ class Starcoder2Attention(torch.nn.Module):
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
paged_attention.attention(
|
paged_attention(
|
||||||
attn_output,
|
attn_output,
|
||||||
query,
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import torch
|
import torch
|
||||||
import os
|
import os
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||||
# This is overridden by the cli
|
# This is overridden by the cli
|
||||||
|
|
|
@ -1,293 +0,0 @@
|
||||||
import os
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from loguru import logger
|
|
||||||
import math
|
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
if SYSTEM != "xpu":
|
|
||||||
from text_generation_server.utils.flash_attn_triton import triton_attention
|
|
||||||
|
|
||||||
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
|
||||||
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
|
||||||
HAS_FLASH_ATTN = False
|
|
||||||
HAS_FLASH_ATTN_V2_CUDA = False
|
|
||||||
HAS_FLASH_ATTN_V2_ROCM = False
|
|
||||||
ROCM_USE_FLASH_ATTN_V2_CK = False
|
|
||||||
ROCM_USE_FLASH_ATTN_V2_TRITON = False
|
|
||||||
|
|
||||||
|
|
||||||
if SYSTEM in {"cuda", "rocm"}:
|
|
||||||
if not torch.cuda.is_available():
|
|
||||||
raise ImportError("CUDA is not available")
|
|
||||||
|
|
||||||
major, minor = torch.cuda.get_device_capability()
|
|
||||||
is_sm75 = major == 7 and minor == 5
|
|
||||||
is_sm8x = major == 8 and minor >= 0
|
|
||||||
is_sm90 = major == 9 and minor == 0
|
|
||||||
is_sm94 = major == 9 and minor == 4
|
|
||||||
|
|
||||||
if SYSTEM == "rocm":
|
|
||||||
if (
|
|
||||||
os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() == "true"
|
|
||||||
or os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "0") == "1"
|
|
||||||
):
|
|
||||||
ROCM_USE_FLASH_ATTN_V2_TRITON = True
|
|
||||||
logger.info("ROCm: using Flash Attention 2 Triton implementation.")
|
|
||||||
else:
|
|
||||||
ROCM_USE_FLASH_ATTN_V2_CK = True
|
|
||||||
logger.info(
|
|
||||||
"ROCm: using Flash Attention 2 Composable Kernel implementation."
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
try:
|
|
||||||
import flash_attn_2_cuda
|
|
||||||
except ImportError:
|
|
||||||
architecture_suffix = f"-{SYSTEM}"
|
|
||||||
raise ImportError(
|
|
||||||
"Flash Attention V2 is not installed.\n"
|
|
||||||
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
|
||||||
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
|
|
||||||
)
|
|
||||||
if SYSTEM == "cuda" and not (is_sm8x or is_sm90):
|
|
||||||
raise ImportError(
|
|
||||||
f"GPU with CUDA capability {major} {minor} is not supported for "
|
|
||||||
"Flash Attention V2"
|
|
||||||
)
|
|
||||||
elif SYSTEM == "rocm" and not (is_sm8x or is_sm90 or is_sm94):
|
|
||||||
raise ImportError(
|
|
||||||
f"AMD GPU with compute capability {major} {minor} is not supported for "
|
|
||||||
"Flash Attention V2"
|
|
||||||
)
|
|
||||||
HAS_FLASH_ATTN_V2_CUDA = SYSTEM == "cuda"
|
|
||||||
HAS_FLASH_ATTN_V2_ROCM = SYSTEM == "rocm"
|
|
||||||
except ImportError as e:
|
|
||||||
try:
|
|
||||||
import flash_attn_cuda
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError(
|
|
||||||
"Flash Attention is not installed.\n"
|
|
||||||
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
|
||||||
"or install flash attention with `cd server && make install install-flash-attention`"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
if SYSTEM == "cuda" and not (is_sm75 or is_sm8x or is_sm90):
|
|
||||||
raise ImportError(
|
|
||||||
f"GPU with CUDA capability {major} {minor} is not supported"
|
|
||||||
) from e
|
|
||||||
elif SYSTEM == "rocm":
|
|
||||||
for idx in range(torch.cuda.device_count()):
|
|
||||||
if "MI210" not in torch.cuda.get_device_name(
|
|
||||||
idx
|
|
||||||
) and "MI250" not in torch.cuda.get_device_name(idx):
|
|
||||||
raise ImportError(
|
|
||||||
f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.warning(f"Unable to use Flash Attention V2: {e}")
|
|
||||||
HAS_FLASH_ATTN = True
|
|
||||||
|
|
||||||
if SYSTEM == "xpu":
|
|
||||||
import intel_extension_for_pytorch as ipex
|
|
||||||
|
|
||||||
def attention(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
out,
|
|
||||||
cu_seqlens,
|
|
||||||
max_s,
|
|
||||||
softmax_scale,
|
|
||||||
window_size_left=-1,
|
|
||||||
):
|
|
||||||
if window_size_left <= 0 and window_size_left != -1:
|
|
||||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
|
||||||
|
|
||||||
if window_size_left != -1:
|
|
||||||
raise ValueError(
|
|
||||||
f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
|
|
||||||
)
|
|
||||||
return ipex.llm.functional.varlen_attention(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
out,
|
|
||||||
cu_seqlens,
|
|
||||||
cu_seqlens,
|
|
||||||
max_s,
|
|
||||||
max_s,
|
|
||||||
0.0,
|
|
||||||
softmax_scale,
|
|
||||||
False,
|
|
||||||
True,
|
|
||||||
False,
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
|
|
||||||
elif HAS_FLASH_ATTN_V2_CUDA:
|
|
||||||
|
|
||||||
def attention(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
out,
|
|
||||||
cu_seqlens,
|
|
||||||
max_s,
|
|
||||||
softmax_scale,
|
|
||||||
window_size_left=-1,
|
|
||||||
causal=True,
|
|
||||||
):
|
|
||||||
if window_size_left <= 0 and window_size_left != -1:
|
|
||||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
|
||||||
return flash_attn_2_cuda.varlen_fwd(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
out,
|
|
||||||
cu_seqlens,
|
|
||||||
cu_seqlens,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
max_s,
|
|
||||||
max_s,
|
|
||||||
0.0,
|
|
||||||
softmax_scale,
|
|
||||||
False,
|
|
||||||
causal,
|
|
||||||
window_size_left,
|
|
||||||
0,
|
|
||||||
False,
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
|
|
||||||
elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_CK:
|
|
||||||
|
|
||||||
def attention(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
out,
|
|
||||||
cu_seqlens,
|
|
||||||
max_s,
|
|
||||||
softmax_scale,
|
|
||||||
window_size_left=-1,
|
|
||||||
causal=True,
|
|
||||||
):
|
|
||||||
if window_size_left <= 0 and window_size_left != -1:
|
|
||||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
|
||||||
if window_size_left != -1:
|
|
||||||
raise ValueError(
|
|
||||||
f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
|
|
||||||
)
|
|
||||||
|
|
||||||
# RoCm flash API does not take the window_size_left and window_size_right arguments.
|
|
||||||
return flash_attn_2_cuda.varlen_fwd(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
out,
|
|
||||||
cu_seqlens,
|
|
||||||
cu_seqlens,
|
|
||||||
max_s,
|
|
||||||
max_s,
|
|
||||||
0.0,
|
|
||||||
softmax_scale,
|
|
||||||
False,
|
|
||||||
causal,
|
|
||||||
False,
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
|
|
||||||
elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_TRITON:
|
|
||||||
|
|
||||||
def attention(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
out,
|
|
||||||
cu_seqlens,
|
|
||||||
max_s,
|
|
||||||
softmax_scale,
|
|
||||||
window_size_left=-1,
|
|
||||||
causal=True,
|
|
||||||
):
|
|
||||||
output, _ = triton_attention(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
out,
|
|
||||||
cu_seqlens,
|
|
||||||
cu_seqlens,
|
|
||||||
max_s,
|
|
||||||
max_s,
|
|
||||||
causal,
|
|
||||||
softmax_scale,
|
|
||||||
)
|
|
||||||
return output
|
|
||||||
|
|
||||||
elif HAS_FLASH_ATTN:
|
|
||||||
|
|
||||||
def attention(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
out,
|
|
||||||
cu_seqlens,
|
|
||||||
max_s,
|
|
||||||
softmax_scale,
|
|
||||||
window_size_left=-1,
|
|
||||||
):
|
|
||||||
if window_size_left != -1:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"window_size_left is only available with flash attn v2"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Flash attention v1 requires q, k and v to have the same number of heads
|
|
||||||
if k.shape[1] != q.shape[1]:
|
|
||||||
# MQA expand
|
|
||||||
if k.shape[1] == 1:
|
|
||||||
k = k.expand(-1, q.shape[1], -1)
|
|
||||||
# Grouped attention reshape
|
|
||||||
else:
|
|
||||||
original_shape = k.shape
|
|
||||||
k = (
|
|
||||||
k.unsqueeze(2)
|
|
||||||
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
|
|
||||||
.reshape(original_shape[0], -1, original_shape[2])
|
|
||||||
)
|
|
||||||
if v.shape[1] != q.shape[1]:
|
|
||||||
# MQA expand
|
|
||||||
if v.shape[1] == 1:
|
|
||||||
v = v.expand(-1, q.shape[1], -1)
|
|
||||||
# Grouped attention reshape
|
|
||||||
else:
|
|
||||||
original_shape = v.shape
|
|
||||||
v = (
|
|
||||||
v.unsqueeze(2)
|
|
||||||
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
|
|
||||||
.reshape(original_shape[0], -1, original_shape[2])
|
|
||||||
)
|
|
||||||
|
|
||||||
return flash_attn_cuda.fwd(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
out,
|
|
||||||
cu_seqlens,
|
|
||||||
cu_seqlens,
|
|
||||||
max_s,
|
|
||||||
max_s,
|
|
||||||
0.0,
|
|
||||||
softmax_scale,
|
|
||||||
False,
|
|
||||||
True,
|
|
||||||
False,
|
|
||||||
0,
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("flash attention is not installed")
|
|
|
@ -1,4 +1,5 @@
|
||||||
import torch
|
import torch
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
def is_xpu_available():
|
def is_xpu_available():
|
||||||
|
@ -48,3 +49,4 @@ else:
|
||||||
empty_cache = noop
|
empty_cache = noop
|
||||||
synchronize = noop
|
synchronize = noop
|
||||||
get_free_memory = noop
|
get_free_memory = noop
|
||||||
|
logger.info(f"Detected system {SYSTEM}")
|
||||||
|
|
|
@ -1,137 +0,0 @@
|
||||||
import torch
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
|
||||||
|
|
||||||
_PARTITION_SIZE = 512
|
|
||||||
|
|
||||||
if SYSTEM == "xpu":
|
|
||||||
import intel_extension_for_pytorch as ipex
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
from vllm._C import cache_ops
|
|
||||||
from vllm._C import ops
|
|
||||||
except Exception as e:
|
|
||||||
raise ImportError(
|
|
||||||
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def reshape_and_cache(
|
|
||||||
key: torch.Tensor,
|
|
||||||
value: torch.Tensor,
|
|
||||||
key_cache: torch.Tensor,
|
|
||||||
value_cache: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
|
||||||
):
|
|
||||||
if SYSTEM == "xpu":
|
|
||||||
ipex.llm.modules.PagedAttention.reshape_and_cache(
|
|
||||||
key, value, key_cache, value_cache, slots
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
cache_ops.reshape_and_cache(
|
|
||||||
key, value, key_cache, value_cache, slots, "auto", 1.0
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def attention(
|
|
||||||
out: torch.Tensor,
|
|
||||||
query: torch.Tensor,
|
|
||||||
key_cache: torch.Tensor,
|
|
||||||
value_cache: torch.Tensor,
|
|
||||||
kv_head_mapping: torch.Tensor,
|
|
||||||
softmax_scale: float,
|
|
||||||
block_tables: torch.Tensor,
|
|
||||||
input_lengths: torch.Tensor,
|
|
||||||
max_s: int,
|
|
||||||
):
|
|
||||||
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
|
||||||
# Copyright 2023 The vLLM team. All rights
|
|
||||||
# reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
#
|
|
||||||
|
|
||||||
# value_cache => [num_blocks, num_heads, head_size, block_size]
|
|
||||||
block_size = value_cache.shape[3]
|
|
||||||
num_seqs, num_heads, head_size = query.shape
|
|
||||||
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
|
||||||
if SYSTEM == "xpu":
|
|
||||||
query = query.contiguous()
|
|
||||||
return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
|
||||||
out,
|
|
||||||
query,
|
|
||||||
key_cache,
|
|
||||||
value_cache,
|
|
||||||
kv_head_mapping,
|
|
||||||
softmax_scale,
|
|
||||||
block_tables,
|
|
||||||
input_lengths,
|
|
||||||
block_size,
|
|
||||||
max_s,
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
|
||||||
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
|
||||||
# V1 to avoid the overhead of reduction. Also, if the number of
|
|
||||||
# sequences or heads is large, we use V1 since there is enough work
|
|
||||||
# to parallelize.
|
|
||||||
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
|
|
||||||
if use_v1:
|
|
||||||
ops.paged_attention_v1(
|
|
||||||
out,
|
|
||||||
query,
|
|
||||||
key_cache,
|
|
||||||
value_cache,
|
|
||||||
kv_head_mapping,
|
|
||||||
softmax_scale,
|
|
||||||
block_tables,
|
|
||||||
input_lengths,
|
|
||||||
block_size,
|
|
||||||
max_s,
|
|
||||||
None,
|
|
||||||
"auto",
|
|
||||||
1.0,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Run PagedAttention V2.
|
|
||||||
assert _PARTITION_SIZE % block_size == 0
|
|
||||||
tmp_output = torch.empty(
|
|
||||||
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
|
||||||
dtype=out.dtype,
|
|
||||||
device=out.device,
|
|
||||||
)
|
|
||||||
exp_sums = torch.empty(
|
|
||||||
size=(num_seqs, num_heads, max_num_partitions),
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=out.device,
|
|
||||||
)
|
|
||||||
max_logits = torch.empty_like(exp_sums)
|
|
||||||
|
|
||||||
ops.paged_attention_v2(
|
|
||||||
out,
|
|
||||||
exp_sums,
|
|
||||||
max_logits,
|
|
||||||
tmp_output,
|
|
||||||
query,
|
|
||||||
key_cache,
|
|
||||||
value_cache,
|
|
||||||
kv_head_mapping,
|
|
||||||
softmax_scale,
|
|
||||||
block_tables,
|
|
||||||
input_lengths,
|
|
||||||
block_size,
|
|
||||||
max_s,
|
|
||||||
None,
|
|
||||||
"auto",
|
|
||||||
1.0,
|
|
||||||
)
|
|
Loading…
Reference in New Issue