hf_text-generation-inference/server/text_generation_server/models/metadata_kernels.py

348 lines
8.9 KiB
Python

import torch
import triton
import triton.language as tl
from loguru import logger
from typing import List, Optional
from torch.utils._triton import has_triton as has_triton_torch
from text_generation_server.utils.import_utils import (
SYSTEM,
)
from text_generation_server.utils.log import log_master
_HAS_TRITON: Optional[bool] = None
def has_triton():
global _HAS_TRITON
if _HAS_TRITON is None:
# FIXME: it seems that has_triton_torch is bugged on RocM
# For now, only accept cuda
_HAS_TRITON = has_triton_torch() if SYSTEM == "cuda" else False
if _HAS_TRITON:
log_master(logger.info, "Using optimized Triton indexing kernels.")
return _HAS_TRITON
def block_tables_to_padded(
max_blocks: int,
cu_seqlen: torch.Tensor,
block_tables: torch.Tensor,
block_tables_ragged: torch.Tensor,
):
def grid(meta):
return (
triton.cdiv(max_blocks, meta["BLOCK_SIZE"]),
len(block_tables),
)
triton_block_tables_to_padded[grid](
cu_seqlen,
block_tables,
block_tables_ragged,
block_tables.shape[1],
BLOCK_SIZE=256,
)
def block_tables_to_ragged(
*,
block_tables: torch.Tensor,
input_lengths: List[int],
cache_lengths: List[int],
input_lengths_tensor: torch.Tensor,
cache_lengths_tensor: torch.Tensor,
max_current_length: int,
) -> torch.Tensor:
"""Convert block table to ragged format compatible with FlashInfer."""
assert len(input_lengths) == len(cache_lengths)
total_len = sum(input_lengths) + sum(cache_lengths)
block_tables_ragged = torch.empty(
total_len, dtype=torch.int32, device=block_tables.device
)
if has_triton():
cu_seqlen = torch.nn.functional.pad(
torch.cumsum(input_lengths_tensor + cache_lengths_tensor, dim=0), (1, 0)
)
def grid(meta):
return (
triton.cdiv(max_current_length, meta["BLOCK_SIZE"]),
len(cache_lengths),
)
triton_block_tables_to_ragged[grid](
cu_seqlen,
block_tables,
block_tables_ragged,
block_tables.shape[1],
BLOCK_SIZE=256,
)
else:
offset = 0
for i, (input_length, cache_length) in enumerate(
zip(input_lengths, cache_lengths)
):
seq_len = cache_length + input_length
block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len]
offset += seq_len
return block_tables_ragged
def copy_next_input_ids_inplace(
max_next_input_ids: int,
all_input_ids: torch.Tensor,
cache_lengths: torch.Tensor,
input_lengths: torch.Tensor,
prompt_lengths: torch.Tensor,
next_input_ids: torch.Tensor,
cu_accepted_ids: torch.Tensor,
):
def grid(meta):
return (
triton.cdiv(max_next_input_ids, meta["BLOCK_SIZE"]),
len(all_input_ids),
)
triton_copy_next_input_ids_inplace[grid](
all_input_ids,
cache_lengths,
input_lengths,
prompt_lengths,
next_input_ids,
cu_accepted_ids,
all_input_ids.shape[1],
BLOCK_SIZE=16,
)
def prepare_position_slot_ids(
max_input_length: int,
cache_lengths: torch.Tensor,
cu_seqlen: torch.Tensor,
cu_slots: torch.Tensor,
position_ids: torch.Tensor,
slot_indices: torch.Tensor,
):
def grid(meta):
return (
triton.cdiv(max_input_length, meta["BLOCK_SIZE"]),
len(cache_lengths),
)
triton_prepare_position_slot_ids[grid](
cache_lengths, cu_seqlen, cu_slots, position_ids, slot_indices, BLOCK_SIZE=256
)
def slots_filtering(
max_slots: int,
slots: torch.Tensor,
filtered_slots: torch.Tensor,
cu_slots: torch.Tensor,
slots_start: torch.Tensor,
):
def grid(meta):
return (
triton.cdiv(max_slots, meta["BLOCK_SIZE"]),
len(slots_start),
)
triton_slots_filtering[grid](
slots, filtered_slots, slots_start, cu_slots, BLOCK_SIZE=256
)
@triton.jit
def triton_slots_filtering(
# Inputs
slots_ptr,
filtered_slots_ptr,
slots_start_ptr,
cu_slots_ptr,
# Const values
BLOCK_SIZE: "tl.constexpr",
):
# Position in block_tables_ragged.numel() / BLOCK_SIZE
pid = tl.program_id(axis=0)
# Position in batch
bid = tl.program_id(axis=1)
block_start = pid * BLOCK_SIZE
block_arange = block_start + tl.arange(0, BLOCK_SIZE)
filter_start = tl.load(slots_start_ptr + bid)
slot_start = tl.load(cu_slots_ptr + bid)
slot_end = tl.load(cu_slots_ptr + bid + 1)
mask = (slot_start + block_arange) < slot_end
slots = tl.load(slots_ptr + filter_start + block_arange, mask=mask)
tl.store(filtered_slots_ptr + slot_start + block_arange, slots, mask=mask)
@triton.jit
def triton_block_tables_to_padded(
# Inputs
cu_seqlen_ptr,
# Outputs
block_tables_ptr,
block_tables_ragged_ptr,
# Stride
stride_block_tables,
# Const values
BLOCK_SIZE: "tl.constexpr",
):
# Position in block_tables_ragged.numel() / BLOCK_SIZE
pid = tl.program_id(axis=0)
# Position in batch
bid = tl.program_id(axis=1)
block_start = pid * BLOCK_SIZE
block_arange = block_start + tl.arange(0, BLOCK_SIZE)
seq_start = tl.load(cu_seqlen_ptr + bid)
seq_end = tl.load(cu_seqlen_ptr + bid + 1)
mask = (seq_start + block_arange) < seq_end
blocks = tl.load(block_tables_ragged_ptr + seq_start + block_arange, mask=mask)
tl.store(
block_tables_ptr + bid * stride_block_tables + block_arange, blocks, mask=mask
)
@triton.jit
def triton_block_tables_to_ragged(
# Inputs
cu_seqlen_ptr,
# Outputs
block_tables_ptr,
block_tables_ragged_ptr,
# Stride
stride_block_tables,
# Const values
BLOCK_SIZE: "tl.constexpr",
):
# Position in block_tables_ragged.numel() / BLOCK_SIZE
pid = tl.program_id(axis=0)
# Position in batch
bid = tl.program_id(axis=1)
block_start = pid * BLOCK_SIZE
block_arange = block_start + tl.arange(0, BLOCK_SIZE)
seq_start = tl.load(cu_seqlen_ptr + bid)
seq_end = tl.load(cu_seqlen_ptr + bid + 1)
mask = (seq_start + block_arange) < seq_end
blocks = tl.load(
block_tables_ptr + bid * stride_block_tables + block_arange, mask=mask
)
tl.store(block_tables_ragged_ptr + seq_start + block_arange, blocks, mask=mask)
@triton.jit
def triton_copy_next_input_ids_inplace(
# Inputs
all_input_ids_ptr,
cache_lengths_ptr,
input_lengths_ptr,
prompt_lengths_ptr,
next_input_ids_ptr,
cu_accepted_ids_ptr,
# Stride
stride_all_input_ids,
# Const values
BLOCK_SIZE: "tl.constexpr",
):
# Position in max_accepted_ids / BLOCK_SIZE
pid = tl.program_id(axis=0)
# Position in batch
bid = tl.program_id(axis=1)
block_start = pid * BLOCK_SIZE
block_arange = block_start + tl.arange(0, BLOCK_SIZE)
# Used for correctly indexing in all_input_ids
cache_length = tl.load(cache_lengths_ptr + bid)
input_length = tl.load(input_lengths_ptr + bid)
prompt_length = tl.load(prompt_lengths_ptr + bid)
# Start/End of next_input_ids for this request
next_input_ids_start = tl.load(cu_accepted_ids_ptr + bid)
next_input_ids_end = tl.load(cu_accepted_ids_ptr + bid + 1)
# Mask values out of range
mask = (next_input_ids_start + block_arange) < next_input_ids_end
# Mask values for request still prefilling
decode_mask = (cache_length + input_length + block_arange) >= prompt_length
mask = mask & decode_mask
# Load this request next input ids
next_input_ids = tl.load(
next_input_ids_ptr + next_input_ids_start + block_arange, mask=mask
)
# Store in all_input_ids, since it is a 2D tensor, apply stride * bid
tl.store(
all_input_ids_ptr
+ stride_all_input_ids * bid
+ cache_length
+ input_length
+ block_arange,
next_input_ids,
mask=mask,
)
@triton.jit
def triton_prepare_position_slot_ids(
# Inputs
cache_lengths_ptr,
cu_seqlen_ptr,
cu_slots_ptr,
# Outputs
position_ids_ptr,
slot_indices_ptr,
# Const values
BLOCK_SIZE: "tl.constexpr",
):
# Position in max_input_length / BLOCK_SIZE
pid = tl.program_id(axis=0)
# Position in batch
bid = tl.program_id(axis=1)
block_start = pid * BLOCK_SIZE
block_arange = block_start + tl.arange(0, BLOCK_SIZE)
cache_length = tl.load(cache_lengths_ptr + bid)
seq_start = tl.load(cu_seqlen_ptr + bid)
seq_end = tl.load(cu_seqlen_ptr + bid + 1)
slot_start = tl.load(cu_slots_ptr + bid)
mask = (seq_start + block_arange) < seq_end
tl.store(
position_ids_ptr + seq_start + block_arange,
cache_length + block_arange,
mask=mask,
)
tl.store(
slot_indices_ptr + seq_start + block_arange,
slot_start + cache_length + block_arange,
mask=mask,
)