feat: paged attention v2 (#1183)

This commit is contained in:
OlivierDehaene 2023-10-23 12:29:25 +02:00 committed by GitHub
parent 63fa534612
commit 12590fdcce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 126 additions and 61 deletions

View File

@ -1,4 +1,4 @@
flash_att_v2_commit := 601b4dc48dbe9d87c468daa2b4c0c8388b83753c
flash_att_v2_commit := 02ac572f3ffc4f402e4183aaa6824b45859d3ed3
flash-attention-v2:
# Clone flash attention

View File

@ -1,8 +1,8 @@
vllm_commit := 25dbff97d5a8f2ba331847237b458b2692e9ae78
vllm_commit := f8a1e39fae05ca610be8d5a78be9d40f5274e5fc
vllm:
# Clone vllm
git clone https://github.com/OlivierDehaene/vllm.git
git clone https://github.com/vllm-project/vllm.git
build-vllm: vllm
cd vllm && git fetch && git checkout $(vllm_commit)

View File

@ -29,11 +29,7 @@ from typing import Optional, List, Tuple
# Flash attention imports
import dropout_layer_norm
# vllm imports
import vllm_cache_ops
import vllm_attention_ops
from text_generation_server.utils.flash_attn import attention
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
@ -269,7 +265,7 @@ class FlashLlamaAttention(torch.nn.Module):
self.rotary_emb(query, cos, sin)
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
vllm_cache_ops.reshape_and_cache(
paged_attention.reshape_and_cache(
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
)
@ -279,7 +275,7 @@ class FlashLlamaAttention(torch.nn.Module):
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attention(
flash_attn.attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
@ -290,9 +286,7 @@ class FlashLlamaAttention(torch.nn.Module):
)
# Decode
else:
# kv_cache[1] => [num_blocks, num_heads, head_size, block_size]
block_size = kv_cache[1].shape[3]
vllm_attention_ops.single_query_cached_kv_attention(
paged_attention.attention(
attn_output,
query,
kv_cache[0],
@ -301,7 +295,6 @@ class FlashLlamaAttention(torch.nn.Module):
self.softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
)

View File

@ -29,10 +29,7 @@ from typing import Optional, List, Tuple
# Flash attention imports
import dropout_layer_norm
# vllm imports
import vllm_cache_ops
import vllm_attention_ops
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.flash_attn import attention, HAS_FLASH_ATTN_V2
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
@ -272,7 +269,7 @@ class MistralAttention(torch.nn.Module):
else:
kv_to_cache = kv
vllm_cache_ops.reshape_and_cache(
paged_attention.reshape_and_cache(
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
)
@ -282,7 +279,7 @@ class MistralAttention(torch.nn.Module):
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attention(
flash_attn.attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
@ -294,9 +291,7 @@ class MistralAttention(torch.nn.Module):
)
# Decode
else:
# kv_cache[1] => [num_blocks, num_heads, head_size, block_size]
block_size = kv_cache[1].shape[3]
vllm_attention_ops.single_query_cached_kv_attention(
paged_attention.attention(
attn_output,
query,
kv_cache[0],
@ -305,7 +300,6 @@ class MistralAttention(torch.nn.Module):
self.softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
)

View File

@ -27,10 +27,7 @@ from transformers.modeling_utils import PreTrainedModel
from transformers.models.gpt_neox import GPTNeoXConfig
from typing import Optional, List, Tuple
# vllm imports
import vllm_cache_ops
import vllm_attention_ops
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.flash_attn import attention
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
@ -141,7 +138,7 @@ class FlashNeoxAttention(torch.nn.Module):
self.rotary_emb(qkv[:, 0], cos, sin)
self.rotary_emb(qkv[:, 1], cos, sin)
vllm_cache_ops.reshape_and_cache(
paged_attention.reshape_and_cache(
qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots
)
@ -151,7 +148,7 @@ class FlashNeoxAttention(torch.nn.Module):
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attention(
flash_attn.attention(
qkv[:, 0],
qkv[:, 1],
qkv[:, 2],
@ -162,9 +159,7 @@ class FlashNeoxAttention(torch.nn.Module):
)
# Decode
else:
# kv_cache[1] => [num_blocks, num_heads, head_size, block_size]
block_size = kv_cache[1].shape[3]
vllm_attention_ops.single_query_cached_kv_attention(
paged_attention.attention(
attn_output,
qkv[:, 0],
kv_cache[0],
@ -173,7 +168,6 @@ class FlashNeoxAttention(torch.nn.Module):
self.softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
)

View File

@ -6,10 +6,7 @@ from transformers.modeling_utils import PreTrainedModel
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
# vllm imports
import vllm_cache_ops
import vllm_attention_ops
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.flash_attn import attention
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
@ -191,7 +188,7 @@ class FlashRWAttention(torch.nn.Module):
self.rotary_emb(query, cos, sin)
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
vllm_cache_ops.reshape_and_cache(
paged_attention.reshape_and_cache(
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
)
@ -201,7 +198,7 @@ class FlashRWAttention(torch.nn.Module):
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attention(
flash_attn.attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
@ -212,9 +209,7 @@ class FlashRWAttention(torch.nn.Module):
)
# Decode
else:
# kv_cache[1] => [num_blocks, num_heads_kv, head_size, block_size]
block_size = kv_cache[1].shape[3]
vllm_attention_ops.single_query_cached_kv_attention(
paged_attention.attention(
attn_output,
query,
kv_cache[0],
@ -223,7 +218,6 @@ class FlashRWAttention(torch.nn.Module):
self.softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
)
@ -310,7 +304,7 @@ class FlashRWLargeAttention(torch.nn.Module):
self.rotary_emb(query, cos, sin)
self.rotary_emb(torch.select(kv, dim=2, index=0), cos, sin)
vllm_cache_ops.reshape_and_cache(
paged_attention.reshape_and_cache(
kv[:, :, 0].contiguous(),
kv[:, :, 1].contiguous(),
kv_cache[0],
@ -324,7 +318,7 @@ class FlashRWLargeAttention(torch.nn.Module):
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attention(
flash_attn.attention(
query,
torch.select(kv, dim=2, index=0),
torch.select(kv, dim=2, index=1),
@ -335,9 +329,7 @@ class FlashRWLargeAttention(torch.nn.Module):
)
# Decode
else:
# kv_cache[1] => [num_blocks, num_groups, head_size, block_size]
block_size = kv_cache[1].shape[3]
vllm_attention_ops.single_query_cached_kv_attention(
paged_attention.attention(
attn_output,
query,
kv_cache[0],
@ -346,7 +338,6 @@ class FlashRWLargeAttention(torch.nn.Module):
self.softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
)

View File

@ -5,10 +5,7 @@ from torch import nn
from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
# vllm imports
import vllm_cache_ops
import vllm_attention_ops
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.flash_attn import attention
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
@ -18,7 +15,6 @@ from text_generation_server.utils.layers import (
FastLayerNorm,
get_linear,
)
from safetensors import SafetensorError
def load_multi_mqa(
@ -258,7 +254,7 @@ class FlashMQAttention(torch.nn.Module):
query = query.view(-1, self.num_heads, self.head_size)
key_value = key_value.view(-1, 2, 1, self.head_size)
vllm_cache_ops.reshape_and_cache(
paged_attention.reshape_and_cache(
key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots
)
@ -268,7 +264,7 @@ class FlashMQAttention(torch.nn.Module):
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attention(
flash_attn.attention(
query,
torch.select(key_value, dim=1, index=0),
torch.select(key_value, dim=1, index=1),
@ -279,9 +275,7 @@ class FlashMQAttention(torch.nn.Module):
)
# Decode
else:
# kv_cache[1] => [num_blocks, 1, head_size, block_size]
block_size = kv_cache[1].shape[3]
vllm_attention_ops.single_query_cached_kv_attention(
paged_attention.attention(
attn_output,
query,
kv_cache[0],
@ -290,7 +284,6 @@ class FlashMQAttention(torch.nn.Module):
self.softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
)

View File

@ -0,0 +1,100 @@
import torch
# vllm imports
from vllm import cache_ops
from vllm import attention_ops
_PARTITION_SIZE = 512
def reshape_and_cache(key: torch.Tensor, value: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor,
slots: torch.Tensor):
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots
)
def attention(
out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
# Copyright 2023 The vLLM team. All rights
# reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# value_cache => [num_blocks, num_heads, head_size, block_size]
block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape
max_num_partitions = (
(max_s + _PARTITION_SIZE - 1) //
_PARTITION_SIZE)
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
use_v1 = max_num_partitions == 1 or num_seqs * num_heads > 512
if use_v1:
attention_ops.paged_attention_v1(
out,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
)
else:
# Run PagedAttention V2.
assert _PARTITION_SIZE % block_size == 0
tmp_output = torch.empty(
size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=out.dtype,
device=out.device,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, max_num_partitions),
dtype=torch.float32,
device=out.device,
)
max_logits = torch.empty_like(exp_sums)
attention_ops.paged_attention_v2(
out,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
)