feat: paged attention v2 (#1183)
This commit is contained in:
parent
63fa534612
commit
12590fdcce
|
@ -1,4 +1,4 @@
|
|||
flash_att_v2_commit := 601b4dc48dbe9d87c468daa2b4c0c8388b83753c
|
||||
flash_att_v2_commit := 02ac572f3ffc4f402e4183aaa6824b45859d3ed3
|
||||
|
||||
flash-attention-v2:
|
||||
# Clone flash attention
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
vllm_commit := 25dbff97d5a8f2ba331847237b458b2692e9ae78
|
||||
vllm_commit := f8a1e39fae05ca610be8d5a78be9d40f5274e5fc
|
||||
|
||||
vllm:
|
||||
# Clone vllm
|
||||
git clone https://github.com/OlivierDehaene/vllm.git
|
||||
git clone https://github.com/vllm-project/vllm.git
|
||||
|
||||
build-vllm: vllm
|
||||
cd vllm && git fetch && git checkout $(vllm_commit)
|
||||
|
|
|
@ -29,11 +29,7 @@ from typing import Optional, List, Tuple
|
|||
# Flash attention imports
|
||||
import dropout_layer_norm
|
||||
|
||||
# vllm imports
|
||||
import vllm_cache_ops
|
||||
import vllm_attention_ops
|
||||
|
||||
from text_generation_server.utils.flash_attn import attention
|
||||
from text_generation_server.utils import paged_attention, flash_attn
|
||||
from text_generation_server.utils.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
|
@ -269,7 +265,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
self.rotary_emb(query, cos, sin)
|
||||
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
vllm_cache_ops.reshape_and_cache(
|
||||
paged_attention.reshape_and_cache(
|
||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
)
|
||||
|
||||
|
@ -279,7 +275,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attention(
|
||||
flash_attn.attention(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
|
@ -290,9 +286,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
)
|
||||
# Decode
|
||||
else:
|
||||
# kv_cache[1] => [num_blocks, num_heads, head_size, block_size]
|
||||
block_size = kv_cache[1].shape[3]
|
||||
vllm_attention_ops.single_query_cached_kv_attention(
|
||||
paged_attention.attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
|
@ -301,7 +295,6 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
block_size,
|
||||
max_s,
|
||||
)
|
||||
|
||||
|
|
|
@ -29,10 +29,7 @@ from typing import Optional, List, Tuple
|
|||
# Flash attention imports
|
||||
import dropout_layer_norm
|
||||
|
||||
# vllm imports
|
||||
import vllm_cache_ops
|
||||
import vllm_attention_ops
|
||||
|
||||
from text_generation_server.utils import paged_attention, flash_attn
|
||||
from text_generation_server.utils.flash_attn import attention, HAS_FLASH_ATTN_V2
|
||||
from text_generation_server.utils.layers import (
|
||||
TensorParallelRowLinear,
|
||||
|
@ -272,7 +269,7 @@ class MistralAttention(torch.nn.Module):
|
|||
else:
|
||||
kv_to_cache = kv
|
||||
|
||||
vllm_cache_ops.reshape_and_cache(
|
||||
paged_attention.reshape_and_cache(
|
||||
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
)
|
||||
|
||||
|
@ -282,7 +279,7 @@ class MistralAttention(torch.nn.Module):
|
|||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attention(
|
||||
flash_attn.attention(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
|
@ -294,9 +291,7 @@ class MistralAttention(torch.nn.Module):
|
|||
)
|
||||
# Decode
|
||||
else:
|
||||
# kv_cache[1] => [num_blocks, num_heads, head_size, block_size]
|
||||
block_size = kv_cache[1].shape[3]
|
||||
vllm_attention_ops.single_query_cached_kv_attention(
|
||||
paged_attention.attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
|
@ -305,7 +300,6 @@ class MistralAttention(torch.nn.Module):
|
|||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
block_size,
|
||||
max_s,
|
||||
)
|
||||
|
||||
|
|
|
@ -27,10 +27,7 @@ from transformers.modeling_utils import PreTrainedModel
|
|||
from transformers.models.gpt_neox import GPTNeoXConfig
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
# vllm imports
|
||||
import vllm_cache_ops
|
||||
import vllm_attention_ops
|
||||
|
||||
from text_generation_server.utils import paged_attention, flash_attn
|
||||
from text_generation_server.utils.flash_attn import attention
|
||||
from text_generation_server.utils.layers import (
|
||||
TensorParallelRowLinear,
|
||||
|
@ -141,7 +138,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||
self.rotary_emb(qkv[:, 0], cos, sin)
|
||||
self.rotary_emb(qkv[:, 1], cos, sin)
|
||||
|
||||
vllm_cache_ops.reshape_and_cache(
|
||||
paged_attention.reshape_and_cache(
|
||||
qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots
|
||||
)
|
||||
|
||||
|
@ -151,7 +148,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attention(
|
||||
flash_attn.attention(
|
||||
qkv[:, 0],
|
||||
qkv[:, 1],
|
||||
qkv[:, 2],
|
||||
|
@ -162,9 +159,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||
)
|
||||
# Decode
|
||||
else:
|
||||
# kv_cache[1] => [num_blocks, num_heads, head_size, block_size]
|
||||
block_size = kv_cache[1].shape[3]
|
||||
vllm_attention_ops.single_query_cached_kv_attention(
|
||||
paged_attention.attention(
|
||||
attn_output,
|
||||
qkv[:, 0],
|
||||
kv_cache[0],
|
||||
|
@ -173,7 +168,6 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
block_size,
|
||||
max_s,
|
||||
)
|
||||
|
||||
|
|
|
@ -6,10 +6,7 @@ from transformers.modeling_utils import PreTrainedModel
|
|||
from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
# vllm imports
|
||||
import vllm_cache_ops
|
||||
import vllm_attention_ops
|
||||
|
||||
from text_generation_server.utils import paged_attention, flash_attn
|
||||
from text_generation_server.utils.flash_attn import attention
|
||||
from text_generation_server.utils.layers import (
|
||||
TensorParallelRowLinear,
|
||||
|
@ -191,7 +188,7 @@ class FlashRWAttention(torch.nn.Module):
|
|||
self.rotary_emb(query, cos, sin)
|
||||
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
vllm_cache_ops.reshape_and_cache(
|
||||
paged_attention.reshape_and_cache(
|
||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
)
|
||||
|
||||
|
@ -201,7 +198,7 @@ class FlashRWAttention(torch.nn.Module):
|
|||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attention(
|
||||
flash_attn.attention(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
|
@ -212,9 +209,7 @@ class FlashRWAttention(torch.nn.Module):
|
|||
)
|
||||
# Decode
|
||||
else:
|
||||
# kv_cache[1] => [num_blocks, num_heads_kv, head_size, block_size]
|
||||
block_size = kv_cache[1].shape[3]
|
||||
vllm_attention_ops.single_query_cached_kv_attention(
|
||||
paged_attention.attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
|
@ -223,7 +218,6 @@ class FlashRWAttention(torch.nn.Module):
|
|||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
block_size,
|
||||
max_s,
|
||||
)
|
||||
|
||||
|
@ -310,7 +304,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||
self.rotary_emb(query, cos, sin)
|
||||
self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin)
|
||||
|
||||
vllm_cache_ops.reshape_and_cache(
|
||||
paged_attention.reshape_and_cache(
|
||||
kv[:, :, 0].contiguous(),
|
||||
kv[:, :, 1].contiguous(),
|
||||
kv_cache[0],
|
||||
|
@ -324,7 +318,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attention(
|
||||
flash_attn.attention(
|
||||
query,
|
||||
torch.select(kv, dim=2, index=0),
|
||||
torch.select(kv, dim=2, index=1),
|
||||
|
@ -335,9 +329,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||
)
|
||||
# Decode
|
||||
else:
|
||||
# kv_cache[1] => [num_blocks, num_groups, head_size, block_size]
|
||||
block_size = kv_cache[1].shape[3]
|
||||
vllm_attention_ops.single_query_cached_kv_attention(
|
||||
paged_attention.attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
|
@ -346,7 +338,6 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
block_size,
|
||||
max_s,
|
||||
)
|
||||
|
||||
|
|
|
@ -5,10 +5,7 @@ from torch import nn
|
|||
from transformers.activations import ACT2FN
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
# vllm imports
|
||||
import vllm_cache_ops
|
||||
import vllm_attention_ops
|
||||
|
||||
from text_generation_server.utils import paged_attention, flash_attn
|
||||
from text_generation_server.utils.flash_attn import attention
|
||||
from text_generation_server.utils.layers import (
|
||||
TensorParallelRowLinear,
|
||||
|
@ -18,7 +15,6 @@ from text_generation_server.utils.layers import (
|
|||
FastLayerNorm,
|
||||
get_linear,
|
||||
)
|
||||
from safetensors import SafetensorError
|
||||
|
||||
|
||||
def load_multi_mqa(
|
||||
|
@ -258,7 +254,7 @@ class FlashMQAttention(torch.nn.Module):
|
|||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key_value = key_value.view(-1, 2, 1, self.head_size)
|
||||
|
||||
vllm_cache_ops.reshape_and_cache(
|
||||
paged_attention.reshape_and_cache(
|
||||
key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
)
|
||||
|
||||
|
@ -268,7 +264,7 @@ class FlashMQAttention(torch.nn.Module):
|
|||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attention(
|
||||
flash_attn.attention(
|
||||
query,
|
||||
torch.select(key_value, dim=1, index=0),
|
||||
torch.select(key_value, dim=1, index=1),
|
||||
|
@ -279,9 +275,7 @@ class FlashMQAttention(torch.nn.Module):
|
|||
)
|
||||
# Decode
|
||||
else:
|
||||
# kv_cache[1] => [num_blocks, 1, head_size, block_size]
|
||||
block_size = kv_cache[1].shape[3]
|
||||
vllm_attention_ops.single_query_cached_kv_attention(
|
||||
paged_attention.attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
|
@ -290,7 +284,6 @@ class FlashMQAttention(torch.nn.Module):
|
|||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
block_size,
|
||||
max_s,
|
||||
)
|
||||
|
||||
|
|
|
@ -0,0 +1,100 @@
|
|||
import torch
|
||||
|
||||
# vllm imports
|
||||
from vllm import cache_ops
|
||||
from vllm import attention_ops
|
||||
|
||||
_PARTITION_SIZE = 512
|
||||
|
||||
|
||||
def reshape_and_cache(key: torch.Tensor, value: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor,
|
||||
slots: torch.Tensor):
|
||||
cache_ops.reshape_and_cache(
|
||||
key, value, key_cache, value_cache, slots
|
||||
)
|
||||
|
||||
|
||||
def attention(
|
||||
out: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
kv_head_mapping: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
block_tables: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
):
|
||||
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
||||
# Copyright 2023 The vLLM team. All rights
|
||||
# reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
# value_cache => [num_blocks, num_heads, head_size, block_size]
|
||||
block_size = value_cache.shape[3]
|
||||
num_seqs, num_heads, head_size = query.shape
|
||||
max_num_partitions = (
|
||||
(max_s + _PARTITION_SIZE - 1) //
|
||||
_PARTITION_SIZE)
|
||||
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
||||
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
||||
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||
# sequences or heads is large, we use V1 since there is enough work
|
||||
# to parallelize.
|
||||
use_v1 = max_num_partitions == 1 or num_seqs * num_heads > 512
|
||||
if use_v1:
|
||||
attention_ops.paged_attention_v1(
|
||||
out,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
kv_head_mapping,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
block_size,
|
||||
max_s,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
# Run PagedAttention V2.
|
||||
assert _PARTITION_SIZE % block_size == 0
|
||||
tmp_output = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
||||
dtype=out.dtype,
|
||||
device=out.device,
|
||||
)
|
||||
exp_sums = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions),
|
||||
dtype=torch.float32,
|
||||
device=out.device,
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
attention_ops.paged_attention_v2(
|
||||
out,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
kv_head_mapping,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
block_size,
|
||||
max_s,
|
||||
None,
|
||||
)
|
Loading…
Reference in New Issue