Break cycle between the attention implementations and KV cache (#2627)
This commit is contained in:
parent
5f32dea1e2
commit
8ec57558cd
|
@ -11,21 +11,18 @@ if SYSTEM == "cuda":
|
|||
SUPPORTS_WINDOWING,
|
||||
attention,
|
||||
paged_attention,
|
||||
reshape_and_cache,
|
||||
)
|
||||
elif SYSTEM == "rocm":
|
||||
from .rocm import (
|
||||
SUPPORTS_WINDOWING,
|
||||
attention,
|
||||
paged_attention,
|
||||
reshape_and_cache,
|
||||
)
|
||||
elif SYSTEM == "ipex":
|
||||
from .ipex import (
|
||||
SUPPORTS_WINDOWING,
|
||||
attention,
|
||||
paged_attention,
|
||||
reshape_and_cache,
|
||||
)
|
||||
else:
|
||||
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
|
||||
|
@ -36,7 +33,6 @@ from .kv_cache import KVCache
|
|||
__all__ = [
|
||||
"attention",
|
||||
"paged_attention",
|
||||
"reshape_and_cache",
|
||||
"SUPPORTS_WINDOWING",
|
||||
"KVCache",
|
||||
"Seqlen",
|
||||
|
|
|
@ -12,30 +12,6 @@ major, minor = torch.cuda.get_device_capability()
|
|||
is_sm75 = major == 7 and minor == 5
|
||||
_PARTITION_SIZE = 512
|
||||
|
||||
try:
|
||||
from vllm._C import cache_ops
|
||||
except Exception as e:
|
||||
raise ImportError(
|
||||
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
|
||||
)
|
||||
|
||||
|
||||
def reshape_and_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
):
|
||||
if ATTENTION in {"flashdecoding", "flashinfer"}:
|
||||
shape = key_cache.shape
|
||||
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
||||
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
||||
else:
|
||||
cache_ops.reshape_and_cache(
|
||||
key, value, key_cache, value_cache, slots, "auto", 1.0
|
||||
)
|
||||
|
||||
|
||||
def paged_attention(
|
||||
query: torch.Tensor,
|
||||
|
@ -346,5 +322,4 @@ __all__ = [
|
|||
"SUPPORTS_WINDOWING",
|
||||
"attention",
|
||||
"paged_attention",
|
||||
"reshape_and_cache",
|
||||
]
|
||||
|
|
|
@ -47,18 +47,6 @@ def attention(
|
|||
return out
|
||||
|
||||
|
||||
def reshape_and_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
):
|
||||
ipex.llm.modules.PagedAttention.reshape_and_cache(
|
||||
key, value, key_cache, value_cache, slots
|
||||
)
|
||||
|
||||
|
||||
def paged_attention(
|
||||
query: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
|
@ -94,5 +82,4 @@ __all__ = [
|
|||
"SUPPORTS_WINDOWING",
|
||||
"attention",
|
||||
"paged_attention",
|
||||
"reshape_and_cache",
|
||||
]
|
||||
|
|
|
@ -115,6 +115,41 @@ class KVCache:
|
|||
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
||||
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
||||
else:
|
||||
from text_generation_server.layers.attention import reshape_and_cache
|
||||
paged_reshape_and_cache(key, value, key_cache, value_cache, slots)
|
||||
|
||||
reshape_and_cache(key, value, key_cache, value_cache, slots)
|
||||
|
||||
def paged_reshape_and_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
):
|
||||
if SYSTEM == "cuda":
|
||||
try:
|
||||
from vllm._C import cache_ops
|
||||
except Exception as e:
|
||||
raise ImportError(
|
||||
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
|
||||
)
|
||||
cache_ops.reshape_and_cache(
|
||||
key, value, key_cache, value_cache, slots, "auto", 1.0
|
||||
)
|
||||
elif SYSTEM == "rocm":
|
||||
try:
|
||||
import vllm._custom_ops as ops
|
||||
except Exception as e:
|
||||
raise ImportError(
|
||||
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
|
||||
)
|
||||
ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
|
||||
elif SYSTEM == "ipex":
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
ipex.llm.modules.PagedAttention.reshape_and_cache(
|
||||
key, value, key_cache, value_cache, slots
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Cannot reshape and cache for paged attention, system '{SYSTEM}' not supportedattention"
|
||||
)
|
||||
|
|
|
@ -3,7 +3,6 @@ from typing import Optional
|
|||
import torch
|
||||
from text_generation_server.layers.attention.kv_cache import KVCache
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.models.globals import ATTENTION
|
||||
from text_generation_server.layers.attention import Seqlen
|
||||
from text_generation_server.utils.log import log_master
|
||||
from loguru import logger
|
||||
|
@ -28,28 +27,6 @@ except ImportError as e:
|
|||
)
|
||||
use_rocm_custom_paged_attn = False
|
||||
|
||||
try:
|
||||
import vllm._custom_ops as ops
|
||||
except Exception as e:
|
||||
raise ImportError(
|
||||
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
|
||||
)
|
||||
|
||||
|
||||
def reshape_and_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
):
|
||||
if ATTENTION == "flashdecoding":
|
||||
shape = key_cache.shape
|
||||
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
||||
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
||||
else:
|
||||
ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
|
||||
|
||||
|
||||
def paged_attention(
|
||||
query: torch.Tensor,
|
||||
|
@ -305,5 +282,4 @@ __all__ = [
|
|||
"SUPPORTS_WINDOWING",
|
||||
"attention",
|
||||
"paged_attention",
|
||||
"reshape_and_cache",
|
||||
]
|
||||
|
|
Loading…
Reference in New Issue