feat: add triton kernels to decrease latency of large batches (#2687)
* feat: add triton kernels to decrease latency of large batches * cast to int32 * fix kernel * fix kernel * disable triton on rocm * fix speculation * add slots filtering kernel
This commit is contained in:
parent
0f346a3296
commit
6f88bd9390
|
@ -71,6 +71,14 @@ from text_generation_server.utils.import_utils import (
|
|||
synchronize,
|
||||
get_free_memory,
|
||||
)
|
||||
from text_generation_server.models.metadata_kernels import (
|
||||
has_triton,
|
||||
copy_next_input_ids_inplace,
|
||||
block_tables_to_ragged,
|
||||
block_tables_to_padded,
|
||||
prepare_position_slot_ids,
|
||||
slots_filtering,
|
||||
)
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
@ -147,8 +155,10 @@ class FlashCausalLMBatch(Batch):
|
|||
# tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences
|
||||
block_tables_tensor: torch.Tensor
|
||||
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
|
||||
# Will be set by `generate_token` and reset after each prefill forward before staying set in decode
|
||||
slots: Optional[torch.Tensor]
|
||||
slots: torch.Tensor
|
||||
# list of length b + 1 containing the cumulative sequence slot lengths of the sequences in the batch
|
||||
# used for filtering
|
||||
cu_slots: torch.Tensor
|
||||
|
||||
max_input_length: int
|
||||
max_current_length: int
|
||||
|
@ -159,7 +169,7 @@ class FlashCausalLMBatch(Batch):
|
|||
prefilling_mask: List[bool]
|
||||
|
||||
# Prefill metadata tensors to efficiently compute logprobs
|
||||
# tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
|
||||
# tensor of length b + 1 containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
|
||||
cu_seqlen_prefill: Optional[torch.Tensor]
|
||||
# Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers
|
||||
# as we only keep SLIDING_WINDOW values instead of the whole tensor
|
||||
|
@ -257,6 +267,8 @@ class FlashCausalLMBatch(Batch):
|
|||
all_input_ids = []
|
||||
all_postfix_ids = []
|
||||
requests_idx_mapping = {}
|
||||
slots = []
|
||||
cu_slots = [0]
|
||||
|
||||
next_token_chooser_parameters = []
|
||||
stopping_criterias = []
|
||||
|
@ -268,7 +280,9 @@ class FlashCausalLMBatch(Batch):
|
|||
max_length = 0
|
||||
max_blocks = 0
|
||||
|
||||
cu_blocks = [0]
|
||||
block_tables = []
|
||||
block_tables_ragged = []
|
||||
|
||||
# Parse batch
|
||||
for i, (r, tokenized_input) in enumerate(
|
||||
|
@ -341,10 +355,21 @@ class FlashCausalLMBatch(Batch):
|
|||
request_blocks = [
|
||||
b for b in range(num_blocks, num_blocks + needed_blocks)
|
||||
]
|
||||
request_slots = [
|
||||
s
|
||||
for b in request_blocks
|
||||
for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)
|
||||
]
|
||||
else:
|
||||
request_blocks = r.blocks
|
||||
request_slots = r.slots
|
||||
|
||||
block_tables.append(request_blocks)
|
||||
block_tables_ragged.extend(request_blocks)
|
||||
cu_blocks.append(len(block_tables_ragged))
|
||||
|
||||
slots.extend(request_slots)
|
||||
cu_slots.append(len(slots))
|
||||
|
||||
cache_lengths.append(cache_length)
|
||||
num_blocks += len(request_blocks)
|
||||
|
@ -378,16 +403,34 @@ class FlashCausalLMBatch(Batch):
|
|||
top_n_tokens, device=device, dtype=torch.int64
|
||||
)
|
||||
|
||||
block_tables_tensor = torch.zeros(
|
||||
(len(block_tables), max_blocks), dtype=torch.int32, device="cpu"
|
||||
block_tables_ragged = torch.tensor(
|
||||
block_tables_ragged, device=device, dtype=torch.int32
|
||||
)
|
||||
for i, request_blocks in enumerate(block_tables):
|
||||
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
|
||||
block_tables_tensor = block_tables_tensor.to(device)
|
||||
cu_blocks = torch.tensor(cu_blocks, device=device, dtype=torch.int64)
|
||||
block_tables_tensor = torch.empty(
|
||||
(len(block_tables), max_blocks),
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
# If the device supports Triton, we can use a fused kernel
|
||||
if has_triton():
|
||||
block_tables_to_padded(
|
||||
max_blocks, cu_blocks, block_tables_tensor, block_tables_ragged
|
||||
)
|
||||
else:
|
||||
for i, request_blocks in enumerate(block_tables):
|
||||
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(
|
||||
request_blocks
|
||||
)
|
||||
|
||||
prompt_lengths_tensor = torch.tensor(
|
||||
prompt_lengths, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
slots = torch.tensor(slots, dtype=torch.int64, device=device)
|
||||
cu_slots = torch.tensor(cu_slots, dtype=torch.int64)
|
||||
|
||||
return cls(
|
||||
batch_id=pb.id,
|
||||
requests=pb.requests,
|
||||
|
@ -420,7 +463,8 @@ class FlashCausalLMBatch(Batch):
|
|||
cu_seqlen_prefill=None,
|
||||
prefill_cache_indices=None,
|
||||
slot_indices=None,
|
||||
slots=None,
|
||||
slots=slots,
|
||||
cu_slots=cu_slots,
|
||||
prefill_head_indices=None,
|
||||
prefill_next_token_indices=None,
|
||||
prefill_cu_outlens=None,
|
||||
|
@ -457,10 +501,11 @@ class FlashCausalLMBatch(Batch):
|
|||
# Used to index into tensors
|
||||
indices = []
|
||||
|
||||
# slots to keep after filtering
|
||||
slot_filtering_indices = torch.zeros(
|
||||
self.slots.shape[0], dtype=torch.bool, device=device
|
||||
)
|
||||
if not has_triton():
|
||||
# slots to keep after filtering
|
||||
slot_filtering_indices = torch.zeros(
|
||||
self.slots.shape[0], dtype=torch.bool, device=device
|
||||
)
|
||||
|
||||
# Create on CPU to only move to GPU once instead of at every copy
|
||||
slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
|
||||
|
@ -477,6 +522,7 @@ class FlashCausalLMBatch(Batch):
|
|||
cache_lengths = []
|
||||
prefix_offsets = []
|
||||
read_offsets = []
|
||||
cu_slots = [0]
|
||||
|
||||
prefilling_mask = []
|
||||
prefill_logprob_tokens = []
|
||||
|
@ -487,8 +533,8 @@ class FlashCausalLMBatch(Batch):
|
|||
|
||||
num_blocks = 0
|
||||
max_blocks = 0
|
||||
# Cumulative length
|
||||
cumulative_max_length = 0
|
||||
max_slots = 0
|
||||
cumulative_slot_tokens = 0
|
||||
|
||||
for i, request_id in enumerate(request_ids):
|
||||
idx = self.requests_idx_mapping[request_id]
|
||||
|
@ -531,29 +577,27 @@ class FlashCausalLMBatch(Batch):
|
|||
num_blocks += len(request_block_table)
|
||||
block_tables.append(request_block_table)
|
||||
|
||||
start_slot = self.cu_slots[idx]
|
||||
end_slot = self.cu_slots[idx + 1]
|
||||
slot_length = end_slot - start_slot
|
||||
|
||||
if not has_triton():
|
||||
# Set slice
|
||||
slot_filtering_indices[start_slot:end_slot] = True
|
||||
|
||||
cu_slots.append(cumulative_slot_tokens + slot_length)
|
||||
|
||||
# Input ids if the request was part of a prefilling batch
|
||||
# If the batch was decoding we can index into the tensor directly later
|
||||
if self.prefilling:
|
||||
input_ids.append(self.input_ids[idx])
|
||||
else:
|
||||
# Copy to tensor (CPU)
|
||||
slot_indices[i] = cumulative_max_length
|
||||
|
||||
remaining_tokens = (
|
||||
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
||||
)
|
||||
|
||||
# Set slice
|
||||
slot_filtering_indices[
|
||||
self.slot_indices[idx] : self.slot_indices[idx]
|
||||
+ request_input_length
|
||||
+ remaining_tokens
|
||||
- 1
|
||||
] = True
|
||||
|
||||
cumulative_max_length += request_input_length + remaining_tokens - 1
|
||||
slot_indices[i] = cumulative_slot_tokens + request_cache_length
|
||||
|
||||
cumulative_slot_tokens += slot_length
|
||||
max_blocks = max(max_blocks, len(request_block_table))
|
||||
max_slots = max(max_slots, slot_length)
|
||||
|
||||
all_input_ids_tensor = self.all_input_ids_tensor[indices]
|
||||
block_tables_tensor = self.block_tables_tensor[indices]
|
||||
|
@ -564,11 +608,22 @@ class FlashCausalLMBatch(Batch):
|
|||
)
|
||||
prompt_lengths_tensor = self.prompt_lengths_tensor[indices]
|
||||
|
||||
cu_slots = torch.tensor(cu_slots, dtype=torch.int64)
|
||||
|
||||
if not has_triton():
|
||||
slots = self.slots[slot_filtering_indices]
|
||||
else:
|
||||
slots = self.slots.new_empty(cumulative_slot_tokens)
|
||||
gpu_cu_slots = cu_slots.to(device)
|
||||
slots_indexing_start = self.cu_slots.to(device)[indices]
|
||||
slots_filtering(
|
||||
max_slots, self.slots, slots, gpu_cu_slots, slots_indexing_start
|
||||
)
|
||||
|
||||
if self.prefilling:
|
||||
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
|
||||
position_ids = None
|
||||
slot_indices = None
|
||||
slots = None
|
||||
cache_lengths_tensor = None
|
||||
input_lengths_tensor = None
|
||||
adapter_meta = None
|
||||
|
@ -578,7 +633,6 @@ class FlashCausalLMBatch(Batch):
|
|||
position_ids = self.position_ids[indices]
|
||||
adapter_indices = self.adapter_meta.adapter_indices[indices]
|
||||
input_lengths_tensor = self.input_lengths_tensor[indices]
|
||||
slots = self.slots[slot_filtering_indices]
|
||||
cache_lengths_tensor = self.cache_lengths_tensor[indices]
|
||||
|
||||
# Move to GPU now that we have the whole tensor
|
||||
|
@ -607,6 +661,7 @@ class FlashCausalLMBatch(Batch):
|
|||
block_tables=block_tables,
|
||||
block_tables_tensor=block_tables_tensor,
|
||||
slots=slots,
|
||||
cu_slots=cu_slots,
|
||||
max_input_length=max_input_length,
|
||||
max_current_length=max_current_length,
|
||||
prefilling=self.prefilling,
|
||||
|
@ -653,9 +708,7 @@ class FlashCausalLMBatch(Batch):
|
|||
for b in batches:
|
||||
total_batch_size += len(b)
|
||||
max_blocks = max(max_blocks, b.max_blocks)
|
||||
# If `b` is prefilling and was just filtered, `b.slots` is None
|
||||
# `total_slots` is not used if any of the batches is prefilling
|
||||
total_slots += len(b.slots) if not b.prefilling else 0
|
||||
total_slots += len(b.slots)
|
||||
num_blocks += b.num_blocks
|
||||
speculative_length = (
|
||||
b.speculative_ids.shape[1] if b.speculative_ids is not None else 0
|
||||
|
@ -675,11 +728,12 @@ class FlashCausalLMBatch(Batch):
|
|||
)
|
||||
prefilling = prefilling or b.prefilling
|
||||
|
||||
slots = batches[0].slots.new_empty(total_slots)
|
||||
cu_slots = torch.zeros(total_batch_size + 1, dtype=torch.int64)
|
||||
if prefilling:
|
||||
input_ids = []
|
||||
# These values will be set by `FlashCausalLMBatch.prepare_for_prefill`
|
||||
position_ids = None
|
||||
slots = None
|
||||
slot_indices = None
|
||||
cache_lengths_tensor = None
|
||||
input_lengths_tensor = None
|
||||
|
@ -688,7 +742,6 @@ class FlashCausalLMBatch(Batch):
|
|||
else:
|
||||
input_ids = batches[0].input_ids.new_empty(total_batch_size)
|
||||
position_ids = batches[0].position_ids.new_empty(total_batch_size)
|
||||
slots = batches[0].slots.new_empty(total_slots)
|
||||
slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
|
||||
input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
|
||||
total_batch_size
|
||||
|
@ -764,13 +817,16 @@ class FlashCausalLMBatch(Batch):
|
|||
] = batch.block_tables_tensor[:, :max_blocks]
|
||||
prompt_lengths_tensor[start_index:end_index] = batch.prompt_lengths_tensor
|
||||
|
||||
if not prefilling:
|
||||
slots_start_index = cumulative_slots
|
||||
slots_end_index = cumulative_slots + len(batch.slots)
|
||||
slots_start_index = cumulative_slots
|
||||
slots_end_index = cumulative_slots + len(batch.slots)
|
||||
slots[slots_start_index:slots_end_index] = batch.slots
|
||||
cu_slots[start_index + 1 : end_index + 1] = (
|
||||
batch.cu_slots[1:] + cumulative_slots
|
||||
)
|
||||
|
||||
if not prefilling:
|
||||
input_ids[start_index:end_index] = batch.input_ids
|
||||
position_ids[start_index:end_index] = batch.position_ids
|
||||
slots[slots_start_index:slots_end_index] = batch.slots
|
||||
slot_indices[start_index:end_index] = (
|
||||
batch.slot_indices + cumulative_slots
|
||||
)
|
||||
|
@ -792,9 +848,6 @@ class FlashCausalLMBatch(Batch):
|
|||
batch.adapter_meta.adapter_segments,
|
||||
batch.adapter_meta.segment_indices,
|
||||
)
|
||||
|
||||
# Update
|
||||
cumulative_slots += len(batch.slots)
|
||||
else:
|
||||
if isinstance(batch.input_ids, torch.Tensor):
|
||||
batch.input_ids = batch.input_ids.view(-1, 1).tolist()
|
||||
|
@ -819,6 +872,7 @@ class FlashCausalLMBatch(Batch):
|
|||
top_n_tokens.extend(batch.top_n_tokens)
|
||||
|
||||
# Update
|
||||
cumulative_slots += len(batch.slots)
|
||||
cumulative_batch_size += len(batch)
|
||||
|
||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||
|
@ -858,6 +912,7 @@ class FlashCausalLMBatch(Batch):
|
|||
cache_lengths=cache_lengths,
|
||||
cache_lengths_tensor=cache_lengths_tensor,
|
||||
slots=slots,
|
||||
cu_slots=cu_slots,
|
||||
max_input_length=max_input_length,
|
||||
max_current_length=max_current_length,
|
||||
prefilling=prefilling,
|
||||
|
@ -890,15 +945,50 @@ class FlashCausalLMBatch(Batch):
|
|||
# it simplifies everything
|
||||
assert self.speculative_ids is None
|
||||
|
||||
device = self.block_tables_tensor.device
|
||||
|
||||
if isinstance(self.input_ids, list):
|
||||
if len(self) > 1:
|
||||
input_ids = np.concatenate(self.input_ids, dtype=np.int64)
|
||||
else:
|
||||
input_ids = self.input_ids[0]
|
||||
self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||
|
||||
self.input_lengths_tensor = torch.tensor(
|
||||
self.input_lengths, dtype=torch.int32, device=device
|
||||
)
|
||||
self.cu_seqlen_prefill = torch.nn.functional.pad(
|
||||
torch.cumsum(self.input_lengths_tensor, dim=0), (1, 0)
|
||||
).to(torch.int32)
|
||||
self.cache_lengths_tensor = torch.tensor(
|
||||
self.cache_lengths, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
# If the device supports Triton, we can use a fused kernel
|
||||
if has_triton():
|
||||
self.position_ids = torch.empty(
|
||||
len(self.input_ids), dtype=torch.int32, device=device
|
||||
)
|
||||
self.slot_indices = torch.empty(
|
||||
len(self.input_ids), dtype=torch.int64, device=device
|
||||
)
|
||||
cu_slots_gpu = self.cu_slots.to(device)
|
||||
|
||||
prepare_position_slot_ids(
|
||||
self.max_input_length,
|
||||
self.cache_lengths_tensor,
|
||||
self.cu_seqlen_prefill,
|
||||
cu_slots_gpu,
|
||||
self.position_ids,
|
||||
self.slot_indices,
|
||||
)
|
||||
|
||||
sliding_window = get_sliding_windows()
|
||||
position_ids = []
|
||||
cu_seqlen_prefill = [0]
|
||||
slot_indices = []
|
||||
prefill_cache_indices = []
|
||||
all_prefill_logprobs = True
|
||||
no_prefill_logprobs = True
|
||||
prefill_head_indices = []
|
||||
prefill_next_token_indices = []
|
||||
prefill_cu_outlens = [0]
|
||||
|
||||
# Cumulative length
|
||||
|
@ -906,7 +996,6 @@ class FlashCausalLMBatch(Batch):
|
|||
cumulative_slot_tokens = 0
|
||||
prefill_out_cumulative_length = 0
|
||||
|
||||
slots = []
|
||||
adapter_indices_list = []
|
||||
adapter_set = set()
|
||||
|
||||
|
@ -928,30 +1017,33 @@ class FlashCausalLMBatch(Batch):
|
|||
)
|
||||
):
|
||||
next_chunk_length = input_length
|
||||
# Position ids
|
||||
request_position_ids = torch.arange(
|
||||
cache_length, cache_length + input_length, dtype=torch.int32
|
||||
)
|
||||
position_ids.append(request_position_ids)
|
||||
|
||||
# Add cumulative lengths of all previous inputs
|
||||
cu_seqlen_prefill.append(cumulative_length + input_length)
|
||||
if not has_triton():
|
||||
# Position ids
|
||||
request_position_ids = torch.arange(
|
||||
cache_length, cache_length + input_length, dtype=torch.int32
|
||||
)
|
||||
position_ids.append(request_position_ids)
|
||||
|
||||
if not r.slots:
|
||||
request_slots = [
|
||||
s
|
||||
for b in blocks
|
||||
for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)
|
||||
]
|
||||
else:
|
||||
request_slots = r.slots
|
||||
if not r.slots:
|
||||
request_slots = [
|
||||
s
|
||||
for b in blocks
|
||||
for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)
|
||||
]
|
||||
else:
|
||||
request_slots = r.slots
|
||||
|
||||
request_slots = request_slots[cache_length:]
|
||||
request_slot_indices = torch.arange(
|
||||
cumulative_slot_tokens,
|
||||
cumulative_slot_tokens + input_length,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
request_slot_indices = torch.arange(
|
||||
cache_length + cumulative_slot_tokens,
|
||||
cache_length + cumulative_slot_tokens + input_length,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
|
||||
slot_indices.append(request_slot_indices)
|
||||
|
||||
# Update
|
||||
cumulative_slot_tokens += len(request_slots)
|
||||
|
||||
# Create tensor to slice into the kv tensor in prefill
|
||||
if sliding_window is not None:
|
||||
|
@ -968,83 +1060,102 @@ class FlashCausalLMBatch(Batch):
|
|||
no_prefill_logprobs = no_prefill_logprobs and not prefill_logprobs
|
||||
|
||||
if prefill_logprobs:
|
||||
prefill_head_indices.append(
|
||||
torch.arange(
|
||||
cumulative_length,
|
||||
cumulative_length + input_length,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
)
|
||||
prefill_next_token_indices.append(
|
||||
prefill_out_cumulative_length + input_length - 1
|
||||
)
|
||||
prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
|
||||
prefill_out_cumulative_length += input_length
|
||||
else:
|
||||
prefill_head_indices.append(
|
||||
torch.tensor(
|
||||
[cumulative_length + input_length - 1],
|
||||
dtype=torch.int64,
|
||||
)
|
||||
)
|
||||
prefill_next_token_indices.append(prefill_out_cumulative_length)
|
||||
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
|
||||
prefill_out_cumulative_length += 1
|
||||
|
||||
slots.extend(request_slots)
|
||||
slot_indices.append(request_slot_indices)
|
||||
|
||||
if sliding_window is not None:
|
||||
prefill_cache_indices.append(request_prefill_cache_indices)
|
||||
|
||||
ADAPTER_TO_INDEX = get_adapter_to_index()
|
||||
adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0)
|
||||
adapter_indices_list.append(torch.full((next_chunk_length,), adapter_index))
|
||||
adapter_set.add(adapter_index)
|
||||
if ADAPTER_TO_INDEX:
|
||||
adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0)
|
||||
adapter_indices_list.append(
|
||||
torch.full((next_chunk_length,), adapter_index)
|
||||
)
|
||||
adapter_set.add(adapter_index)
|
||||
|
||||
# Update
|
||||
cumulative_length += next_chunk_length
|
||||
cumulative_slot_tokens += len(request_slots)
|
||||
|
||||
device = self.block_tables_tensor.device
|
||||
if not all_prefill_logprobs and not no_prefill_logprobs:
|
||||
prefill_head_indices = []
|
||||
prefill_next_token_indices = []
|
||||
|
||||
if isinstance(self.input_ids, list):
|
||||
if len(self) > 1:
|
||||
input_ids = np.concatenate(self.input_ids, dtype=np.int64)
|
||||
else:
|
||||
input_ids = self.input_ids[0]
|
||||
self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||
# Cumulative length
|
||||
cumulative_length = 0
|
||||
prefill_out_cumulative_length = 0
|
||||
|
||||
for i, (
|
||||
r,
|
||||
input_length,
|
||||
request_prefilling,
|
||||
) in enumerate(
|
||||
zip(
|
||||
self.requests,
|
||||
self.input_lengths,
|
||||
self.prefilling_mask,
|
||||
)
|
||||
):
|
||||
# Prefill logprobs is ignored if the request is done prefilling
|
||||
prefill_logprobs = r.prefill_logprobs and request_prefilling
|
||||
|
||||
if prefill_logprobs:
|
||||
prefill_head_indices.append(
|
||||
torch.arange(
|
||||
cumulative_length,
|
||||
cumulative_length + input_length,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
)
|
||||
prefill_next_token_indices.append(
|
||||
prefill_out_cumulative_length + input_length - 1
|
||||
)
|
||||
prefill_out_cumulative_length += input_length
|
||||
else:
|
||||
prefill_head_indices.append(
|
||||
torch.tensor(
|
||||
[cumulative_length + input_length - 1],
|
||||
dtype=torch.int64,
|
||||
)
|
||||
)
|
||||
prefill_next_token_indices.append(prefill_out_cumulative_length)
|
||||
prefill_out_cumulative_length += 1
|
||||
|
||||
# Update
|
||||
cumulative_length += input_length
|
||||
|
||||
if len(self) > 1:
|
||||
position_ids = torch.cat(position_ids)
|
||||
slot_indices = torch.cat(slot_indices)
|
||||
if position_ids:
|
||||
position_ids = torch.cat(position_ids)
|
||||
if slot_indices:
|
||||
slot_indices = torch.cat(slot_indices)
|
||||
if sliding_window is not None:
|
||||
prefill_cache_indices = torch.cat(prefill_cache_indices)
|
||||
else:
|
||||
position_ids = position_ids[0]
|
||||
slot_indices = slot_indices[0]
|
||||
if position_ids:
|
||||
position_ids = position_ids[0]
|
||||
if slot_indices:
|
||||
slot_indices = slot_indices[0]
|
||||
if sliding_window is not None:
|
||||
prefill_cache_indices = prefill_cache_indices[0]
|
||||
|
||||
if not has_triton():
|
||||
self.position_ids = position_ids.to(device)
|
||||
self.slot_indices = slot_indices.to(device)
|
||||
|
||||
self.prefill_cu_outlens = prefill_cu_outlens
|
||||
cu_seqlen_prefill = torch.tensor(
|
||||
cu_seqlen_prefill, device=device, dtype=torch.int32
|
||||
)
|
||||
self.cu_seqlen_prefill = cu_seqlen_prefill
|
||||
self.position_ids = position_ids.to(device)
|
||||
self.slot_indices = slot_indices.to(device)
|
||||
self.prefill_cache_indices = (
|
||||
prefill_cache_indices.to(device) if sliding_window is not None else None
|
||||
)
|
||||
self.input_lengths_tensor = torch.tensor(
|
||||
self.input_lengths, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
if all_prefill_logprobs:
|
||||
prefill_head_indices = None
|
||||
prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
|
||||
prefill_next_token_indices = self.cu_seqlen_prefill[1:] - 1
|
||||
elif no_prefill_logprobs:
|
||||
prefill_head_indices = cu_seqlen_prefill[1:] - 1
|
||||
prefill_head_indices = self.cu_seqlen_prefill[1:] - 1
|
||||
prefill_next_token_indices = None
|
||||
else:
|
||||
prefill_head_indices = torch.cat(prefill_head_indices).to(device)
|
||||
|
@ -1054,17 +1165,21 @@ class FlashCausalLMBatch(Batch):
|
|||
|
||||
self.prefill_head_indices = prefill_head_indices
|
||||
self.prefill_next_token_indices = prefill_next_token_indices
|
||||
self.slots = torch.tensor(slots, dtype=torch.int64, device=device)
|
||||
self.cache_lengths_tensor = torch.tensor(
|
||||
self.cache_lengths, dtype=torch.int32, device=device
|
||||
)
|
||||
adapter_indices = torch.cat(adapter_indices_list).to(
|
||||
dtype=torch.int64, device=device
|
||||
)
|
||||
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
|
||||
|
||||
if adapter_set:
|
||||
adapter_indices = torch.cat(adapter_indices_list).to(
|
||||
dtype=torch.int64, device=device
|
||||
)
|
||||
adapter_segments, adapter_segment_indices = find_segments(adapter_indices)
|
||||
else:
|
||||
adapter_indices = torch.zeros_like(self.input_ids)
|
||||
adapter_segments = [0, len(adapter_indices)]
|
||||
adapter_segment_indices = [len(adapter_indices) - 1]
|
||||
|
||||
adapter_segments = torch.tensor(
|
||||
adapter_segments, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
self.adapter_meta = AdapterBatchMetadata(
|
||||
adapter_indices=adapter_indices,
|
||||
adapter_set=adapter_set,
|
||||
|
@ -1288,6 +1403,9 @@ class FlashCausalLM(Model):
|
|||
block_tables=block_tables,
|
||||
input_lengths=input_lengths,
|
||||
cache_lengths=cache_lengths,
|
||||
input_lengths_tensor=input_lengths_tensor,
|
||||
cache_lengths_tensor=cache_lengths_tensor,
|
||||
max_current_length=max_s,
|
||||
)
|
||||
from text_generation_server.layers.attention.flashinfer import (
|
||||
create_decode_state_cuda_graphs,
|
||||
|
@ -1621,6 +1739,9 @@ class FlashCausalLM(Model):
|
|||
block_tables=block_tables,
|
||||
input_lengths=batch.input_lengths,
|
||||
cache_lengths=batch.cache_lengths,
|
||||
input_lengths_tensor=batch.input_lengths_tensor,
|
||||
cache_lengths_tensor=batch.cache_lengths_tensor,
|
||||
max_current_length=batch.max_current_length,
|
||||
)
|
||||
with self._forward_context(
|
||||
block_tables=block_tables,
|
||||
|
@ -1661,6 +1782,9 @@ class FlashCausalLM(Model):
|
|||
block_tables=block_tables,
|
||||
input_lengths=batch.input_lengths,
|
||||
cache_lengths=batch.cache_lengths,
|
||||
input_lengths_tensor=batch.input_lengths_tensor,
|
||||
cache_lengths_tensor=batch.cache_lengths_tensor,
|
||||
max_current_length=batch.max_current_length,
|
||||
)
|
||||
# assert block_tables.shape[0] >= slots.shape[0]
|
||||
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
|
||||
|
@ -1756,7 +1880,6 @@ class FlashCausalLM(Model):
|
|||
else:
|
||||
prefill_logprobs = None
|
||||
next_token_logits = out
|
||||
next_adapter_indices = batch.adapter_meta.adapter_indices
|
||||
|
||||
finished_prefilling = True
|
||||
next_chunk_lengths = []
|
||||
|
@ -1827,13 +1950,12 @@ class FlashCausalLM(Model):
|
|||
# Since we are done prefilling, all the tensors that were concatenating values for all the requests
|
||||
# instantly become of shape [BATCH_SIZE]
|
||||
if prefill and finished_prefilling:
|
||||
next_position_ids = batch.position_ids.new_empty(len(batch))
|
||||
batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
|
||||
next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty(
|
||||
len(batch)
|
||||
)
|
||||
elif not prefill:
|
||||
next_position_ids = batch.position_ids
|
||||
indices = batch.cu_seqlen_prefill[1:] - 1
|
||||
batch.position_ids = batch.position_ids[indices]
|
||||
batch.slot_indices = batch.slot_indices[indices]
|
||||
batch.adapter_meta.adapter_indices = batch.adapter_meta.adapter_indices[
|
||||
indices
|
||||
]
|
||||
|
||||
# Zipped iterator
|
||||
iterator = zip(
|
||||
|
@ -1852,8 +1974,10 @@ class FlashCausalLM(Model):
|
|||
# It is faster if we delay this sync for the maximum amount of time
|
||||
|
||||
# For each member of the batch
|
||||
index = 0
|
||||
# Cumulative length
|
||||
cu_accepted_ids = torch.nn.functional.pad(
|
||||
torch.cumsum(accepted_ids, dim=0), (1, 0)
|
||||
)
|
||||
cumulative_length = 0
|
||||
for i, (
|
||||
request,
|
||||
|
@ -1865,21 +1989,6 @@ class FlashCausalLM(Model):
|
|||
request_was_prefilling,
|
||||
request_is_prefilling,
|
||||
) in enumerate(iterator):
|
||||
if prefill and finished_prefilling:
|
||||
# Indexing metadata
|
||||
_start_index = cumulative_length
|
||||
end_index = cumulative_length + input_length
|
||||
|
||||
# Initialize position_ids
|
||||
# In decode, we do not need this as we can just increment position ids
|
||||
next_position_ids[i] = batch.position_ids[end_index - 1]
|
||||
|
||||
# Initialize adapter indices
|
||||
# In decode, we only have one token per row in the batch, so grab last index
|
||||
next_adapter_indices[i] = batch.adapter_meta.adapter_indices[
|
||||
end_index - 1
|
||||
]
|
||||
|
||||
# Used to gather prefill logprobs
|
||||
# Copy batch.all_input_ids_tensor to prefill_token_indices
|
||||
if request.prefill_logprobs and request_was_prefilling:
|
||||
|
@ -1898,25 +2007,39 @@ class FlashCausalLM(Model):
|
|||
# Set prefill_tokens_indices to the correct slice
|
||||
prefill_tokens_indices = ids
|
||||
|
||||
if not request_is_prefilling:
|
||||
# If the device does not support triton, we copy one by one
|
||||
if not request_is_prefilling and not has_triton():
|
||||
# Only save tokens if we are done prefilling for this request
|
||||
for j in range(n_accepted_ids):
|
||||
batch.all_input_ids_tensor[i, cache_length + input_length + j] = (
|
||||
next_input_ids[index + j]
|
||||
)
|
||||
index += n_accepted_ids
|
||||
batch.all_input_ids_tensor[
|
||||
i,
|
||||
batch.cache_lengths_tensor[i]
|
||||
+ batch.input_lengths[i] : batch.cache_lengths_tensor[i]
|
||||
+ batch.input_lengths[i]
|
||||
+ accepted_ids[i],
|
||||
] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]]
|
||||
cumulative_length += input_length
|
||||
|
||||
# If the device support triton, we can use a fused kernel
|
||||
if has_triton():
|
||||
copy_next_input_ids_inplace(
|
||||
speculate + 1,
|
||||
batch.all_input_ids_tensor,
|
||||
batch.cache_lengths_tensor,
|
||||
batch.input_lengths_tensor,
|
||||
batch.prompt_lengths_tensor,
|
||||
next_input_ids,
|
||||
cu_accepted_ids,
|
||||
)
|
||||
|
||||
# Update values
|
||||
# These values can be updated without a GPU -> CPU sync
|
||||
if not prefill or (prefill and finished_prefilling):
|
||||
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
|
||||
batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1]
|
||||
batch.speculative_ids = speculative_ids
|
||||
batch.position_ids = next_position_ids + accepted_ids
|
||||
batch.cache_lengths_tensor += batch.input_lengths_tensor
|
||||
batch.input_lengths_tensor = accepted_ids.to(dtype=torch.int32)
|
||||
batch.position_ids += accepted_ids
|
||||
batch.cache_lengths_tensor += batch.input_lengths_tensor + accepted_ids - 1
|
||||
batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor)
|
||||
batch.slot_indices += accepted_ids
|
||||
batch.adapter_meta.adapter_indices = next_adapter_indices
|
||||
|
||||
if prefill and prefill_logprobs:
|
||||
# Get prefill logprobs with inplace softmax (avoid copying the `out` tensor (max_batch_prefill_tokens * vocab_size))
|
||||
|
@ -2093,8 +2216,10 @@ class FlashCausalLM(Model):
|
|||
# processing
|
||||
stopped = False
|
||||
new_input_length = next_chunk_lengths[i]
|
||||
new_cache_length = cache_length + input_length
|
||||
else:
|
||||
new_input_length = n_accepted_ids
|
||||
new_input_length = 1
|
||||
new_cache_length = cache_length + input_length + n_accepted_ids - 1
|
||||
# Append next token to all tokens
|
||||
next_token_texts = []
|
||||
left = 0
|
||||
|
@ -2206,12 +2331,10 @@ class FlashCausalLM(Model):
|
|||
|
||||
# Update values
|
||||
index += n_accepted_ids
|
||||
current_cache_length = cache_length + input_length
|
||||
batch.cache_lengths[i] = current_cache_length
|
||||
current_input_length = new_input_length
|
||||
batch.max_input_length = max(batch.max_input_length, current_input_length)
|
||||
batch.input_lengths[i] = current_input_length
|
||||
current_length = current_cache_length + current_input_length
|
||||
batch.cache_lengths[i] = new_cache_length
|
||||
batch.max_input_length = max(batch.max_input_length, new_input_length)
|
||||
batch.input_lengths[i] = new_input_length
|
||||
current_length = new_cache_length + new_input_length
|
||||
batch.max_current_length = max(batch.max_current_length, current_length)
|
||||
|
||||
batch.prefix_offsets[i] = prefix_offset
|
||||
|
@ -2258,11 +2381,6 @@ class FlashCausalLM(Model):
|
|||
state=(
|
||||
state if state is not None else self.prefill_with_paged_kv_state
|
||||
),
|
||||
# block_tables=block_tables_to_ragged(
|
||||
# block_tables=block_tables,
|
||||
# input_lengths=input_lengths,
|
||||
# cache_lengths=cache_lengths,
|
||||
# ),
|
||||
block_tables=block_tables,
|
||||
cu_seqlens=cu_seqlen_prefill,
|
||||
input_lengths=input_lengths_tensor + cache_lengths_tensor,
|
||||
|
@ -2287,23 +2405,3 @@ class FlashCausalLM(Model):
|
|||
dtype=self.dtype,
|
||||
window_left=self.sliding_window,
|
||||
)
|
||||
|
||||
|
||||
def block_tables_to_ragged(
|
||||
*, block_tables: torch.Tensor, input_lengths: List[int], cache_lengths: List[int]
|
||||
) -> torch.Tensor:
|
||||
"""Convert block table to ragged format compatible with FlashInfer."""
|
||||
assert len(input_lengths) == len(cache_lengths)
|
||||
|
||||
total_len = sum(input_lengths) + sum(cache_lengths)
|
||||
block_tables_ragged = torch.empty(
|
||||
total_len, dtype=torch.int32, device=block_tables.device
|
||||
)
|
||||
|
||||
offset = 0
|
||||
for i, (input_length, cache_length) in enumerate(zip(input_lengths, cache_lengths)):
|
||||
seq_len = cache_length + input_length
|
||||
block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len]
|
||||
offset += seq_len
|
||||
|
||||
return block_tables_ragged
|
||||
|
|
|
@ -0,0 +1,347 @@
|
|||
import torch
|
||||
import triton
|
||||
|
||||
import triton.language as tl
|
||||
|
||||
from loguru import logger
|
||||
from typing import List, Optional
|
||||
from torch.utils._triton import has_triton as has_triton_torch
|
||||
|
||||
from text_generation_server.utils.import_utils import (
|
||||
SYSTEM,
|
||||
)
|
||||
from text_generation_server.utils.log import log_master
|
||||
|
||||
_HAS_TRITON: Optional[bool] = None
|
||||
|
||||
|
||||
def has_triton():
|
||||
global _HAS_TRITON
|
||||
if _HAS_TRITON is None:
|
||||
# FIXME: it seems that has_triton_torch is bugged on RocM
|
||||
# For now, only accept cuda
|
||||
_HAS_TRITON = has_triton_torch() if SYSTEM == "cuda" else False
|
||||
if _HAS_TRITON:
|
||||
log_master(logger.info, "Using optimized Triton indexing kernels.")
|
||||
|
||||
return _HAS_TRITON
|
||||
|
||||
|
||||
def block_tables_to_padded(
|
||||
max_blocks: int,
|
||||
cu_seqlen: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
block_tables_ragged: torch.Tensor,
|
||||
):
|
||||
def grid(meta):
|
||||
return (
|
||||
triton.cdiv(max_blocks, meta["BLOCK_SIZE"]),
|
||||
len(block_tables),
|
||||
)
|
||||
|
||||
triton_block_tables_to_padded[grid](
|
||||
cu_seqlen,
|
||||
block_tables,
|
||||
block_tables_ragged,
|
||||
block_tables.shape[1],
|
||||
BLOCK_SIZE=256,
|
||||
)
|
||||
|
||||
|
||||
def block_tables_to_ragged(
|
||||
*,
|
||||
block_tables: torch.Tensor,
|
||||
input_lengths: List[int],
|
||||
cache_lengths: List[int],
|
||||
input_lengths_tensor: torch.Tensor,
|
||||
cache_lengths_tensor: torch.Tensor,
|
||||
max_current_length: int
|
||||
) -> torch.Tensor:
|
||||
"""Convert block table to ragged format compatible with FlashInfer."""
|
||||
assert len(input_lengths) == len(cache_lengths)
|
||||
|
||||
total_len = sum(input_lengths) + sum(cache_lengths)
|
||||
block_tables_ragged = torch.empty(
|
||||
total_len, dtype=torch.int32, device=block_tables.device
|
||||
)
|
||||
|
||||
if has_triton():
|
||||
cu_seqlen = torch.nn.functional.pad(
|
||||
torch.cumsum(input_lengths_tensor + cache_lengths_tensor, dim=0), (1, 0)
|
||||
)
|
||||
|
||||
def grid(meta):
|
||||
return (
|
||||
triton.cdiv(max_current_length, meta["BLOCK_SIZE"]),
|
||||
len(cache_lengths),
|
||||
)
|
||||
|
||||
triton_block_tables_to_ragged[grid](
|
||||
cu_seqlen,
|
||||
block_tables,
|
||||
block_tables_ragged,
|
||||
block_tables.shape[1],
|
||||
BLOCK_SIZE=256,
|
||||
)
|
||||
else:
|
||||
offset = 0
|
||||
for i, (input_length, cache_length) in enumerate(
|
||||
zip(input_lengths, cache_lengths)
|
||||
):
|
||||
seq_len = cache_length + input_length
|
||||
block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len]
|
||||
offset += seq_len
|
||||
|
||||
return block_tables_ragged
|
||||
|
||||
|
||||
def copy_next_input_ids_inplace(
|
||||
max_next_input_ids: int,
|
||||
all_input_ids: torch.Tensor,
|
||||
cache_lengths: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
prompt_lengths: torch.Tensor,
|
||||
next_input_ids: torch.Tensor,
|
||||
cu_accepted_ids: torch.Tensor,
|
||||
):
|
||||
def grid(meta):
|
||||
return (
|
||||
triton.cdiv(max_next_input_ids, meta["BLOCK_SIZE"]),
|
||||
len(all_input_ids),
|
||||
)
|
||||
|
||||
triton_copy_next_input_ids_inplace[grid](
|
||||
all_input_ids,
|
||||
cache_lengths,
|
||||
input_lengths,
|
||||
prompt_lengths,
|
||||
next_input_ids,
|
||||
cu_accepted_ids,
|
||||
all_input_ids.shape[1],
|
||||
BLOCK_SIZE=16,
|
||||
)
|
||||
|
||||
|
||||
def prepare_position_slot_ids(
|
||||
max_input_length: int,
|
||||
cache_lengths: torch.Tensor,
|
||||
cu_seqlen: torch.Tensor,
|
||||
cu_slots: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
slot_indices: torch.Tensor,
|
||||
):
|
||||
def grid(meta):
|
||||
return (
|
||||
triton.cdiv(max_input_length, meta["BLOCK_SIZE"]),
|
||||
len(cache_lengths),
|
||||
)
|
||||
|
||||
triton_prepare_position_slot_ids[grid](
|
||||
cache_lengths, cu_seqlen, cu_slots, position_ids, slot_indices, BLOCK_SIZE=256
|
||||
)
|
||||
|
||||
|
||||
def slots_filtering(
|
||||
max_slots: int,
|
||||
slots: torch.Tensor,
|
||||
filtered_slots: torch.Tensor,
|
||||
cu_slots: torch.Tensor,
|
||||
slots_start: torch.Tensor,
|
||||
):
|
||||
def grid(meta):
|
||||
return (
|
||||
triton.cdiv(max_slots, meta["BLOCK_SIZE"]),
|
||||
len(slots_start),
|
||||
)
|
||||
|
||||
triton_slots_filtering[grid](
|
||||
slots, filtered_slots, slots_start, cu_slots, BLOCK_SIZE=256
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def triton_slots_filtering(
|
||||
# Inputs
|
||||
slots_ptr,
|
||||
filtered_slots_ptr,
|
||||
slots_start_ptr,
|
||||
cu_slots_ptr,
|
||||
# Const values
|
||||
BLOCK_SIZE: "tl.constexpr",
|
||||
):
|
||||
# Position in block_tables_ragged.numel() / BLOCK_SIZE
|
||||
pid = tl.program_id(axis=0)
|
||||
# Position in batch
|
||||
bid = tl.program_id(axis=1)
|
||||
|
||||
block_start = pid * BLOCK_SIZE
|
||||
block_arange = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
filter_start = tl.load(slots_start_ptr + bid)
|
||||
|
||||
slot_start = tl.load(cu_slots_ptr + bid)
|
||||
slot_end = tl.load(cu_slots_ptr + bid + 1)
|
||||
|
||||
mask = (slot_start + block_arange) < slot_end
|
||||
|
||||
slots = tl.load(slots_ptr + filter_start + block_arange, mask=mask)
|
||||
tl.store(filtered_slots_ptr + slot_start + block_arange, slots, mask=mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def triton_block_tables_to_padded(
|
||||
# Inputs
|
||||
cu_seqlen_ptr,
|
||||
# Outputs
|
||||
block_tables_ptr,
|
||||
block_tables_ragged_ptr,
|
||||
# Stride
|
||||
stride_block_tables,
|
||||
# Const values
|
||||
BLOCK_SIZE: "tl.constexpr",
|
||||
):
|
||||
# Position in block_tables_ragged.numel() / BLOCK_SIZE
|
||||
pid = tl.program_id(axis=0)
|
||||
# Position in batch
|
||||
bid = tl.program_id(axis=1)
|
||||
|
||||
block_start = pid * BLOCK_SIZE
|
||||
block_arange = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
seq_start = tl.load(cu_seqlen_ptr + bid)
|
||||
seq_end = tl.load(cu_seqlen_ptr + bid + 1)
|
||||
|
||||
mask = (seq_start + block_arange) < seq_end
|
||||
|
||||
blocks = tl.load(block_tables_ragged_ptr + seq_start + block_arange, mask=mask)
|
||||
tl.store(
|
||||
block_tables_ptr + bid * stride_block_tables + block_arange, blocks, mask=mask
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def triton_block_tables_to_ragged(
|
||||
# Inputs
|
||||
cu_seqlen_ptr,
|
||||
# Outputs
|
||||
block_tables_ptr,
|
||||
block_tables_ragged_ptr,
|
||||
# Stride
|
||||
stride_block_tables,
|
||||
# Const values
|
||||
BLOCK_SIZE: "tl.constexpr",
|
||||
):
|
||||
# Position in block_tables_ragged.numel() / BLOCK_SIZE
|
||||
pid = tl.program_id(axis=0)
|
||||
# Position in batch
|
||||
bid = tl.program_id(axis=1)
|
||||
|
||||
block_start = pid * BLOCK_SIZE
|
||||
block_arange = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
seq_start = tl.load(cu_seqlen_ptr + bid)
|
||||
seq_end = tl.load(cu_seqlen_ptr + bid + 1)
|
||||
|
||||
mask = (seq_start + block_arange) < seq_end
|
||||
|
||||
blocks = tl.load(
|
||||
block_tables_ptr + bid * stride_block_tables + block_arange, mask=mask
|
||||
)
|
||||
tl.store(block_tables_ragged_ptr + seq_start + block_arange, blocks, mask=mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def triton_copy_next_input_ids_inplace(
|
||||
# Inputs
|
||||
all_input_ids_ptr,
|
||||
cache_lengths_ptr,
|
||||
input_lengths_ptr,
|
||||
prompt_lengths_ptr,
|
||||
next_input_ids_ptr,
|
||||
cu_accepted_ids_ptr,
|
||||
# Stride
|
||||
stride_all_input_ids,
|
||||
# Const values
|
||||
BLOCK_SIZE: "tl.constexpr",
|
||||
):
|
||||
# Position in max_accepted_ids / BLOCK_SIZE
|
||||
pid = tl.program_id(axis=0)
|
||||
# Position in batch
|
||||
bid = tl.program_id(axis=1)
|
||||
|
||||
block_start = pid * BLOCK_SIZE
|
||||
block_arange = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
# Used for correctly indexing in all_input_ids
|
||||
cache_length = tl.load(cache_lengths_ptr + bid)
|
||||
input_length = tl.load(input_lengths_ptr + bid)
|
||||
prompt_length = tl.load(prompt_lengths_ptr + bid)
|
||||
|
||||
# Start/End of next_input_ids for this request
|
||||
next_input_ids_start = tl.load(cu_accepted_ids_ptr + bid)
|
||||
next_input_ids_end = tl.load(cu_accepted_ids_ptr + bid + 1)
|
||||
|
||||
# Mask values out of range
|
||||
mask = (next_input_ids_start + block_arange) < next_input_ids_end
|
||||
|
||||
# Mask values for request still prefilling
|
||||
decode_mask = (cache_length + input_length + block_arange) >= prompt_length
|
||||
|
||||
mask = mask & decode_mask
|
||||
|
||||
# Load this request next input ids
|
||||
next_input_ids = tl.load(
|
||||
next_input_ids_ptr + next_input_ids_start + block_arange, mask=mask
|
||||
)
|
||||
|
||||
# Store in all_input_ids, since it is a 2D tensor, apply stride * bid
|
||||
tl.store(
|
||||
all_input_ids_ptr
|
||||
+ stride_all_input_ids * bid
|
||||
+ cache_length
|
||||
+ input_length
|
||||
+ block_arange,
|
||||
next_input_ids,
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def triton_prepare_position_slot_ids(
|
||||
# Inputs
|
||||
cache_lengths_ptr,
|
||||
cu_seqlen_ptr,
|
||||
cu_slots_ptr,
|
||||
# Outputs
|
||||
position_ids_ptr,
|
||||
slot_indices_ptr,
|
||||
# Const values
|
||||
BLOCK_SIZE: "tl.constexpr",
|
||||
):
|
||||
# Position in max_input_length / BLOCK_SIZE
|
||||
pid = tl.program_id(axis=0)
|
||||
# Position in batch
|
||||
bid = tl.program_id(axis=1)
|
||||
|
||||
block_start = pid * BLOCK_SIZE
|
||||
block_arange = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
cache_length = tl.load(cache_lengths_ptr + bid)
|
||||
|
||||
seq_start = tl.load(cu_seqlen_ptr + bid)
|
||||
seq_end = tl.load(cu_seqlen_ptr + bid + 1)
|
||||
|
||||
slot_start = tl.load(cu_slots_ptr + bid)
|
||||
|
||||
mask = (seq_start + block_arange) < seq_end
|
||||
|
||||
tl.store(
|
||||
position_ids_ptr + seq_start + block_arange,
|
||||
cache_length + block_arange,
|
||||
mask=mask,
|
||||
)
|
||||
tl.store(
|
||||
slot_indices_ptr + seq_start + block_arange,
|
||||
slot_start + cache_length + block_arange,
|
||||
mask=mask,
|
||||
)
|
|
@ -14,11 +14,9 @@ from transformers import (
|
|||
|
||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch, VlmCausalLM
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.models.flash_causal_lm import (
|
||||
block_tables_to_ragged,
|
||||
)
|
||||
from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION
|
||||
from text_generation_server.layers.attention import Seqlen
|
||||
from text_generation_server.models.metadata_kernels import block_tables_to_ragged
|
||||
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
@ -283,6 +281,9 @@ class MllamaCausalLM(VlmCausalLM):
|
|||
block_tables=block_tables,
|
||||
input_lengths=batch.input_lengths,
|
||||
cache_lengths=batch.cache_lengths,
|
||||
input_lengths_tensor=batch.input_lengths_tensor,
|
||||
cache_lengths_tensor=batch.cache_lengths_tensor,
|
||||
max_current_length=batch.max_current_length,
|
||||
)
|
||||
with self._forward_context(
|
||||
block_tables=block_tables,
|
||||
|
@ -338,6 +339,9 @@ class MllamaCausalLM(VlmCausalLM):
|
|||
block_tables=block_tables,
|
||||
input_lengths=batch.input_lengths,
|
||||
cache_lengths=batch.cache_lengths,
|
||||
input_lengths_tensor=batch.input_lengths_tensor,
|
||||
cache_lengths_tensor=batch.cache_lengths_tensor,
|
||||
max_current_length=batch.max_current_length,
|
||||
)
|
||||
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
|
||||
else:
|
||||
|
|
|
@ -11,12 +11,12 @@ from text_generation_server.pb import generate_pb2
|
|||
from text_generation_server.models.flash_causal_lm import (
|
||||
FlashCausalLMBatch,
|
||||
FlashCausalLM,
|
||||
block_tables_to_ragged,
|
||||
)
|
||||
from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION
|
||||
from text_generation_server.utils.log import log_master
|
||||
from transformers import AutoProcessor
|
||||
from text_generation_server.layers.attention import Seqlen
|
||||
from text_generation_server.models.metadata_kernels import block_tables_to_ragged
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
@ -363,6 +363,9 @@ class VlmCausalLM(FlashCausalLM):
|
|||
block_tables=block_tables,
|
||||
input_lengths=batch.input_lengths,
|
||||
cache_lengths=batch.cache_lengths,
|
||||
input_lengths_tensor=batch.input_lengths_tensor,
|
||||
cache_lengths_tensor=batch.cache_lengths_tensor,
|
||||
max_current_length=batch.max_current_length,
|
||||
)
|
||||
with self._forward_context(
|
||||
block_tables=block_tables,
|
||||
|
@ -411,6 +414,9 @@ class VlmCausalLM(FlashCausalLM):
|
|||
block_tables=block_tables,
|
||||
input_lengths=batch.input_lengths,
|
||||
cache_lengths=batch.cache_lengths,
|
||||
input_lengths_tensor=batch.input_lengths_tensor,
|
||||
cache_lengths_tensor=batch.cache_lengths_tensor,
|
||||
max_current_length=batch.max_current_length,
|
||||
)
|
||||
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue