Purely refactors paged/attention into `layers/attention` and make hardware differences more obvious with 1 file per hardware. (#1986)

# What does this PR do?

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @


@OlivierDehaene OR @Narsil

 -->
This commit is contained in:
Nicolas Patry 2024-05-31 17:57:01 +02:00 committed by GitHub
parent 659bd67fec
commit 06edde9491
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 754 additions and 527 deletions

View File

@ -70,7 +70,6 @@ impl Infer {
tokenizer_config: HubTokenizerConfig,
processor_config: HubProcessorConfig,
) -> Self {
// Infer shared state
let queue = Queue::new(requires_padding, 16, window_size, speculate);
let shared = Arc::new(Shared {
batching_task: Notify::new(),

View File

@ -1,11 +1,11 @@
flash_att_v2_commit_cuda := 23e8fa5a263d1c7122bc46a86ef32030ee7130f9
flash_att_v2_commit_cuda := v2.5.8
flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6
flash-attention-v2-cuda:
# Clone flash attention
pip install -U packaging ninja --no-cache-dir
git clone https://github.com/HazyResearch/flash-attention.git flash-attention-v2
git clone https://github.com/Dao-AILab/flash-attention.git flash-attention-v2
build-flash-attention-v2-cuda: flash-attention-v2-cuda
cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_cuda)

View File

@ -0,0 +1,13 @@
from text_generation_server.utils.import_utils import SYSTEM
import os
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.")
if SYSTEM == "cuda":
from .cuda import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
elif SYSTEM == "rocm":
from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
elif SYSTEM == "xpu":
from .xpu import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
else:
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")

View File

@ -0,0 +1,245 @@
import torch
from text_generation_server.utils.import_utils import SYSTEM
major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
_PARTITION_SIZE = 512
try:
from vllm._C import cache_ops
from vllm._C import ops
except Exception as e:
raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
)
def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
):
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
def paged_attention(
out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
# Copyright 2023 The vLLM team. All rights
# reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# value_cache => [num_blocks, num_heads, head_size, block_size]
block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
from vllm._C import ops
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
if use_v1:
ops.paged_attention_v1(
out,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
1.0,
)
else:
# Run PagedAttention V2.
assert _PARTITION_SIZE % block_size == 0
tmp_output = torch.empty(
size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=out.dtype,
device=out.device,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, max_num_partitions),
dtype=torch.float32,
device=out.device,
)
max_logits = torch.empty_like(exp_sums)
ops.paged_attention_v2(
out,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
1.0,
)
try:
import flash_attn_2_cuda
V2 = True
except ImportError:
try:
import flash_attn_cuda
V2 = False
except ImportError as e:
if major >= 8:
architecture_suffix = f"-{SYSTEM}"
raise ImportError(
"Flash Attention V2 is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
)
elif is_sm75:
raise ImportError(
"Flash Attention is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention with `cd server && make install install-flash-attention`"
) from e
else:
raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported"
) from e
SUPPORTS_WINDOWING = V2
if V2:
def attention(
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
causal=True,
):
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
return flash_attn_2_cuda.varlen_fwd(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
None,
None,
None,
max_s,
max_s,
0.0,
softmax_scale,
False,
causal,
window_size_left,
0,
False,
None,
)
else:
def attention(
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
):
if window_size_left != -1:
raise NotImplementedError(
"window_size_left is only available with flash attn v2"
)
# Flash attention v1 requires q, k and v to have the same number of heads
if k.shape[1] != q.shape[1]:
# MQA expand
if k.shape[1] == 1:
k = k.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = k.shape
k = (
k.unsqueeze(2)
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
if v.shape[1] != q.shape[1]:
# MQA expand
if v.shape[1] == 1:
v = v.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = v.shape
v = (
v.unsqueeze(2)
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
return flash_attn_cuda.fwd(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
0.0,
softmax_scale,
False,
True,
False,
0,
None,
)

View File

@ -0,0 +1,295 @@
import os
import torch
from text_generation_server.utils.import_utils import SYSTEM
from loguru import logger
major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
_PARTITION_SIZE = 512
use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"}
ENGINE = "triton" if use_triton else "ck"
try:
from vllm._C import cache_ops
from vllm._C import ops
except Exception as e:
raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
)
def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
):
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
def paged_attention(
out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
# Copyright 2023 The vLLM team. All rights
# reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# value_cache => [num_blocks, num_heads, head_size, block_size]
block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
from vllm._C import ops
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
if use_v1:
ops.paged_attention_v1(
out,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
1.0,
)
else:
# Run PagedAttention V2.
assert _PARTITION_SIZE % block_size == 0
tmp_output = torch.empty(
size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=out.dtype,
device=out.device,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, max_num_partitions),
dtype=torch.float32,
device=out.device,
)
max_logits = torch.empty_like(exp_sums)
ops.paged_attention_v2(
out,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
1.0,
)
if ENGINE != "triton":
try:
import flash_attn_2_cuda
logger.info("ROCm: using Flash Attention 2 Composable Kernel implementation.")
except ImportError:
try:
import flash_attn_cuda
ENGINE = "v1"
logger.info("ROCm: using Flash Attention 1")
except ImportError as e:
if major >= 8:
architecture_suffix = f"-{SYSTEM}"
raise ImportError(
"Flash Attention V2 is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
)
elif is_sm75:
raise ImportError(
"Flash Attention is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention with `cd server && make install install-flash-attention`"
) from e
else:
for idx in range(torch.cuda.device_count()):
name = torch.cuda.get_device_name(idx)
if "MI210" not in name and "MI250" not in name:
raise ImportError(
f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
)
raise ImportError(
f"AMD GPU with ROCm capability {major} {minor} is not supported"
) from e
SUPPORTS_WINDOWING = ENGINE != "v1"
if ENGINE == "ck":
def attention(
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
causal=True,
):
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
if window_size_left != -1:
raise ValueError(
f"ROCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
)
return flash_attn_2_cuda.varlen_fwd(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
None,
None,
None,
max_s,
max_s,
0.0,
softmax_scale,
False,
causal,
window_size_left,
0,
False,
None,
)
elif ENGINE == "triton":
from .flash_attn_triton import triton_attention
def attention(
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
causal=True,
):
if window_size_left != -1:
raise ValueError(
f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
)
output, _ = triton_attention(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
causal,
softmax_scale,
)
return output
else:
def attention(
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
):
if window_size_left != -1:
raise NotImplementedError(
"window_size_left is only available with flash attn v2"
)
# Flash attention v1 requires q, k and v to have the same number of heads
if k.shape[1] != q.shape[1]:
# MQA expand
if k.shape[1] == 1:
k = k.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = k.shape
k = (
k.unsqueeze(2)
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
if v.shape[1] != q.shape[1]:
# MQA expand
if v.shape[1] == 1:
v = v.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = v.shape
v = (
v.unsqueeze(2)
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
return flash_attn_cuda.fwd(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
0.0,
softmax_scale,
False,
True,
False,
0,
None,
)

View File

@ -0,0 +1,76 @@
import intel_extension_for_pytorch as ipex
import torch
SUPPORTS_WINDOWING = False
def attention(
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
):
if window_size_left != -1:
raise ValueError(
f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
)
return ipex.llm.functional.varlen_attention(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
0.0,
softmax_scale,
False,
True,
False,
None,
)
def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
):
ipex.llm.modules.PagedAttention.reshape_and_cache(
key, value, key_cache, value_cache, slots
)
def paged_attention(
out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
):
query = query.contiguous()
block_size = value_cache.shape[3]
return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
out,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
)

View File

@ -80,15 +80,11 @@ try:
from text_generation_server.models.flash_phi import FlashPhi
from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
from text_generation_server.models.flash_dbrx import FlashDbrx
from text_generation_server.utils.flash_attn import (
HAS_FLASH_ATTN_V2_CUDA,
HAS_FLASH_ATTN_V2_ROCM,
)
from text_generation_server.layers.attention import SUPPORTS_WINDOWING
except ImportError as e:
logger.warning(f"Could not import Flash Attention enabled models: {e}")
SUPPORTS_WINDOWING = False
FLASH_ATTENTION = False
HAS_FLASH_ATTN_V2_CUDA = False
HAS_FLASH_ATTN_V2_ROCM = False
if FLASH_ATTENTION:
__all__.append(FlashGPT2)
@ -262,6 +258,7 @@ def get_model(
dtype: Optional[str],
trust_remote_code: bool,
) -> Model:
global FLASH_ATTENTION
if dtype is None:
if quantize in ["awq", "exl2", "gptq"]:
# These quantizers only work with float16 params.
@ -412,6 +409,12 @@ def get_model(
raise RuntimeError(
"Sharding is currently not supported with `exl2` quantization"
)
sliding_window = config_dict.get("sliding_window", -1)
if sliding_window != -1 and not SUPPORTS_WINDOWING:
logger.warning(
f"Flash attention is available, but doesn't support windowing which is required by model {model_id}"
)
FLASH_ATTENTION = False
if model_type == MAMBA:
return Mamba(
@ -699,11 +702,7 @@ def get_model(
if model_type == MISTRAL:
sliding_window = config_dict.get("sliding_window", -1)
if (
((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
or HAS_FLASH_ATTN_V2_CUDA
or HAS_FLASH_ATTN_V2_ROCM
):
if FLASH_ATTENTION:
return FlashMistral(
model_id,
revision,
@ -726,11 +725,7 @@ def get_model(
if model_type == MIXTRAL:
sliding_window = config_dict.get("sliding_window", -1)
if (
((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
or HAS_FLASH_ATTN_V2_CUDA
or HAS_FLASH_ATTN_V2_ROCM
):
if FLASH_ATTENTION:
return FlashMixtral(
model_id,
revision,
@ -753,11 +748,7 @@ def get_model(
if model_type == STARCODER2:
sliding_window = config_dict.get("sliding_window", -1)
if (
((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
or HAS_FLASH_ATTN_V2_CUDA
or HAS_FLASH_ATTN_V2_ROCM
):
if FLASH_ATTENTION:
return FlashStarcoder2(
model_id,
revision,
@ -781,11 +772,7 @@ def get_model(
if model_type == QWEN2:
sliding_window = config_dict.get("sliding_window", -1)
if (
((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
or HAS_FLASH_ATTN_V2_CUDA
or HAS_FLASH_ATTN_V2_ROCM
):
if (sliding_window is None or sliding_window != -1) and SUPPORTS_WINDOWING:
return FlashQwen2(
model_id,
revision,

View File

@ -25,7 +25,11 @@ from torch import nn
from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
)
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers import (
TensorParallelRowLinear,
@ -281,7 +285,7 @@ class FlashCohereAttention(torch.nn.Module):
self.rotary_emb(query, key, cos, sin)
paged_attention.reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
# output tensor
attn_output = torch.empty_like(query)
@ -289,7 +293,7 @@ class FlashCohereAttention(torch.nn.Module):
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
flash_attn.attention(
attention(
query,
key,
value,
@ -300,7 +304,7 @@ class FlashCohereAttention(torch.nn.Module):
)
# Decode
else:
paged_attention.attention(
paged_attention(
attn_output,
query,
kv_cache[0],

View File

@ -27,7 +27,11 @@ from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM != "xpu":
from vllm.model_executor.layers.fused_moe import fused_moe
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
)
from text_generation_server.layers import (
FastLinear,
TensorParallelRowLinear,
@ -424,9 +428,7 @@ class DbrxAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
paged_attention.reshape_and_cache(
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
)
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
# output tensor
attn_output = torch.empty_like(query)
@ -434,7 +436,7 @@ class DbrxAttention(torch.nn.Module):
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
flash_attn.attention(
attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
@ -445,7 +447,7 @@ class DbrxAttention(torch.nn.Module):
)
# Decode
else:
paged_attention.attention(
paged_attention(
attn_output,
query,
kv_cache[0],

View File

@ -26,7 +26,11 @@ from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
@ -221,9 +225,7 @@ class FlashGemmaAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
paged_attention.reshape_and_cache(
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
)
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
# output tensor
attn_output = torch.empty_like(query)
@ -231,7 +233,7 @@ class FlashGemmaAttention(torch.nn.Module):
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
flash_attn.attention(
attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
@ -243,7 +245,7 @@ class FlashGemmaAttention(torch.nn.Module):
)
# Decode
else:
paged_attention.attention(
paged_attention(
attn_output,
query,
kv_cache[0],

View File

@ -25,7 +25,11 @@ from torch import nn
from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
@ -213,7 +217,7 @@ class FlashGPT2Attention(torch.nn.Module):
key = key.view(-1, self.num_heads, self.head_size)
value = value.view(-1, self.num_heads, self.head_size)
paged_attention.reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
# output tensor
attn_output = torch.empty_like(query)
@ -221,7 +225,7 @@ class FlashGPT2Attention(torch.nn.Module):
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
flash_attn.attention(
attention(
query,
key,
value,
@ -232,7 +236,7 @@ class FlashGPT2Attention(torch.nn.Module):
)
# Decode
else:
paged_attention.attention(
paged_attention(
attn_output,
query,
kv_cache[0],

View File

@ -28,7 +28,11 @@ from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
@ -145,9 +149,7 @@ class FlashLlamaAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
paged_attention.reshape_and_cache(
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
)
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
# output tensor
attn_output = torch.empty_like(query)
@ -155,7 +157,7 @@ class FlashLlamaAttention(torch.nn.Module):
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
flash_attn.attention(
attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
@ -166,7 +168,7 @@ class FlashLlamaAttention(torch.nn.Module):
)
# Decode
else:
paged_attention.attention(
paged_attention(
attn_output,
query,
kv_cache[0],

View File

@ -27,7 +27,11 @@ from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
@ -186,7 +190,7 @@ class MistralAttention(torch.nn.Module):
else:
kv_to_cache = kv
paged_attention.reshape_and_cache(
reshape_and_cache(
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
)
@ -196,7 +200,7 @@ class MistralAttention(torch.nn.Module):
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
flash_attn.attention(
attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
@ -208,7 +212,7 @@ class MistralAttention(torch.nn.Module):
)
# Decode
else:
paged_attention.attention(
paged_attention(
attn_output,
query,
kv_cache[0],

View File

@ -33,7 +33,11 @@ from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from loguru import logger
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
)
from text_generation_server.layers import (
FastLinear,
TensorParallelRowLinear,
@ -265,7 +269,7 @@ class MixtralAttention(torch.nn.Module):
else:
kv_to_cache = kv
paged_attention.reshape_and_cache(
reshape_and_cache(
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
)
@ -275,7 +279,7 @@ class MixtralAttention(torch.nn.Module):
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
flash_attn.attention(
attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
@ -287,7 +291,7 @@ class MixtralAttention(torch.nn.Module):
)
# Decode
else:
paged_attention.attention(
paged_attention(
attn_output,
query,
kv_cache[0],

View File

@ -27,8 +27,11 @@ from transformers.modeling_utils import PreTrainedModel
from transformers.models.gpt_neox import GPTNeoXConfig
from typing import Optional, List, Tuple
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.flash_attn import attention
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
@ -146,9 +149,7 @@ class FlashNeoxAttention(torch.nn.Module):
# Inplace rotary
self.rotary_emb(qkv[:, 0], qkv[:, 1], cos, sin)
paged_attention.reshape_and_cache(
qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots
)
reshape_and_cache(qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots)
# output tensor
attn_output = torch.empty_like(qkv[:, 0])
@ -156,7 +157,7 @@ class FlashNeoxAttention(torch.nn.Module):
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
flash_attn.attention(
attention(
qkv[:, 0],
qkv[:, 1],
qkv[:, 2],
@ -167,7 +168,7 @@ class FlashNeoxAttention(torch.nn.Module):
)
# Decode
else:
paged_attention.attention(
paged_attention(
attn_output,
qkv[:, 0],
kv_cache[0],

View File

@ -6,7 +6,11 @@ from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
@ -185,16 +189,14 @@ class FlashPhiAttention(torch.nn.Module):
)
# Reshape key and value and cache
paged_attention.reshape_and_cache(
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
)
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
# output tensor
attn_output = torch.empty_like(query)
# Prefill
if cu_seqlen_prefill is not None:
flash_attn.attention(
attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
@ -205,7 +207,7 @@ class FlashPhiAttention(torch.nn.Module):
)
# Decode
else:
paged_attention.attention(
paged_attention(
attn_output,
query,
kv_cache[0],

View File

@ -5,7 +5,11 @@ from torch import nn
from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
@ -142,7 +146,7 @@ class Qwen2Attention(torch.nn.Module):
else:
kv_to_cache = kv
paged_attention.reshape_and_cache(
reshape_and_cache(
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
)
@ -152,7 +156,7 @@ class Qwen2Attention(torch.nn.Module):
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
flash_attn.attention(
attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
@ -164,7 +168,7 @@ class Qwen2Attention(torch.nn.Module):
)
# Decode
else:
paged_attention.attention(
paged_attention(
attn_output,
query,
kv_cache[0],

View File

@ -15,7 +15,11 @@ from text_generation_server.layers import (
)
from text_generation_server.layers.layernorm import FastLayerNorm
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.utils import flash_attn, paged_attention
from text_generation_server.layers.attention import (
attention,
paged_attention,
reshape_and_cache,
)
def load_row(config, prefix: str, weights, bias: bool):
@ -194,9 +198,7 @@ class FlashRWAttention(torch.nn.Module):
# Inplace rotary
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
paged_attention.reshape_and_cache(
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
)
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
# output
attn_output = torch.empty_like(query)
@ -204,7 +206,7 @@ class FlashRWAttention(torch.nn.Module):
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
flash_attn.attention(
attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
@ -215,7 +217,7 @@ class FlashRWAttention(torch.nn.Module):
)
# Decode
else:
paged_attention.attention(
paged_attention(
attn_output,
query,
kv_cache[0],
@ -313,7 +315,7 @@ class FlashRWLargeAttention(torch.nn.Module):
# Inplace rotary
self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin)
paged_attention.reshape_and_cache(
reshape_and_cache(
kv[:, :, 0].contiguous(),
kv[:, :, 1].contiguous(),
kv_cache[0],
@ -327,7 +329,7 @@ class FlashRWLargeAttention(torch.nn.Module):
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
flash_attn.attention(
attention(
query,
torch.select(kv, dim=2, index=0),
torch.select(kv, dim=2, index=1),
@ -338,7 +340,7 @@ class FlashRWLargeAttention(torch.nn.Module):
)
# Decode
else:
paged_attention.attention(
paged_attention(
attn_output,
query,
kv_cache[0],

View File

@ -6,7 +6,11 @@ from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
@ -276,7 +280,7 @@ class FlashMQAttention(torch.nn.Module):
query = query.view(-1, self.num_heads, self.head_size)
key_value = key_value.view(-1, 2, 1, self.head_size)
paged_attention.reshape_and_cache(
reshape_and_cache(
key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots
)
@ -286,7 +290,7 @@ class FlashMQAttention(torch.nn.Module):
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
flash_attn.attention(
attention(
query,
torch.select(key_value, dim=1, index=0),
torch.select(key_value, dim=1, index=1),
@ -297,7 +301,7 @@ class FlashMQAttention(torch.nn.Module):
)
# Decode
else:
paged_attention.attention(
paged_attention(
attn_output,
query,
kv_cache[0],

View File

@ -26,7 +26,11 @@ from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.layers.attention import (
paged_attention,
attention,
reshape_and_cache,
)
from text_generation_server.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
@ -229,7 +233,7 @@ class Starcoder2Attention(torch.nn.Module):
else:
kv_to_cache = kv
paged_attention.reshape_and_cache(
reshape_and_cache(
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
)
@ -239,7 +243,7 @@ class Starcoder2Attention(torch.nn.Module):
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
flash_attn.attention(
attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
@ -251,7 +255,7 @@ class Starcoder2Attention(torch.nn.Module):
)
# Decode
else:
paged_attention.attention(
paged_attention(
attn_output,
query,
kv_cache[0],

View File

@ -1,5 +1,6 @@
import torch
import os
from loguru import logger
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
# This is overridden by the cli

View File

@ -1,293 +0,0 @@
import os
import torch
from loguru import logger
import math
from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM != "xpu":
from text_generation_server.utils.flash_attn_triton import triton_attention
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.")
HAS_FLASH_ATTN = False
HAS_FLASH_ATTN_V2_CUDA = False
HAS_FLASH_ATTN_V2_ROCM = False
ROCM_USE_FLASH_ATTN_V2_CK = False
ROCM_USE_FLASH_ATTN_V2_TRITON = False
if SYSTEM in {"cuda", "rocm"}:
if not torch.cuda.is_available():
raise ImportError("CUDA is not available")
major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
is_sm8x = major == 8 and minor >= 0
is_sm90 = major == 9 and minor == 0
is_sm94 = major == 9 and minor == 4
if SYSTEM == "rocm":
if (
os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() == "true"
or os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "0") == "1"
):
ROCM_USE_FLASH_ATTN_V2_TRITON = True
logger.info("ROCm: using Flash Attention 2 Triton implementation.")
else:
ROCM_USE_FLASH_ATTN_V2_CK = True
logger.info(
"ROCm: using Flash Attention 2 Composable Kernel implementation."
)
try:
try:
import flash_attn_2_cuda
except ImportError:
architecture_suffix = f"-{SYSTEM}"
raise ImportError(
"Flash Attention V2 is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
)
if SYSTEM == "cuda" and not (is_sm8x or is_sm90):
raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported for "
"Flash Attention V2"
)
elif SYSTEM == "rocm" and not (is_sm8x or is_sm90 or is_sm94):
raise ImportError(
f"AMD GPU with compute capability {major} {minor} is not supported for "
"Flash Attention V2"
)
HAS_FLASH_ATTN_V2_CUDA = SYSTEM == "cuda"
HAS_FLASH_ATTN_V2_ROCM = SYSTEM == "rocm"
except ImportError as e:
try:
import flash_attn_cuda
except ImportError:
raise ImportError(
"Flash Attention is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention with `cd server && make install install-flash-attention`"
) from e
if SYSTEM == "cuda" and not (is_sm75 or is_sm8x or is_sm90):
raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported"
) from e
elif SYSTEM == "rocm":
for idx in range(torch.cuda.device_count()):
if "MI210" not in torch.cuda.get_device_name(
idx
) and "MI250" not in torch.cuda.get_device_name(idx):
raise ImportError(
f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
)
logger.warning(f"Unable to use Flash Attention V2: {e}")
HAS_FLASH_ATTN = True
if SYSTEM == "xpu":
import intel_extension_for_pytorch as ipex
def attention(
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
):
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
if window_size_left != -1:
raise ValueError(
f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
)
return ipex.llm.functional.varlen_attention(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
0.0,
softmax_scale,
False,
True,
False,
None,
)
elif HAS_FLASH_ATTN_V2_CUDA:
def attention(
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
causal=True,
):
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
return flash_attn_2_cuda.varlen_fwd(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
None,
None,
None,
max_s,
max_s,
0.0,
softmax_scale,
False,
causal,
window_size_left,
0,
False,
None,
)
elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_CK:
def attention(
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
causal=True,
):
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
if window_size_left != -1:
raise ValueError(
f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
)
# RoCm flash API does not take the window_size_left and window_size_right arguments.
return flash_attn_2_cuda.varlen_fwd(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
0.0,
softmax_scale,
False,
causal,
False,
None,
)
elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_TRITON:
def attention(
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
causal=True,
):
output, _ = triton_attention(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
causal,
softmax_scale,
)
return output
elif HAS_FLASH_ATTN:
def attention(
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
):
if window_size_left != -1:
raise NotImplementedError(
"window_size_left is only available with flash attn v2"
)
# Flash attention v1 requires q, k and v to have the same number of heads
if k.shape[1] != q.shape[1]:
# MQA expand
if k.shape[1] == 1:
k = k.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = k.shape
k = (
k.unsqueeze(2)
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
if v.shape[1] != q.shape[1]:
# MQA expand
if v.shape[1] == 1:
v = v.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = v.shape
v = (
v.unsqueeze(2)
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
return flash_attn_cuda.fwd(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
0.0,
softmax_scale,
False,
True,
False,
0,
None,
)
else:
raise NotImplementedError("flash attention is not installed")

View File

@ -1,4 +1,5 @@
import torch
from loguru import logger
def is_xpu_available():
@ -48,3 +49,4 @@ else:
empty_cache = noop
synchronize = noop
get_free_memory = noop
logger.info(f"Detected system {SYSTEM}")

View File

@ -1,137 +0,0 @@
import torch
from text_generation_server.utils.import_utils import SYSTEM
_PARTITION_SIZE = 512
if SYSTEM == "xpu":
import intel_extension_for_pytorch as ipex
else:
try:
from vllm._C import cache_ops
from vllm._C import ops
except Exception as e:
raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
)
def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
):
if SYSTEM == "xpu":
ipex.llm.modules.PagedAttention.reshape_and_cache(
key, value, key_cache, value_cache, slots
)
else:
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0
)
def attention(
out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
# Copyright 2023 The vLLM team. All rights
# reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# value_cache => [num_blocks, num_heads, head_size, block_size]
block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
if SYSTEM == "xpu":
query = query.contiguous()
return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
out,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
)
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
if use_v1:
ops.paged_attention_v1(
out,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
1.0,
)
else:
# Run PagedAttention V2.
assert _PARTITION_SIZE % block_size == 0
tmp_output = torch.empty(
size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=out.dtype,
device=out.device,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, max_num_partitions),
dtype=torch.float32,
device=out.device,
)
max_logits = torch.empty_like(exp_sums)
ops.paged_attention_v2(
out,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
1.0,
)