Break cycle between the attention implementations and KV cache (#2627)

This commit is contained in:
Daniël de Kok 2024-10-17 14:54:22 +02:00 committed by GitHub
parent 5f32dea1e2
commit 8ec57558cd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 37 additions and 68 deletions

View File

@ -11,21 +11,18 @@ if SYSTEM == "cuda":
SUPPORTS_WINDOWING,
attention,
paged_attention,
reshape_and_cache,
)
elif SYSTEM == "rocm":
from .rocm import (
SUPPORTS_WINDOWING,
attention,
paged_attention,
reshape_and_cache,
)
elif SYSTEM == "ipex":
from .ipex import (
SUPPORTS_WINDOWING,
attention,
paged_attention,
reshape_and_cache,
)
else:
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
@ -36,7 +33,6 @@ from .kv_cache import KVCache
__all__ = [
"attention",
"paged_attention",
"reshape_and_cache",
"SUPPORTS_WINDOWING",
"KVCache",
"Seqlen",

View File

@ -12,30 +12,6 @@ major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
_PARTITION_SIZE = 512
try:
from vllm._C import cache_ops
except Exception as e:
raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
)
def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
):
if ATTENTION in {"flashdecoding", "flashinfer"}:
shape = key_cache.shape
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else:
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0
)
def paged_attention(
query: torch.Tensor,
@ -346,5 +322,4 @@ __all__ = [
"SUPPORTS_WINDOWING",
"attention",
"paged_attention",
"reshape_and_cache",
]

View File

@ -47,18 +47,6 @@ def attention(
return out
def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
):
ipex.llm.modules.PagedAttention.reshape_and_cache(
key, value, key_cache, value_cache, slots
)
def paged_attention(
query: torch.Tensor,
kv_cache: KVCache,
@ -94,5 +82,4 @@ __all__ = [
"SUPPORTS_WINDOWING",
"attention",
"paged_attention",
"reshape_and_cache",
]

View File

@ -115,6 +115,41 @@ class KVCache:
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else:
from text_generation_server.layers.attention import reshape_and_cache
paged_reshape_and_cache(key, value, key_cache, value_cache, slots)
reshape_and_cache(key, value, key_cache, value_cache, slots)
def paged_reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
):
if SYSTEM == "cuda":
try:
from vllm._C import cache_ops
except Exception as e:
raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
)
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0
)
elif SYSTEM == "rocm":
try:
import vllm._custom_ops as ops
except Exception as e:
raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
)
ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
elif SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex
ipex.llm.modules.PagedAttention.reshape_and_cache(
key, value, key_cache, value_cache, slots
)
else:
raise NotImplementedError(
f"Cannot reshape and cache for paged attention, system '{SYSTEM}' not supportedattention"
)

View File

@ -3,7 +3,6 @@ from typing import Optional
import torch
from text_generation_server.layers.attention.kv_cache import KVCache
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import ATTENTION
from text_generation_server.layers.attention import Seqlen
from text_generation_server.utils.log import log_master
from loguru import logger
@ -28,28 +27,6 @@ except ImportError as e:
)
use_rocm_custom_paged_attn = False
try:
import vllm._custom_ops as ops
except Exception as e:
raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
)
def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
):
if ATTENTION == "flashdecoding":
shape = key_cache.shape
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else:
ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
def paged_attention(
query: torch.Tensor,
@ -305,5 +282,4 @@ __all__ = [
"SUPPORTS_WINDOWING",
"attention",
"paged_attention",
"reshape_and_cache",
]