feat: add triton kernels to decrease latency of large batches (#2687)

* feat: add triton kernels to decrease latency of large batches

* cast to int32

* fix kernel

* fix kernel

* disable triton on rocm

* fix speculation

* add slots filtering kernel
This commit is contained in:
OlivierDehaene 2024-10-25 23:10:00 +02:00 committed by GitHub
parent 0f346a3296
commit 6f88bd9390
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 649 additions and 194 deletions

View File

@ -71,6 +71,14 @@ from text_generation_server.utils.import_utils import (
synchronize,
get_free_memory,
)
from text_generation_server.models.metadata_kernels import (
has_triton,
copy_next_input_ids_inplace,
block_tables_to_ragged,
block_tables_to_padded,
prepare_position_slot_ids,
slots_filtering,
)
tracer = trace.get_tracer(__name__)
@ -147,8 +155,10 @@ class FlashCausalLMBatch(Batch):
# tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences
block_tables_tensor: torch.Tensor
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
# Will be set by `generate_token` and reset after each prefill forward before staying set in decode
slots: Optional[torch.Tensor]
slots: torch.Tensor
# list of length b + 1 containing the cumulative sequence slot lengths of the sequences in the batch
# used for filtering
cu_slots: torch.Tensor
max_input_length: int
max_current_length: int
@ -159,7 +169,7 @@ class FlashCausalLMBatch(Batch):
prefilling_mask: List[bool]
# Prefill metadata tensors to efficiently compute logprobs
# tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
# tensor of length b + 1 containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
cu_seqlen_prefill: Optional[torch.Tensor]
# Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers
# as we only keep SLIDING_WINDOW values instead of the whole tensor
@ -257,6 +267,8 @@ class FlashCausalLMBatch(Batch):
all_input_ids = []
all_postfix_ids = []
requests_idx_mapping = {}
slots = []
cu_slots = [0]
next_token_chooser_parameters = []
stopping_criterias = []
@ -268,7 +280,9 @@ class FlashCausalLMBatch(Batch):
max_length = 0
max_blocks = 0
cu_blocks = [0]
block_tables = []
block_tables_ragged = []
# Parse batch
for i, (r, tokenized_input) in enumerate(
@ -341,10 +355,21 @@ class FlashCausalLMBatch(Batch):
request_blocks = [
b for b in range(num_blocks, num_blocks + needed_blocks)
]
request_slots = [
s
for b in request_blocks
for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)
]
else:
request_blocks = r.blocks
request_slots = r.slots
block_tables.append(request_blocks)
block_tables_ragged.extend(request_blocks)
cu_blocks.append(len(block_tables_ragged))
slots.extend(request_slots)
cu_slots.append(len(slots))
cache_lengths.append(cache_length)
num_blocks += len(request_blocks)
@ -378,16 +403,34 @@ class FlashCausalLMBatch(Batch):
top_n_tokens, device=device, dtype=torch.int64
)
block_tables_tensor = torch.zeros(
(len(block_tables), max_blocks), dtype=torch.int32, device="cpu"
block_tables_ragged = torch.tensor(
block_tables_ragged, device=device, dtype=torch.int32
)
for i, request_blocks in enumerate(block_tables):
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
block_tables_tensor = block_tables_tensor.to(device)
cu_blocks = torch.tensor(cu_blocks, device=device, dtype=torch.int64)
block_tables_tensor = torch.empty(
(len(block_tables), max_blocks),
device=device,
dtype=torch.int32,
)
# If the device supports Triton, we can use a fused kernel
if has_triton():
block_tables_to_padded(
max_blocks, cu_blocks, block_tables_tensor, block_tables_ragged
)
else:
for i, request_blocks in enumerate(block_tables):
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(
request_blocks
)
prompt_lengths_tensor = torch.tensor(
prompt_lengths, dtype=torch.int32, device=device
)
slots = torch.tensor(slots, dtype=torch.int64, device=device)
cu_slots = torch.tensor(cu_slots, dtype=torch.int64)
return cls(
batch_id=pb.id,
requests=pb.requests,
@ -420,7 +463,8 @@ class FlashCausalLMBatch(Batch):
cu_seqlen_prefill=None,
prefill_cache_indices=None,
slot_indices=None,
slots=None,
slots=slots,
cu_slots=cu_slots,
prefill_head_indices=None,
prefill_next_token_indices=None,
prefill_cu_outlens=None,
@ -457,10 +501,11 @@ class FlashCausalLMBatch(Batch):
# Used to index into tensors
indices = []
# slots to keep after filtering
slot_filtering_indices = torch.zeros(
self.slots.shape[0], dtype=torch.bool, device=device
)
if not has_triton():
# slots to keep after filtering
slot_filtering_indices = torch.zeros(
self.slots.shape[0], dtype=torch.bool, device=device
)
# Create on CPU to only move to GPU once instead of at every copy
slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
@ -477,6 +522,7 @@ class FlashCausalLMBatch(Batch):
cache_lengths = []
prefix_offsets = []
read_offsets = []
cu_slots = [0]
prefilling_mask = []
prefill_logprob_tokens = []
@ -487,8 +533,8 @@ class FlashCausalLMBatch(Batch):
num_blocks = 0
max_blocks = 0
# Cumulative length
cumulative_max_length = 0
max_slots = 0
cumulative_slot_tokens = 0
for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[request_id]
@ -531,29 +577,27 @@ class FlashCausalLMBatch(Batch):
num_blocks += len(request_block_table)
block_tables.append(request_block_table)
start_slot = self.cu_slots[idx]
end_slot = self.cu_slots[idx + 1]
slot_length = end_slot - start_slot
if not has_triton():
# Set slice
slot_filtering_indices[start_slot:end_slot] = True
cu_slots.append(cumulative_slot_tokens + slot_length)
# Input ids if the request was part of a prefilling batch
# If the batch was decoding we can index into the tensor directly later
if self.prefilling:
input_ids.append(self.input_ids[idx])
else:
# Copy to tensor (CPU)
slot_indices[i] = cumulative_max_length
remaining_tokens = (
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
# Set slice
slot_filtering_indices[
self.slot_indices[idx] : self.slot_indices[idx]
+ request_input_length
+ remaining_tokens
- 1
] = True
cumulative_max_length += request_input_length + remaining_tokens - 1
slot_indices[i] = cumulative_slot_tokens + request_cache_length
cumulative_slot_tokens += slot_length
max_blocks = max(max_blocks, len(request_block_table))
max_slots = max(max_slots, slot_length)
all_input_ids_tensor = self.all_input_ids_tensor[indices]
block_tables_tensor = self.block_tables_tensor[indices]
@ -564,11 +608,22 @@ class FlashCausalLMBatch(Batch):
)
prompt_lengths_tensor = self.prompt_lengths_tensor[indices]
cu_slots = torch.tensor(cu_slots, dtype=torch.int64)
if not has_triton():
slots = self.slots[slot_filtering_indices]
else:
slots = self.slots.new_empty(cumulative_slot_tokens)
gpu_cu_slots = cu_slots.to(device)
slots_indexing_start = self.cu_slots.to(device)[indices]
slots_filtering(
max_slots, self.slots, slots, gpu_cu_slots, slots_indexing_start
)
if self.prefilling:
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
position_ids = None
slot_indices = None
slots = None
cache_lengths_tensor = None
input_lengths_tensor = None
adapter_meta = None
@ -578,7 +633,6 @@ class FlashCausalLMBatch(Batch):
position_ids = self.position_ids[indices]
adapter_indices = self.adapter_meta.adapter_indices[indices]
input_lengths_tensor = self.input_lengths_tensor[indices]
slots = self.slots[slot_filtering_indices]
cache_lengths_tensor = self.cache_lengths_tensor[indices]
# Move to GPU now that we have the whole tensor
@ -607,6 +661,7 @@ class FlashCausalLMBatch(Batch):
block_tables=block_tables,
block_tables_tensor=block_tables_tensor,
slots=slots,
cu_slots=cu_slots,
max_input_length=max_input_length,
max_current_length=max_current_length,
prefilling=self.prefilling,
@ -653,9 +708,7 @@ class FlashCausalLMBatch(Batch):
for b in batches:
total_batch_size += len(b)
max_blocks = max(max_blocks, b.max_blocks)
# If `b` is prefilling and was just filtered, `b.slots` is None
# `total_slots` is not used if any of the batches is prefilling
total_slots += len(b.slots) if not b.prefilling else 0
total_slots += len(b.slots)
num_blocks += b.num_blocks
speculative_length = (
b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
@ -675,11 +728,12 @@ class FlashCausalLMBatch(Batch):
)
prefilling = prefilling or b.prefilling
slots = batches[0].slots.new_empty(total_slots)
cu_slots = torch.zeros(total_batch_size + 1, dtype=torch.int64)
if prefilling:
input_ids = []
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
position_ids = None
slots = None
slot_indices = None
cache_lengths_tensor = None
input_lengths_tensor = None
@ -688,7 +742,6 @@ class FlashCausalLMBatch(Batch):
else:
input_ids = batches[0].input_ids.new_empty(total_batch_size)
position_ids = batches[0].position_ids.new_empty(total_batch_size)
slots = batches[0].slots.new_empty(total_slots)
slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
total_batch_size
@ -764,13 +817,16 @@ class FlashCausalLMBatch(Batch):
] = batch.block_tables_tensor[:, :max_blocks]
prompt_lengths_tensor[start_index:end_index] = batch.prompt_lengths_tensor
if not prefilling:
slots_start_index = cumulative_slots
slots_end_index = cumulative_slots + len(batch.slots)
slots_start_index = cumulative_slots
slots_end_index = cumulative_slots + len(batch.slots)
slots[slots_start_index:slots_end_index] = batch.slots
cu_slots[start_index + 1 : end_index + 1] = (
batch.cu_slots[1:] + cumulative_slots
)
if not prefilling:
input_ids[start_index:end_index] = batch.input_ids
position_ids[start_index:end_index] = batch.position_ids
slots[slots_start_index:slots_end_index] = batch.slots
slot_indices[start_index:end_index] = (
batch.slot_indices + cumulative_slots
)
@ -792,9 +848,6 @@ class FlashCausalLMBatch(Batch):
batch.adapter_meta.adapter_segments,
batch.adapter_meta.segment_indices,
)
# Update
cumulative_slots += len(batch.slots)
else:
if isinstance(batch.input_ids, torch.Tensor):
batch.input_ids = batch.input_ids.view(-1, 1).tolist()
@ -819,6 +872,7 @@ class FlashCausalLMBatch(Batch):
top_n_tokens.extend(batch.top_n_tokens)
# Update
cumulative_slots += len(batch.slots)
cumulative_batch_size += len(batch)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
@ -858,6 +912,7 @@ class FlashCausalLMBatch(Batch):
cache_lengths=cache_lengths,
cache_lengths_tensor=cache_lengths_tensor,
slots=slots,
cu_slots=cu_slots,
max_input_length=max_input_length,
max_current_length=max_current_length,
prefilling=prefilling,
@ -890,15 +945,50 @@ class FlashCausalLMBatch(Batch):
# it simplifies everything
assert self.speculative_ids is None
device = self.block_tables_tensor.device
if isinstance(self.input_ids, list):
if len(self) > 1:
input_ids = np.concatenate(self.input_ids, dtype=np.int64)
else:
input_ids = self.input_ids[0]
self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
self.input_lengths_tensor = torch.tensor(
self.input_lengths, dtype=torch.int32, device=device
)
self.cu_seqlen_prefill = torch.nn.functional.pad(
torch.cumsum(self.input_lengths_tensor, dim=0), (1, 0)
).to(torch.int32)
self.cache_lengths_tensor = torch.tensor(
self.cache_lengths, dtype=torch.int32, device=device
)
# If the device supports Triton, we can use a fused kernel
if has_triton():
self.position_ids = torch.empty(
len(self.input_ids), dtype=torch.int32, device=device
)
self.slot_indices = torch.empty(
len(self.input_ids), dtype=torch.int64, device=device
)
cu_slots_gpu = self.cu_slots.to(device)
prepare_position_slot_ids(
self.max_input_length,
self.cache_lengths_tensor,
self.cu_seqlen_prefill,
cu_slots_gpu,
self.position_ids,
self.slot_indices,
)
sliding_window = get_sliding_windows()
position_ids = []
cu_seqlen_prefill = [0]
slot_indices = []
prefill_cache_indices = []
all_prefill_logprobs = True
no_prefill_logprobs = True
prefill_head_indices = []
prefill_next_token_indices = []
prefill_cu_outlens = [0]
# Cumulative length
@ -906,7 +996,6 @@ class FlashCausalLMBatch(Batch):
cumulative_slot_tokens = 0
prefill_out_cumulative_length = 0
slots = []
adapter_indices_list = []
adapter_set = set()
@ -928,30 +1017,33 @@ class FlashCausalLMBatch(Batch):
)
):
next_chunk_length = input_length
# Position ids
request_position_ids = torch.arange(
cache_length, cache_length + input_length, dtype=torch.int32
)
position_ids.append(request_position_ids)
# Add cumulative lengths of all previous inputs
cu_seqlen_prefill.append(cumulative_length + input_length)
if not has_triton():
# Position ids
request_position_ids = torch.arange(
cache_length, cache_length + input_length, dtype=torch.int32
)
position_ids.append(request_position_ids)
if not r.slots:
request_slots = [
s
for b in blocks
for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)
]
else:
request_slots = r.slots
if not r.slots:
request_slots = [
s
for b in blocks
for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)
]
else:
request_slots = r.slots
request_slots = request_slots[cache_length:]
request_slot_indices = torch.arange(
cumulative_slot_tokens,
cumulative_slot_tokens + input_length,
dtype=torch.int64,
)
request_slot_indices = torch.arange(
cache_length + cumulative_slot_tokens,
cache_length + cumulative_slot_tokens + input_length,
dtype=torch.int64,
)
slot_indices.append(request_slot_indices)
# Update
cumulative_slot_tokens += len(request_slots)
# Create tensor to slice into the kv tensor in prefill
if sliding_window is not None:
@ -968,83 +1060,102 @@ class FlashCausalLMBatch(Batch):
no_prefill_logprobs = no_prefill_logprobs and not prefill_logprobs
if prefill_logprobs:
prefill_head_indices.append(
torch.arange(
cumulative_length,
cumulative_length + input_length,
dtype=torch.int64,
)
)
prefill_next_token_indices.append(
prefill_out_cumulative_length + input_length - 1
)
prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
prefill_out_cumulative_length += input_length
else:
prefill_head_indices.append(
torch.tensor(
[cumulative_length + input_length - 1],
dtype=torch.int64,
)
)
prefill_next_token_indices.append(prefill_out_cumulative_length)
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
prefill_out_cumulative_length += 1
slots.extend(request_slots)
slot_indices.append(request_slot_indices)
if sliding_window is not None:
prefill_cache_indices.append(request_prefill_cache_indices)
ADAPTER_TO_INDEX = get_adapter_to_index()
adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0)
adapter_indices_list.append(torch.full((next_chunk_length,), adapter_index))
adapter_set.add(adapter_index)
if ADAPTER_TO_INDEX:
adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0)
adapter_indices_list.append(
torch.full((next_chunk_length,), adapter_index)
)
adapter_set.add(adapter_index)
# Update
cumulative_length += next_chunk_length
cumulative_slot_tokens += len(request_slots)
device = self.block_tables_tensor.device
if not all_prefill_logprobs and not no_prefill_logprobs:
prefill_head_indices = []
prefill_next_token_indices = []
if isinstance(self.input_ids, list):
if len(self) > 1:
input_ids = np.concatenate(self.input_ids, dtype=np.int64)
else:
input_ids = self.input_ids[0]
self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
# Cumulative length
cumulative_length = 0
prefill_out_cumulative_length = 0
for i, (
r,
input_length,
request_prefilling,
) in enumerate(
zip(
self.requests,
self.input_lengths,
self.prefilling_mask,
)
):
# Prefill logprobs is ignored if the request is done prefilling
prefill_logprobs = r.prefill_logprobs and request_prefilling
if prefill_logprobs:
prefill_head_indices.append(
torch.arange(
cumulative_length,
cumulative_length + input_length,
dtype=torch.int64,
)
)
prefill_next_token_indices.append(
prefill_out_cumulative_length + input_length - 1
)
prefill_out_cumulative_length += input_length
else:
prefill_head_indices.append(
torch.tensor(
[cumulative_length + input_length - 1],
dtype=torch.int64,
)
)
prefill_next_token_indices.append(prefill_out_cumulative_length)
prefill_out_cumulative_length += 1
# Update
cumulative_length += input_length
if len(self) > 1:
position_ids = torch.cat(position_ids)
slot_indices = torch.cat(slot_indices)
if position_ids:
position_ids = torch.cat(position_ids)
if slot_indices:
slot_indices = torch.cat(slot_indices)
if sliding_window is not None:
prefill_cache_indices = torch.cat(prefill_cache_indices)
else:
position_ids = position_ids[0]
slot_indices = slot_indices[0]
if position_ids:
position_ids = position_ids[0]
if slot_indices:
slot_indices = slot_indices[0]
if sliding_window is not None:
prefill_cache_indices = prefill_cache_indices[0]
if not has_triton():
self.position_ids = position_ids.to(device)
self.slot_indices = slot_indices.to(device)
self.prefill_cu_outlens = prefill_cu_outlens
cu_seqlen_prefill = torch.tensor(
cu_seqlen_prefill, device=device, dtype=torch.int32
)
self.cu_seqlen_prefill = cu_seqlen_prefill
self.position_ids = position_ids.to(device)
self.slot_indices = slot_indices.to(device)
self.prefill_cache_indices = (
prefill_cache_indices.to(device) if sliding_window is not None else None
)
self.input_lengths_tensor = torch.tensor(
self.input_lengths, dtype=torch.int32, device=device
)
if all_prefill_logprobs:
prefill_head_indices = None
prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
prefill_next_token_indices = self.cu_seqlen_prefill[1:] - 1
elif no_prefill_logprobs:
prefill_head_indices = cu_seqlen_prefill[1:] - 1
prefill_head_indices = self.cu_seqlen_prefill[1:] - 1
prefill_next_token_indices = None
else:
prefill_head_indices = torch.cat(prefill_head_indices).to(device)
@ -1054,17 +1165,21 @@ class FlashCausalLMBatch(Batch):
self.prefill_head_indices = prefill_head_indices
self.prefill_next_token_indices = prefill_next_token_indices
self.slots = torch.tensor(slots, dtype=torch.int64, device=device)
self.cache_lengths_tensor = torch.tensor(
self.cache_lengths, dtype=torch.int32, device=device
)
adapter_indices = torch.cat(adapter_indices_list).to(
dtype=torch.int64, device=device
)
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
if adapter_set:
adapter_indices = torch.cat(adapter_indices_list).to(
dtype=torch.int64, device=device
)
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
else:
adapter_indices = torch.zeros_like(self.input_ids)
adapter_segments = [0, len(adapter_indices)]
adapter_segment_indices = [len(adapter_indices) - 1]
adapter_segments = torch.tensor(
adapter_segments, dtype=torch.int32, device=device
)
self.adapter_meta = AdapterBatchMetadata(
adapter_indices=adapter_indices,
adapter_set=adapter_set,
@ -1288,6 +1403,9 @@ class FlashCausalLM(Model):
block_tables=block_tables,
input_lengths=input_lengths,
cache_lengths=cache_lengths,
input_lengths_tensor=input_lengths_tensor,
cache_lengths_tensor=cache_lengths_tensor,
max_current_length=max_s,
)
from text_generation_server.layers.attention.flashinfer import (
create_decode_state_cuda_graphs,
@ -1621,6 +1739,9 @@ class FlashCausalLM(Model):
block_tables=block_tables,
input_lengths=batch.input_lengths,
cache_lengths=batch.cache_lengths,
input_lengths_tensor=batch.input_lengths_tensor,
cache_lengths_tensor=batch.cache_lengths_tensor,
max_current_length=batch.max_current_length,
)
with self._forward_context(
block_tables=block_tables,
@ -1661,6 +1782,9 @@ class FlashCausalLM(Model):
block_tables=block_tables,
input_lengths=batch.input_lengths,
cache_lengths=batch.cache_lengths,
input_lengths_tensor=batch.input_lengths_tensor,
cache_lengths_tensor=batch.cache_lengths_tensor,
max_current_length=batch.max_current_length,
)
# assert block_tables.shape[0] >= slots.shape[0]
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
@ -1756,7 +1880,6 @@ class FlashCausalLM(Model):
else:
prefill_logprobs = None
next_token_logits = out
next_adapter_indices = batch.adapter_meta.adapter_indices
finished_prefilling = True
next_chunk_lengths = []
@ -1827,13 +1950,12 @@ class FlashCausalLM(Model):
# Since we are done prefilling, all the tensors that were concatenating values for all the requests
# instantly become of shape [BATCH_SIZE]
if prefill and finished_prefilling:
next_position_ids = batch.position_ids.new_empty(len(batch))
batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty(
len(batch)
)
elif not prefill:
next_position_ids = batch.position_ids
indices = batch.cu_seqlen_prefill[1:] - 1
batch.position_ids = batch.position_ids[indices]
batch.slot_indices = batch.slot_indices[indices]
batch.adapter_meta.adapter_indices = batch.adapter_meta.adapter_indices[
indices
]
# Zipped iterator
iterator = zip(
@ -1852,8 +1974,10 @@ class FlashCausalLM(Model):
# It is faster if we delay this sync for the maximum amount of time
# For each member of the batch
index = 0
# Cumulative length
cu_accepted_ids = torch.nn.functional.pad(
torch.cumsum(accepted_ids, dim=0), (1, 0)
)
cumulative_length = 0
for i, (
request,
@ -1865,21 +1989,6 @@ class FlashCausalLM(Model):
request_was_prefilling,
request_is_prefilling,
) in enumerate(iterator):
if prefill and finished_prefilling:
# Indexing metadata
_start_index = cumulative_length
end_index = cumulative_length + input_length
# Initialize position_ids
# In decode, we do not need this as we can just increment position ids
next_position_ids[i] = batch.position_ids[end_index - 1]
# Initialize adapter indices
# In decode, we only have one token per row in the batch, so grab last index
next_adapter_indices[i] = batch.adapter_meta.adapter_indices[
end_index - 1
]
# Used to gather prefill logprobs
# Copy batch.all_input_ids_tensor to prefill_token_indices
if request.prefill_logprobs and request_was_prefilling:
@ -1898,25 +2007,39 @@ class FlashCausalLM(Model):
# Set prefill_tokens_indices to the correct slice
prefill_tokens_indices = ids
if not request_is_prefilling:
# If the device does not support triton, we copy one by one
if not request_is_prefilling and not has_triton():
# Only save tokens if we are done prefilling for this request
for j in range(n_accepted_ids):
batch.all_input_ids_tensor[i, cache_length + input_length + j] = (
next_input_ids[index + j]
)
index += n_accepted_ids
batch.all_input_ids_tensor[
i,
batch.cache_lengths_tensor[i]
+ batch.input_lengths[i] : batch.cache_lengths_tensor[i]
+ batch.input_lengths[i]
+ accepted_ids[i],
] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]]
cumulative_length += input_length
# If the device support triton, we can use a fused kernel
if has_triton():
copy_next_input_ids_inplace(
speculate + 1,
batch.all_input_ids_tensor,
batch.cache_lengths_tensor,
batch.input_lengths_tensor,
batch.prompt_lengths_tensor,
next_input_ids,
cu_accepted_ids,
)
# Update values
# These values can be updated without a GPU -> CPU sync
if not prefill or (prefill and finished_prefilling):
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1]
batch.speculative_ids = speculative_ids
batch.position_ids = next_position_ids + accepted_ids
batch.cache_lengths_tensor += batch.input_lengths_tensor
batch.input_lengths_tensor = accepted_ids.to(dtype=torch.int32)
batch.position_ids += accepted_ids
batch.cache_lengths_tensor += batch.input_lengths_tensor + accepted_ids - 1
batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor)
batch.slot_indices += accepted_ids
batch.adapter_meta.adapter_indices = next_adapter_indices
if prefill and prefill_logprobs:
# Get prefill logprobs with inplace softmax (avoid copying the `out` tensor (max_batch_prefill_tokens * vocab_size))
@ -2093,8 +2216,10 @@ class FlashCausalLM(Model):
# processing
stopped = False
new_input_length = next_chunk_lengths[i]
new_cache_length = cache_length + input_length
else:
new_input_length = n_accepted_ids
new_input_length = 1
new_cache_length = cache_length + input_length + n_accepted_ids - 1
# Append next token to all tokens
next_token_texts = []
left = 0
@ -2206,12 +2331,10 @@ class FlashCausalLM(Model):
# Update values
index += n_accepted_ids
current_cache_length = cache_length + input_length
batch.cache_lengths[i] = current_cache_length
current_input_length = new_input_length
batch.max_input_length = max(batch.max_input_length, current_input_length)
batch.input_lengths[i] = current_input_length
current_length = current_cache_length + current_input_length
batch.cache_lengths[i] = new_cache_length
batch.max_input_length = max(batch.max_input_length, new_input_length)
batch.input_lengths[i] = new_input_length
current_length = new_cache_length + new_input_length
batch.max_current_length = max(batch.max_current_length, current_length)
batch.prefix_offsets[i] = prefix_offset
@ -2258,11 +2381,6 @@ class FlashCausalLM(Model):
state=(
state if state is not None else self.prefill_with_paged_kv_state
),
# block_tables=block_tables_to_ragged(
# block_tables=block_tables,
# input_lengths=input_lengths,
# cache_lengths=cache_lengths,
# ),
block_tables=block_tables,
cu_seqlens=cu_seqlen_prefill,
input_lengths=input_lengths_tensor + cache_lengths_tensor,
@ -2287,23 +2405,3 @@ class FlashCausalLM(Model):
dtype=self.dtype,
window_left=self.sliding_window,
)
def block_tables_to_ragged(
*, block_tables: torch.Tensor, input_lengths: List[int], cache_lengths: List[int]
) -> torch.Tensor:
"""Convert block table to ragged format compatible with FlashInfer."""
assert len(input_lengths) == len(cache_lengths)
total_len = sum(input_lengths) + sum(cache_lengths)
block_tables_ragged = torch.empty(
total_len, dtype=torch.int32, device=block_tables.device
)
offset = 0
for i, (input_length, cache_length) in enumerate(zip(input_lengths, cache_lengths)):
seq_len = cache_length + input_length
block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len]
offset += seq_len
return block_tables_ragged

View File

@ -0,0 +1,347 @@
import torch
import triton
import triton.language as tl
from loguru import logger
from typing import List, Optional
from torch.utils._triton import has_triton as has_triton_torch
from text_generation_server.utils.import_utils import (
SYSTEM,
)
from text_generation_server.utils.log import log_master
_HAS_TRITON: Optional[bool] = None
def has_triton():
global _HAS_TRITON
if _HAS_TRITON is None:
# FIXME: it seems that has_triton_torch is bugged on RocM
# For now, only accept cuda
_HAS_TRITON = has_triton_torch() if SYSTEM == "cuda" else False
if _HAS_TRITON:
log_master(logger.info, "Using optimized Triton indexing kernels.")
return _HAS_TRITON
def block_tables_to_padded(
max_blocks: int,
cu_seqlen: torch.Tensor,
block_tables: torch.Tensor,
block_tables_ragged: torch.Tensor,
):
def grid(meta):
return (
triton.cdiv(max_blocks, meta["BLOCK_SIZE"]),
len(block_tables),
)
triton_block_tables_to_padded[grid](
cu_seqlen,
block_tables,
block_tables_ragged,
block_tables.shape[1],
BLOCK_SIZE=256,
)
def block_tables_to_ragged(
*,
block_tables: torch.Tensor,
input_lengths: List[int],
cache_lengths: List[int],
input_lengths_tensor: torch.Tensor,
cache_lengths_tensor: torch.Tensor,
max_current_length: int
) -> torch.Tensor:
"""Convert block table to ragged format compatible with FlashInfer."""
assert len(input_lengths) == len(cache_lengths)
total_len = sum(input_lengths) + sum(cache_lengths)
block_tables_ragged = torch.empty(
total_len, dtype=torch.int32, device=block_tables.device
)
if has_triton():
cu_seqlen = torch.nn.functional.pad(
torch.cumsum(input_lengths_tensor + cache_lengths_tensor, dim=0), (1, 0)
)
def grid(meta):
return (
triton.cdiv(max_current_length, meta["BLOCK_SIZE"]),
len(cache_lengths),
)
triton_block_tables_to_ragged[grid](
cu_seqlen,
block_tables,
block_tables_ragged,
block_tables.shape[1],
BLOCK_SIZE=256,
)
else:
offset = 0
for i, (input_length, cache_length) in enumerate(
zip(input_lengths, cache_lengths)
):
seq_len = cache_length + input_length
block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len]
offset += seq_len
return block_tables_ragged
def copy_next_input_ids_inplace(
max_next_input_ids: int,
all_input_ids: torch.Tensor,
cache_lengths: torch.Tensor,
input_lengths: torch.Tensor,
prompt_lengths: torch.Tensor,
next_input_ids: torch.Tensor,
cu_accepted_ids: torch.Tensor,
):
def grid(meta):
return (
triton.cdiv(max_next_input_ids, meta["BLOCK_SIZE"]),
len(all_input_ids),
)
triton_copy_next_input_ids_inplace[grid](
all_input_ids,
cache_lengths,
input_lengths,
prompt_lengths,
next_input_ids,
cu_accepted_ids,
all_input_ids.shape[1],
BLOCK_SIZE=16,
)
def prepare_position_slot_ids(
max_input_length: int,
cache_lengths: torch.Tensor,
cu_seqlen: torch.Tensor,
cu_slots: torch.Tensor,
position_ids: torch.Tensor,
slot_indices: torch.Tensor,
):
def grid(meta):
return (
triton.cdiv(max_input_length, meta["BLOCK_SIZE"]),
len(cache_lengths),
)
triton_prepare_position_slot_ids[grid](
cache_lengths, cu_seqlen, cu_slots, position_ids, slot_indices, BLOCK_SIZE=256
)
def slots_filtering(
max_slots: int,
slots: torch.Tensor,
filtered_slots: torch.Tensor,
cu_slots: torch.Tensor,
slots_start: torch.Tensor,
):
def grid(meta):
return (
triton.cdiv(max_slots, meta["BLOCK_SIZE"]),
len(slots_start),
)
triton_slots_filtering[grid](
slots, filtered_slots, slots_start, cu_slots, BLOCK_SIZE=256
)
@triton.jit
def triton_slots_filtering(
# Inputs
slots_ptr,
filtered_slots_ptr,
slots_start_ptr,
cu_slots_ptr,
# Const values
BLOCK_SIZE: "tl.constexpr",
):
# Position in block_tables_ragged.numel() / BLOCK_SIZE
pid = tl.program_id(axis=0)
# Position in batch
bid = tl.program_id(axis=1)
block_start = pid * BLOCK_SIZE
block_arange = block_start + tl.arange(0, BLOCK_SIZE)
filter_start = tl.load(slots_start_ptr + bid)
slot_start = tl.load(cu_slots_ptr + bid)
slot_end = tl.load(cu_slots_ptr + bid + 1)
mask = (slot_start + block_arange) < slot_end
slots = tl.load(slots_ptr + filter_start + block_arange, mask=mask)
tl.store(filtered_slots_ptr + slot_start + block_arange, slots, mask=mask)
@triton.jit
def triton_block_tables_to_padded(
# Inputs
cu_seqlen_ptr,
# Outputs
block_tables_ptr,
block_tables_ragged_ptr,
# Stride
stride_block_tables,
# Const values
BLOCK_SIZE: "tl.constexpr",
):
# Position in block_tables_ragged.numel() / BLOCK_SIZE
pid = tl.program_id(axis=0)
# Position in batch
bid = tl.program_id(axis=1)
block_start = pid * BLOCK_SIZE
block_arange = block_start + tl.arange(0, BLOCK_SIZE)
seq_start = tl.load(cu_seqlen_ptr + bid)
seq_end = tl.load(cu_seqlen_ptr + bid + 1)
mask = (seq_start + block_arange) < seq_end
blocks = tl.load(block_tables_ragged_ptr + seq_start + block_arange, mask=mask)
tl.store(
block_tables_ptr + bid * stride_block_tables + block_arange, blocks, mask=mask
)
@triton.jit
def triton_block_tables_to_ragged(
# Inputs
cu_seqlen_ptr,
# Outputs
block_tables_ptr,
block_tables_ragged_ptr,
# Stride
stride_block_tables,
# Const values
BLOCK_SIZE: "tl.constexpr",
):
# Position in block_tables_ragged.numel() / BLOCK_SIZE
pid = tl.program_id(axis=0)
# Position in batch
bid = tl.program_id(axis=1)
block_start = pid * BLOCK_SIZE
block_arange = block_start + tl.arange(0, BLOCK_SIZE)
seq_start = tl.load(cu_seqlen_ptr + bid)
seq_end = tl.load(cu_seqlen_ptr + bid + 1)
mask = (seq_start + block_arange) < seq_end
blocks = tl.load(
block_tables_ptr + bid * stride_block_tables + block_arange, mask=mask
)
tl.store(block_tables_ragged_ptr + seq_start + block_arange, blocks, mask=mask)
@triton.jit
def triton_copy_next_input_ids_inplace(
# Inputs
all_input_ids_ptr,
cache_lengths_ptr,
input_lengths_ptr,
prompt_lengths_ptr,
next_input_ids_ptr,
cu_accepted_ids_ptr,
# Stride
stride_all_input_ids,
# Const values
BLOCK_SIZE: "tl.constexpr",
):
# Position in max_accepted_ids / BLOCK_SIZE
pid = tl.program_id(axis=0)
# Position in batch
bid = tl.program_id(axis=1)
block_start = pid * BLOCK_SIZE
block_arange = block_start + tl.arange(0, BLOCK_SIZE)
# Used for correctly indexing in all_input_ids
cache_length = tl.load(cache_lengths_ptr + bid)
input_length = tl.load(input_lengths_ptr + bid)
prompt_length = tl.load(prompt_lengths_ptr + bid)
# Start/End of next_input_ids for this request
next_input_ids_start = tl.load(cu_accepted_ids_ptr + bid)
next_input_ids_end = tl.load(cu_accepted_ids_ptr + bid + 1)
# Mask values out of range
mask = (next_input_ids_start + block_arange) < next_input_ids_end
# Mask values for request still prefilling
decode_mask = (cache_length + input_length + block_arange) >= prompt_length
mask = mask & decode_mask
# Load this request next input ids
next_input_ids = tl.load(
next_input_ids_ptr + next_input_ids_start + block_arange, mask=mask
)
# Store in all_input_ids, since it is a 2D tensor, apply stride * bid
tl.store(
all_input_ids_ptr
+ stride_all_input_ids * bid
+ cache_length
+ input_length
+ block_arange,
next_input_ids,
mask=mask,
)
@triton.jit
def triton_prepare_position_slot_ids(
# Inputs
cache_lengths_ptr,
cu_seqlen_ptr,
cu_slots_ptr,
# Outputs
position_ids_ptr,
slot_indices_ptr,
# Const values
BLOCK_SIZE: "tl.constexpr",
):
# Position in max_input_length / BLOCK_SIZE
pid = tl.program_id(axis=0)
# Position in batch
bid = tl.program_id(axis=1)
block_start = pid * BLOCK_SIZE
block_arange = block_start + tl.arange(0, BLOCK_SIZE)
cache_length = tl.load(cache_lengths_ptr + bid)
seq_start = tl.load(cu_seqlen_ptr + bid)
seq_end = tl.load(cu_seqlen_ptr + bid + 1)
slot_start = tl.load(cu_slots_ptr + bid)
mask = (seq_start + block_arange) < seq_end
tl.store(
position_ids_ptr + seq_start + block_arange,
cache_length + block_arange,
mask=mask,
)
tl.store(
slot_indices_ptr + seq_start + block_arange,
slot_start + cache_length + block_arange,
mask=mask,
)

