Break cycle between the attention implementations and KV cache (#2627)
This commit is contained in:
parent
5f32dea1e2
commit
8ec57558cd
|
@ -11,21 +11,18 @@ if SYSTEM == "cuda":
|
||||||
SUPPORTS_WINDOWING,
|
SUPPORTS_WINDOWING,
|
||||||
attention,
|
attention,
|
||||||
paged_attention,
|
paged_attention,
|
||||||
reshape_and_cache,
|
|
||||||
)
|
)
|
||||||
elif SYSTEM == "rocm":
|
elif SYSTEM == "rocm":
|
||||||
from .rocm import (
|
from .rocm import (
|
||||||
SUPPORTS_WINDOWING,
|
SUPPORTS_WINDOWING,
|
||||||
attention,
|
attention,
|
||||||
paged_attention,
|
paged_attention,
|
||||||
reshape_and_cache,
|
|
||||||
)
|
)
|
||||||
elif SYSTEM == "ipex":
|
elif SYSTEM == "ipex":
|
||||||
from .ipex import (
|
from .ipex import (
|
||||||
SUPPORTS_WINDOWING,
|
SUPPORTS_WINDOWING,
|
||||||
attention,
|
attention,
|
||||||
paged_attention,
|
paged_attention,
|
||||||
reshape_and_cache,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
|
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
|
||||||
|
@ -36,7 +33,6 @@ from .kv_cache import KVCache
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"attention",
|
"attention",
|
||||||
"paged_attention",
|
"paged_attention",
|
||||||
"reshape_and_cache",
|
|
||||||
"SUPPORTS_WINDOWING",
|
"SUPPORTS_WINDOWING",
|
||||||
"KVCache",
|
"KVCache",
|
||||||
"Seqlen",
|
"Seqlen",
|
||||||
|
|
|
@ -12,30 +12,6 @@ major, minor = torch.cuda.get_device_capability()
|
||||||
is_sm75 = major == 7 and minor == 5
|
is_sm75 = major == 7 and minor == 5
|
||||||
_PARTITION_SIZE = 512
|
_PARTITION_SIZE = 512
|
||||||
|
|
||||||
try:
|
|
||||||
from vllm._C import cache_ops
|
|
||||||
except Exception as e:
|
|
||||||
raise ImportError(
|
|
||||||
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def reshape_and_cache(
|
|
||||||
key: torch.Tensor,
|
|
||||||
value: torch.Tensor,
|
|
||||||
key_cache: torch.Tensor,
|
|
||||||
value_cache: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
|
||||||
):
|
|
||||||
if ATTENTION in {"flashdecoding", "flashinfer"}:
|
|
||||||
shape = key_cache.shape
|
|
||||||
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
|
||||||
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
|
||||||
else:
|
|
||||||
cache_ops.reshape_and_cache(
|
|
||||||
key, value, key_cache, value_cache, slots, "auto", 1.0
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def paged_attention(
|
def paged_attention(
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
|
@ -346,5 +322,4 @@ __all__ = [
|
||||||
"SUPPORTS_WINDOWING",
|
"SUPPORTS_WINDOWING",
|
||||||
"attention",
|
"attention",
|
||||||
"paged_attention",
|
"paged_attention",
|
||||||
"reshape_and_cache",
|
|
||||||
]
|
]
|
||||||
|
|
|
@ -47,18 +47,6 @@ def attention(
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def reshape_and_cache(
|
|
||||||
key: torch.Tensor,
|
|
||||||
value: torch.Tensor,
|
|
||||||
key_cache: torch.Tensor,
|
|
||||||
value_cache: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
|
||||||
):
|
|
||||||
ipex.llm.modules.PagedAttention.reshape_and_cache(
|
|
||||||
key, value, key_cache, value_cache, slots
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def paged_attention(
|
def paged_attention(
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
kv_cache: KVCache,
|
kv_cache: KVCache,
|
||||||
|
@ -94,5 +82,4 @@ __all__ = [
|
||||||
"SUPPORTS_WINDOWING",
|
"SUPPORTS_WINDOWING",
|
||||||
"attention",
|
"attention",
|
||||||
"paged_attention",
|
"paged_attention",
|
||||||
"reshape_and_cache",
|
|
||||||
]
|
]
|
||||||
|
|
|
@ -115,6 +115,41 @@ class KVCache:
|
||||||
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
||||||
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
||||||
else:
|
else:
|
||||||
from text_generation_server.layers.attention import reshape_and_cache
|
paged_reshape_and_cache(key, value, key_cache, value_cache, slots)
|
||||||
|
|
||||||
reshape_and_cache(key, value, key_cache, value_cache, slots)
|
|
||||||
|
def paged_reshape_and_cache(
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
):
|
||||||
|
if SYSTEM == "cuda":
|
||||||
|
try:
|
||||||
|
from vllm._C import cache_ops
|
||||||
|
except Exception as e:
|
||||||
|
raise ImportError(
|
||||||
|
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
|
||||||
|
)
|
||||||
|
cache_ops.reshape_and_cache(
|
||||||
|
key, value, key_cache, value_cache, slots, "auto", 1.0
|
||||||
|
)
|
||||||
|
elif SYSTEM == "rocm":
|
||||||
|
try:
|
||||||
|
import vllm._custom_ops as ops
|
||||||
|
except Exception as e:
|
||||||
|
raise ImportError(
|
||||||
|
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
|
||||||
|
)
|
||||||
|
ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
|
||||||
|
elif SYSTEM == "ipex":
|
||||||
|
import intel_extension_for_pytorch as ipex
|
||||||
|
|
||||||
|
ipex.llm.modules.PagedAttention.reshape_and_cache(
|
||||||
|
key, value, key_cache, value_cache, slots
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Cannot reshape and cache for paged attention, system '{SYSTEM}' not supportedattention"
|
||||||
|
)
|
||||||
|
|
|
@ -3,7 +3,6 @@ from typing import Optional
|
||||||
import torch
|
import torch
|
||||||
from text_generation_server.layers.attention.kv_cache import KVCache
|
from text_generation_server.layers.attention.kv_cache import KVCache
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.models.globals import ATTENTION
|
|
||||||
from text_generation_server.layers.attention import Seqlen
|
from text_generation_server.layers.attention import Seqlen
|
||||||
from text_generation_server.utils.log import log_master
|
from text_generation_server.utils.log import log_master
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
@ -28,28 +27,6 @@ except ImportError as e:
|
||||||
)
|
)
|
||||||
use_rocm_custom_paged_attn = False
|
use_rocm_custom_paged_attn = False
|
||||||
|
|
||||||
try:
|
|
||||||
import vllm._custom_ops as ops
|
|
||||||
except Exception as e:
|
|
||||||
raise ImportError(
|
|
||||||
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def reshape_and_cache(
|
|
||||||
key: torch.Tensor,
|
|
||||||
value: torch.Tensor,
|
|
||||||
key_cache: torch.Tensor,
|
|
||||||
value_cache: torch.Tensor,
|
|
||||||
slots: torch.Tensor,
|
|
||||||
):
|
|
||||||
if ATTENTION == "flashdecoding":
|
|
||||||
shape = key_cache.shape
|
|
||||||
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
|
||||||
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
|
||||||
else:
|
|
||||||
ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
|
|
||||||
|
|
||||||
|
|
||||||
def paged_attention(
|
def paged_attention(
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
|
@ -305,5 +282,4 @@ __all__ = [
|
||||||
"SUPPORTS_WINDOWING",
|
"SUPPORTS_WINDOWING",
|
||||||
"attention",
|
"attention",
|
||||||
"paged_attention",
|
"paged_attention",
|
||||||
"reshape_and_cache",
|
|
||||||
]
|
]
|
||||||
|
|
Loading…
Reference in New Issue