136 lines
4.2 KiB
Python
136 lines
4.2 KiB
Python
import math
|
|
import torch
|
|
|
|
from typing import Optional, List, Tuple
|
|
|
|
BLOCK_SIZE: int = 16
|
|
# Will be set in warmup
|
|
CACHE_MANAGER: Optional["CacheManager"] = None
|
|
|
|
|
|
class CacheManager:
|
|
def __init__(
|
|
self,
|
|
num_blocks: int,
|
|
num_layers: int,
|
|
num_heads: int,
|
|
head_size: int,
|
|
repeat_slots: bool,
|
|
dtype: torch.dtype,
|
|
device: torch.device,
|
|
):
|
|
self.block_size = BLOCK_SIZE
|
|
self.num_blocks = num_blocks
|
|
self.repeat_slots = repeat_slots
|
|
|
|
element_size = torch.tensor([], dtype=dtype).element_size()
|
|
x = self.block_size // element_size
|
|
|
|
self.kv_cache = [
|
|
(
|
|
torch.empty(
|
|
(num_blocks, num_heads, head_size // x, self.block_size, x),
|
|
dtype=dtype,
|
|
device=device,
|
|
),
|
|
torch.empty(
|
|
(num_blocks, num_heads, head_size, self.block_size),
|
|
dtype=dtype,
|
|
device=device,
|
|
),
|
|
)
|
|
for _ in range(num_layers)
|
|
]
|
|
self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu")
|
|
self.slots = torch.arange(
|
|
0, num_blocks * self.block_size, dtype=torch.int32
|
|
).view(num_blocks, self.block_size)
|
|
|
|
def allocate(
|
|
self,
|
|
needed_blocks_slots: List[Tuple[int, int]],
|
|
blocks: int,
|
|
max_blocks: int,
|
|
device: torch.device,
|
|
):
|
|
# Get free blocks indices by finding values in mask that are not set to 0
|
|
free_block_indices = self.free_block_mask.nonzero()
|
|
assert (
|
|
len(free_block_indices) >= blocks
|
|
), f"Out of available cache blocks: asked {blocks}, only {len(free_block_indices)} free blocks"
|
|
|
|
# Slice by the number of required blocks
|
|
block_indices = free_block_indices[:blocks]
|
|
block_indices = block_indices.flatten()
|
|
|
|
# Padded block tables
|
|
block_tables_tensor = torch.zeros(
|
|
(len(needed_blocks_slots), max_blocks), dtype=torch.int32
|
|
)
|
|
|
|
# Allocate paged attention blocks
|
|
cumulative_blocks = 0
|
|
slots = []
|
|
block_tables = []
|
|
for i, (needed_blocks, needed_slots) in enumerate(needed_blocks_slots):
|
|
# Get allocated blocks for this sequence
|
|
allocated_blocks = block_indices[
|
|
cumulative_blocks : cumulative_blocks + needed_blocks
|
|
]
|
|
# Get slots for the allocated blocks
|
|
all_slots = self.slots[allocated_blocks].flatten()
|
|
|
|
# Repeat slots in the case of context sliding window
|
|
if needed_slots > len(all_slots) and self.repeat_slots:
|
|
repeats = math.ceil(needed_slots / len(all_slots))
|
|
all_slots = all_slots.repeat(repeats)
|
|
|
|
allocated_slots = all_slots[:needed_slots]
|
|
|
|
slots.append(allocated_slots)
|
|
block_tables.append(allocated_blocks.tolist())
|
|
block_tables_tensor[i, :needed_blocks] = allocated_blocks
|
|
cumulative_blocks += needed_blocks
|
|
|
|
block_tables = block_tables
|
|
block_tables_tensor = block_tables_tensor.to(device)
|
|
slots = torch.concat(slots).to(device)
|
|
|
|
# Allocate the required number of blocks by setting the mask to 0
|
|
self.free_block_mask[block_indices] = 0
|
|
|
|
return block_tables, block_tables_tensor, slots
|
|
|
|
def free(self, block_indices: Optional[List[int]]):
|
|
if block_indices is not None and block_indices:
|
|
# Reset mask
|
|
self.free_block_mask[block_indices] = 1
|
|
|
|
|
|
def set_cache_manager(
|
|
num_blocks: int,
|
|
num_layers: int,
|
|
num_heads: int,
|
|
head_size: int,
|
|
repeat_slots: bool,
|
|
dtype: torch.dtype,
|
|
device: torch.device,
|
|
) -> CacheManager:
|
|
global CACHE_MANAGER
|
|
if CACHE_MANAGER is not None:
|
|
del CACHE_MANAGER
|
|
torch.cuda.empty_cache()
|
|
|
|
CACHE_MANAGER = CacheManager(
|
|
num_blocks, num_layers, num_heads, head_size, repeat_slots, dtype, device
|
|
)
|
|
return CACHE_MANAGER
|
|
|
|
|
|
def get_cache_manager() -> CacheManager:
|
|
global CACHE_MANAGER
|
|
if CACHE_MANAGER is None:
|
|
raise RuntimeError("cache manager was not initialized")
|
|
|
|
return CACHE_MANAGER
|