View File

@ -14,11 +14,9 @@ from transformers import (
from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch, VlmCausalLM
from text_generation_server.pb import generate_pb2
from text_generation_server.models.flash_causal_lm import (
block_tables_to_ragged,
)
from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION
from text_generation_server.layers.attention import Seqlen
from text_generation_server.models.metadata_kernels import block_tables_to_ragged
tracer = trace.get_tracer(__name__)
@ -283,6 +281,9 @@ class MllamaCausalLM(VlmCausalLM):
block_tables=block_tables,
input_lengths=batch.input_lengths,
cache_lengths=batch.cache_lengths,
input_lengths_tensor=batch.input_lengths_tensor,
cache_lengths_tensor=batch.cache_lengths_tensor,
max_current_length=batch.max_current_length,
)
with self._forward_context(
block_tables=block_tables,
@ -338,6 +339,9 @@ class MllamaCausalLM(VlmCausalLM):
block_tables=block_tables,
input_lengths=batch.input_lengths,
cache_lengths=batch.cache_lengths,
input_lengths_tensor=batch.input_lengths_tensor,
cache_lengths_tensor=batch.cache_lengths_tensor,
max_current_length=batch.max_current_length,
)
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
else:

View File

@ -11,12 +11,12 @@ from text_generation_server.pb import generate_pb2
from text_generation_server.models.flash_causal_lm import (
FlashCausalLMBatch,
FlashCausalLM,
block_tables_to_ragged,
)
from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION
from text_generation_server.utils.log import log_master
from transformers import AutoProcessor
from text_generation_server.layers.attention import Seqlen
from text_generation_server.models.metadata_kernels import block_tables_to_ragged
tracer = trace.get_tracer(__name__)
@ -363,6 +363,9 @@ class VlmCausalLM(FlashCausalLM):
block_tables=block_tables,
input_lengths=batch.input_lengths,
cache_lengths=batch.cache_lengths,
input_lengths_tensor=batch.input_lengths_tensor,
cache_lengths_tensor=batch.cache_lengths_tensor,
max_current_length=batch.max_current_length,
)
with self._forward_context(
block_tables=block_tables,
@ -411,6 +414,9 @@ class VlmCausalLM(FlashCausalLM):
block_tables=block_tables,
input_lengths=batch.input_lengths,
cache_lengths=batch.cache_lengths,
input_lengths_tensor=batch.input_lengths_tensor,
cache_lengths_tensor=batch.cache_lengths_tensor,
max_current_length=batch.max_current_length,
)
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
else: