diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index b931671c..87e904f4 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -71,6 +71,14 @@ from text_generation_server.utils.import_utils import ( synchronize, get_free_memory, ) +from text_generation_server.models.metadata_kernels import ( + has_triton, + copy_next_input_ids_inplace, + block_tables_to_ragged, + block_tables_to_padded, + prepare_position_slot_ids, + slots_filtering, +) tracer = trace.get_tracer(__name__) @@ -147,8 +155,10 @@ class FlashCausalLMBatch(Batch): # tensor of size [b, max_total_seqlen // block_size] holding the paged attention block tables for all sequences block_tables_tensor: torch.Tensor # tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences - # Will be set by `generate_token` and reset after each prefill forward before staying set in decode - slots: Optional[torch.Tensor] + slots: torch.Tensor + # list of length b + 1 containing the cumulative sequence slot lengths of the sequences in the batch + # used for filtering + cu_slots: torch.Tensor max_input_length: int max_current_length: int @@ -159,7 +169,7 @@ class FlashCausalLMBatch(Batch): prefilling_mask: List[bool] # Prefill metadata tensors to efficiently compute logprobs - # tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill + # tensor of length b + 1 containing the cumulative sequence lengths of the sequences in the batch, only used in prefill cu_seqlen_prefill: Optional[torch.Tensor] # Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers # as we only keep SLIDING_WINDOW values instead of the whole tensor @@ -257,6 +267,8 @@ class FlashCausalLMBatch(Batch): all_input_ids = [] all_postfix_ids = [] requests_idx_mapping = {} + slots = [] + cu_slots = [0] next_token_chooser_parameters = [] stopping_criterias = [] @@ -268,7 +280,9 @@ class FlashCausalLMBatch(Batch): max_length = 0 max_blocks = 0 + cu_blocks = [0] block_tables = [] + block_tables_ragged = [] # Parse batch for i, (r, tokenized_input) in enumerate( @@ -341,10 +355,21 @@ class FlashCausalLMBatch(Batch): request_blocks = [ b for b in range(num_blocks, num_blocks + needed_blocks) ] + request_slots = [ + s + for b in request_blocks + for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE) + ] else: request_blocks = r.blocks + request_slots = r.slots block_tables.append(request_blocks) + block_tables_ragged.extend(request_blocks) + cu_blocks.append(len(block_tables_ragged)) + + slots.extend(request_slots) + cu_slots.append(len(slots)) cache_lengths.append(cache_length) num_blocks += len(request_blocks) @@ -378,16 +403,34 @@ class FlashCausalLMBatch(Batch): top_n_tokens, device=device, dtype=torch.int64 ) - block_tables_tensor = torch.zeros( - (len(block_tables), max_blocks), dtype=torch.int32, device="cpu" + block_tables_ragged = torch.tensor( + block_tables_ragged, device=device, dtype=torch.int32 ) - for i, request_blocks in enumerate(block_tables): - block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks) - block_tables_tensor = block_tables_tensor.to(device) + cu_blocks = torch.tensor(cu_blocks, device=device, dtype=torch.int64) + block_tables_tensor = torch.empty( + (len(block_tables), max_blocks), + device=device, + dtype=torch.int32, + ) + + # If the device supports Triton, we can use a fused kernel + if has_triton(): + block_tables_to_padded( + max_blocks, cu_blocks, block_tables_tensor, block_tables_ragged + ) + else: + for i, request_blocks in enumerate(block_tables): + block_tables_tensor[i, : len(request_blocks)] = torch.tensor( + request_blocks + ) + prompt_lengths_tensor = torch.tensor( prompt_lengths, dtype=torch.int32, device=device ) + slots = torch.tensor(slots, dtype=torch.int64, device=device) + cu_slots = torch.tensor(cu_slots, dtype=torch.int64) + return cls( batch_id=pb.id, requests=pb.requests, @@ -420,7 +463,8 @@ class FlashCausalLMBatch(Batch): cu_seqlen_prefill=None, prefill_cache_indices=None, slot_indices=None, - slots=None, + slots=slots, + cu_slots=cu_slots, prefill_head_indices=None, prefill_next_token_indices=None, prefill_cu_outlens=None, @@ -457,10 +501,11 @@ class FlashCausalLMBatch(Batch): # Used to index into tensors indices = [] - # slots to keep after filtering - slot_filtering_indices = torch.zeros( - self.slots.shape[0], dtype=torch.bool, device=device - ) + if not has_triton(): + # slots to keep after filtering + slot_filtering_indices = torch.zeros( + self.slots.shape[0], dtype=torch.bool, device=device + ) # Create on CPU to only move to GPU once instead of at every copy slot_indices = torch.empty(len(request_ids), dtype=torch.int64) @@ -477,6 +522,7 @@ class FlashCausalLMBatch(Batch): cache_lengths = [] prefix_offsets = [] read_offsets = [] + cu_slots = [0] prefilling_mask = [] prefill_logprob_tokens = [] @@ -487,8 +533,8 @@ class FlashCausalLMBatch(Batch): num_blocks = 0 max_blocks = 0 - # Cumulative length - cumulative_max_length = 0 + max_slots = 0 + cumulative_slot_tokens = 0 for i, request_id in enumerate(request_ids): idx = self.requests_idx_mapping[request_id] @@ -531,29 +577,27 @@ class FlashCausalLMBatch(Batch): num_blocks += len(request_block_table) block_tables.append(request_block_table) + start_slot = self.cu_slots[idx] + end_slot = self.cu_slots[idx + 1] + slot_length = end_slot - start_slot + + if not has_triton(): + # Set slice + slot_filtering_indices[start_slot:end_slot] = True + + cu_slots.append(cumulative_slot_tokens + slot_length) + # Input ids if the request was part of a prefilling batch # If the batch was decoding we can index into the tensor directly later if self.prefilling: input_ids.append(self.input_ids[idx]) else: # Copy to tensor (CPU) - slot_indices[i] = cumulative_max_length - - remaining_tokens = ( - stopping_criteria.max_new_tokens - stopping_criteria.current_tokens - ) - - # Set slice - slot_filtering_indices[ - self.slot_indices[idx] : self.slot_indices[idx] - + request_input_length - + remaining_tokens - - 1 - ] = True - - cumulative_max_length += request_input_length + remaining_tokens - 1 + slot_indices[i] = cumulative_slot_tokens + request_cache_length + cumulative_slot_tokens += slot_length max_blocks = max(max_blocks, len(request_block_table)) + max_slots = max(max_slots, slot_length) all_input_ids_tensor = self.all_input_ids_tensor[indices] block_tables_tensor = self.block_tables_tensor[indices] @@ -564,11 +608,22 @@ class FlashCausalLMBatch(Batch): ) prompt_lengths_tensor = self.prompt_lengths_tensor[indices] + cu_slots = torch.tensor(cu_slots, dtype=torch.int64) + + if not has_triton(): + slots = self.slots[slot_filtering_indices] + else: + slots = self.slots.new_empty(cumulative_slot_tokens) + gpu_cu_slots = cu_slots.to(device) + slots_indexing_start = self.cu_slots.to(device)[indices] + slots_filtering( + max_slots, self.slots, slots, gpu_cu_slots, slots_indexing_start + ) + if self.prefilling: # These values will be set by `FlashCausalLMBatch.prepare_for_prefill` position_ids = None slot_indices = None - slots = None cache_lengths_tensor = None input_lengths_tensor = None adapter_meta = None @@ -578,7 +633,6 @@ class FlashCausalLMBatch(Batch): position_ids = self.position_ids[indices] adapter_indices = self.adapter_meta.adapter_indices[indices] input_lengths_tensor = self.input_lengths_tensor[indices] - slots = self.slots[slot_filtering_indices] cache_lengths_tensor = self.cache_lengths_tensor[indices] # Move to GPU now that we have the whole tensor @@ -607,6 +661,7 @@ class FlashCausalLMBatch(Batch): block_tables=block_tables, block_tables_tensor=block_tables_tensor, slots=slots, + cu_slots=cu_slots, max_input_length=max_input_length, max_current_length=max_current_length, prefilling=self.prefilling, @@ -653,9 +708,7 @@ class FlashCausalLMBatch(Batch): for b in batches: total_batch_size += len(b) max_blocks = max(max_blocks, b.max_blocks) - # If `b` is prefilling and was just filtered, `b.slots` is None - # `total_slots` is not used if any of the batches is prefilling - total_slots += len(b.slots) if not b.prefilling else 0 + total_slots += len(b.slots) num_blocks += b.num_blocks speculative_length = ( b.speculative_ids.shape[1] if b.speculative_ids is not None else 0 @@ -675,11 +728,12 @@ class FlashCausalLMBatch(Batch): ) prefilling = prefilling or b.prefilling + slots = batches[0].slots.new_empty(total_slots) + cu_slots = torch.zeros(total_batch_size + 1, dtype=torch.int64) if prefilling: input_ids = [] # These values will be set by `FlashCausalLMBatch.prepare_for_prefill` position_ids = None - slots = None slot_indices = None cache_lengths_tensor = None input_lengths_tensor = None @@ -688,7 +742,6 @@ class FlashCausalLMBatch(Batch): else: input_ids = batches[0].input_ids.new_empty(total_batch_size) position_ids = batches[0].position_ids.new_empty(total_batch_size) - slots = batches[0].slots.new_empty(total_slots) slot_indices = batches[0].slot_indices.new_empty(total_batch_size) input_lengths_tensor = batches[0].input_lengths_tensor.new_empty( total_batch_size @@ -764,13 +817,16 @@ class FlashCausalLMBatch(Batch): ] = batch.block_tables_tensor[:, :max_blocks] prompt_lengths_tensor[start_index:end_index] = batch.prompt_lengths_tensor - if not prefilling: - slots_start_index = cumulative_slots - slots_end_index = cumulative_slots + len(batch.slots) + slots_start_index = cumulative_slots + slots_end_index = cumulative_slots + len(batch.slots) + slots[slots_start_index:slots_end_index] = batch.slots + cu_slots[start_index + 1 : end_index + 1] = ( + batch.cu_slots[1:] + cumulative_slots + ) + if not prefilling: input_ids[start_index:end_index] = batch.input_ids position_ids[start_index:end_index] = batch.position_ids - slots[slots_start_index:slots_end_index] = batch.slots slot_indices[start_index:end_index] = ( batch.slot_indices + cumulative_slots ) @@ -792,9 +848,6 @@ class FlashCausalLMBatch(Batch): batch.adapter_meta.adapter_segments, batch.adapter_meta.segment_indices, ) - - # Update - cumulative_slots += len(batch.slots) else: if isinstance(batch.input_ids, torch.Tensor): batch.input_ids = batch.input_ids.view(-1, 1).tolist() @@ -819,6 +872,7 @@ class FlashCausalLMBatch(Batch): top_n_tokens.extend(batch.top_n_tokens) # Update + cumulative_slots += len(batch.slots) cumulative_batch_size += len(batch) next_token_chooser = HeterogeneousNextTokenChooser.from_pb( @@ -858,6 +912,7 @@ class FlashCausalLMBatch(Batch): cache_lengths=cache_lengths, cache_lengths_tensor=cache_lengths_tensor, slots=slots, + cu_slots=cu_slots, max_input_length=max_input_length, max_current_length=max_current_length, prefilling=prefilling, @@ -890,15 +945,50 @@ class FlashCausalLMBatch(Batch): # it simplifies everything assert self.speculative_ids is None + device = self.block_tables_tensor.device + + if isinstance(self.input_ids, list): + if len(self) > 1: + input_ids = np.concatenate(self.input_ids, dtype=np.int64) + else: + input_ids = self.input_ids[0] + self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) + + self.input_lengths_tensor = torch.tensor( + self.input_lengths, dtype=torch.int32, device=device + ) + self.cu_seqlen_prefill = torch.nn.functional.pad( + torch.cumsum(self.input_lengths_tensor, dim=0), (1, 0) + ).to(torch.int32) + self.cache_lengths_tensor = torch.tensor( + self.cache_lengths, dtype=torch.int32, device=device + ) + + # If the device supports Triton, we can use a fused kernel + if has_triton(): + self.position_ids = torch.empty( + len(self.input_ids), dtype=torch.int32, device=device + ) + self.slot_indices = torch.empty( + len(self.input_ids), dtype=torch.int64, device=device + ) + cu_slots_gpu = self.cu_slots.to(device) + + prepare_position_slot_ids( + self.max_input_length, + self.cache_lengths_tensor, + self.cu_seqlen_prefill, + cu_slots_gpu, + self.position_ids, + self.slot_indices, + ) + sliding_window = get_sliding_windows() position_ids = [] - cu_seqlen_prefill = [0] slot_indices = [] prefill_cache_indices = [] all_prefill_logprobs = True no_prefill_logprobs = True - prefill_head_indices = [] - prefill_next_token_indices = [] prefill_cu_outlens = [0] # Cumulative length @@ -906,7 +996,6 @@ class FlashCausalLMBatch(Batch): cumulative_slot_tokens = 0 prefill_out_cumulative_length = 0 - slots = [] adapter_indices_list = [] adapter_set = set() @@ -928,30 +1017,33 @@ class FlashCausalLMBatch(Batch): ) ): next_chunk_length = input_length - # Position ids - request_position_ids = torch.arange( - cache_length, cache_length + input_length, dtype=torch.int32 - ) - position_ids.append(request_position_ids) - # Add cumulative lengths of all previous inputs - cu_seqlen_prefill.append(cumulative_length + input_length) + if not has_triton(): + # Position ids + request_position_ids = torch.arange( + cache_length, cache_length + input_length, dtype=torch.int32 + ) + position_ids.append(request_position_ids) - if not r.slots: - request_slots = [ - s - for b in blocks - for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE) - ] - else: - request_slots = r.slots + if not r.slots: + request_slots = [ + s + for b in blocks + for s in range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE) + ] + else: + request_slots = r.slots - request_slots = request_slots[cache_length:] - request_slot_indices = torch.arange( - cumulative_slot_tokens, - cumulative_slot_tokens + input_length, - dtype=torch.int64, - ) + request_slot_indices = torch.arange( + cache_length + cumulative_slot_tokens, + cache_length + cumulative_slot_tokens + input_length, + dtype=torch.int64, + ) + + slot_indices.append(request_slot_indices) + + # Update + cumulative_slot_tokens += len(request_slots) # Create tensor to slice into the kv tensor in prefill if sliding_window is not None: @@ -968,83 +1060,102 @@ class FlashCausalLMBatch(Batch): no_prefill_logprobs = no_prefill_logprobs and not prefill_logprobs if prefill_logprobs: - prefill_head_indices.append( - torch.arange( - cumulative_length, - cumulative_length + input_length, - dtype=torch.int64, - ) - ) - prefill_next_token_indices.append( - prefill_out_cumulative_length + input_length - 1 - ) prefill_cu_outlens.append(prefill_out_cumulative_length + input_length) prefill_out_cumulative_length += input_length else: - prefill_head_indices.append( - torch.tensor( - [cumulative_length + input_length - 1], - dtype=torch.int64, - ) - ) - prefill_next_token_indices.append(prefill_out_cumulative_length) prefill_cu_outlens.append(prefill_out_cumulative_length + 1) prefill_out_cumulative_length += 1 - slots.extend(request_slots) - slot_indices.append(request_slot_indices) - if sliding_window is not None: prefill_cache_indices.append(request_prefill_cache_indices) ADAPTER_TO_INDEX = get_adapter_to_index() - adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0) - adapter_indices_list.append(torch.full((next_chunk_length,), adapter_index)) - adapter_set.add(adapter_index) + if ADAPTER_TO_INDEX: + adapter_index = ADAPTER_TO_INDEX.get(r.adapter_id, 0) + adapter_indices_list.append( + torch.full((next_chunk_length,), adapter_index) + ) + adapter_set.add(adapter_index) # Update cumulative_length += next_chunk_length - cumulative_slot_tokens += len(request_slots) - device = self.block_tables_tensor.device + if not all_prefill_logprobs and not no_prefill_logprobs: + prefill_head_indices = [] + prefill_next_token_indices = [] - if isinstance(self.input_ids, list): - if len(self) > 1: - input_ids = np.concatenate(self.input_ids, dtype=np.int64) - else: - input_ids = self.input_ids[0] - self.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) + # Cumulative length + cumulative_length = 0 + prefill_out_cumulative_length = 0 + + for i, ( + r, + input_length, + request_prefilling, + ) in enumerate( + zip( + self.requests, + self.input_lengths, + self.prefilling_mask, + ) + ): + # Prefill logprobs is ignored if the request is done prefilling + prefill_logprobs = r.prefill_logprobs and request_prefilling + + if prefill_logprobs: + prefill_head_indices.append( + torch.arange( + cumulative_length, + cumulative_length + input_length, + dtype=torch.int64, + ) + ) + prefill_next_token_indices.append( + prefill_out_cumulative_length + input_length - 1 + ) + prefill_out_cumulative_length += input_length + else: + prefill_head_indices.append( + torch.tensor( + [cumulative_length + input_length - 1], + dtype=torch.int64, + ) + ) + prefill_next_token_indices.append(prefill_out_cumulative_length) + prefill_out_cumulative_length += 1 + + # Update + cumulative_length += input_length if len(self) > 1: - position_ids = torch.cat(position_ids) - slot_indices = torch.cat(slot_indices) + if position_ids: + position_ids = torch.cat(position_ids) + if slot_indices: + slot_indices = torch.cat(slot_indices) if sliding_window is not None: prefill_cache_indices = torch.cat(prefill_cache_indices) else: - position_ids = position_ids[0] - slot_indices = slot_indices[0] + if position_ids: + position_ids = position_ids[0] + if slot_indices: + slot_indices = slot_indices[0] if sliding_window is not None: prefill_cache_indices = prefill_cache_indices[0] + if not has_triton(): + self.position_ids = position_ids.to(device) + self.slot_indices = slot_indices.to(device) + self.prefill_cu_outlens = prefill_cu_outlens - cu_seqlen_prefill = torch.tensor( - cu_seqlen_prefill, device=device, dtype=torch.int32 - ) - self.cu_seqlen_prefill = cu_seqlen_prefill - self.position_ids = position_ids.to(device) - self.slot_indices = slot_indices.to(device) self.prefill_cache_indices = ( prefill_cache_indices.to(device) if sliding_window is not None else None ) - self.input_lengths_tensor = torch.tensor( - self.input_lengths, dtype=torch.int32, device=device - ) if all_prefill_logprobs: prefill_head_indices = None - prefill_next_token_indices = cu_seqlen_prefill[1:] - 1 + prefill_next_token_indices = self.cu_seqlen_prefill[1:] - 1 elif no_prefill_logprobs: - prefill_head_indices = cu_seqlen_prefill[1:] - 1 + prefill_head_indices = self.cu_seqlen_prefill[1:] - 1 prefill_next_token_indices = None else: prefill_head_indices = torch.cat(prefill_head_indices).to(device) @@ -1054,17 +1165,21 @@ class FlashCausalLMBatch(Batch): self.prefill_head_indices = prefill_head_indices self.prefill_next_token_indices = prefill_next_token_indices - self.slots = torch.tensor(slots, dtype=torch.int64, device=device) - self.cache_lengths_tensor = torch.tensor( - self.cache_lengths, dtype=torch.int32, device=device - ) - adapter_indices = torch.cat(adapter_indices_list).to( - dtype=torch.int64, device=device - ) - adapter_segments, adapter_segment_indices = find_segments(adapter_indices) + + if adapter_set: + adapter_indices = torch.cat(adapter_indices_list).to( + dtype=torch.int64, device=device + ) + adapter_segments, adapter_segment_indices = find_segments(adapter_indices) + else: + adapter_indices = torch.zeros_like(self.input_ids) + adapter_segments = [0, len(adapter_indices)] + adapter_segment_indices = [len(adapter_indices) - 1] + adapter_segments = torch.tensor( adapter_segments, dtype=torch.int32, device=device ) + self.adapter_meta = AdapterBatchMetadata( adapter_indices=adapter_indices, adapter_set=adapter_set, @@ -1288,6 +1403,9 @@ class FlashCausalLM(Model): block_tables=block_tables, input_lengths=input_lengths, cache_lengths=cache_lengths, + input_lengths_tensor=input_lengths_tensor, + cache_lengths_tensor=cache_lengths_tensor, + max_current_length=max_s, ) from text_generation_server.layers.attention.flashinfer import ( create_decode_state_cuda_graphs, @@ -1621,6 +1739,9 @@ class FlashCausalLM(Model): block_tables=block_tables, input_lengths=batch.input_lengths, cache_lengths=batch.cache_lengths, + input_lengths_tensor=batch.input_lengths_tensor, + cache_lengths_tensor=batch.cache_lengths_tensor, + max_current_length=batch.max_current_length, ) with self._forward_context( block_tables=block_tables, @@ -1661,6 +1782,9 @@ class FlashCausalLM(Model): block_tables=block_tables, input_lengths=batch.input_lengths, cache_lengths=batch.cache_lengths, + input_lengths_tensor=batch.input_lengths_tensor, + cache_lengths_tensor=batch.cache_lengths_tensor, + max_current_length=batch.max_current_length, ) # assert block_tables.shape[0] >= slots.shape[0] cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables @@ -1756,7 +1880,6 @@ class FlashCausalLM(Model): else: prefill_logprobs = None next_token_logits = out - next_adapter_indices = batch.adapter_meta.adapter_indices finished_prefilling = True next_chunk_lengths = [] @@ -1827,13 +1950,12 @@ class FlashCausalLM(Model): # Since we are done prefilling, all the tensors that were concatenating values for all the requests # instantly become of shape [BATCH_SIZE] if prefill and finished_prefilling: - next_position_ids = batch.position_ids.new_empty(len(batch)) - batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1] - next_adapter_indices = batch.adapter_meta.adapter_indices.new_empty( - len(batch) - ) - elif not prefill: - next_position_ids = batch.position_ids + indices = batch.cu_seqlen_prefill[1:] - 1 + batch.position_ids = batch.position_ids[indices] + batch.slot_indices = batch.slot_indices[indices] + batch.adapter_meta.adapter_indices = batch.adapter_meta.adapter_indices[ + indices + ] # Zipped iterator iterator = zip( @@ -1852,8 +1974,10 @@ class FlashCausalLM(Model): # It is faster if we delay this sync for the maximum amount of time # For each member of the batch - index = 0 # Cumulative length + cu_accepted_ids = torch.nn.functional.pad( + torch.cumsum(accepted_ids, dim=0), (1, 0) + ) cumulative_length = 0 for i, ( request, @@ -1865,21 +1989,6 @@ class FlashCausalLM(Model): request_was_prefilling, request_is_prefilling, ) in enumerate(iterator): - if prefill and finished_prefilling: - # Indexing metadata - _start_index = cumulative_length - end_index = cumulative_length + input_length - - # Initialize position_ids - # In decode, we do not need this as we can just increment position ids - next_position_ids[i] = batch.position_ids[end_index - 1] - - # Initialize adapter indices - # In decode, we only have one token per row in the batch, so grab last index - next_adapter_indices[i] = batch.adapter_meta.adapter_indices[ - end_index - 1 - ] - # Used to gather prefill logprobs # Copy batch.all_input_ids_tensor to prefill_token_indices if request.prefill_logprobs and request_was_prefilling: @@ -1898,25 +2007,39 @@ class FlashCausalLM(Model): # Set prefill_tokens_indices to the correct slice prefill_tokens_indices = ids - if not request_is_prefilling: + # If the device does not support triton, we copy one by one + if not request_is_prefilling and not has_triton(): # Only save tokens if we are done prefilling for this request - for j in range(n_accepted_ids): - batch.all_input_ids_tensor[i, cache_length + input_length + j] = ( - next_input_ids[index + j] - ) - index += n_accepted_ids + batch.all_input_ids_tensor[ + i, + batch.cache_lengths_tensor[i] + + batch.input_lengths[i] : batch.cache_lengths_tensor[i] + + batch.input_lengths[i] + + accepted_ids[i], + ] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]] cumulative_length += input_length + # If the device support triton, we can use a fused kernel + if has_triton(): + copy_next_input_ids_inplace( + speculate + 1, + batch.all_input_ids_tensor, + batch.cache_lengths_tensor, + batch.input_lengths_tensor, + batch.prompt_lengths_tensor, + next_input_ids, + cu_accepted_ids, + ) + # Update values # These values can be updated without a GPU -> CPU sync if not prefill or (prefill and finished_prefilling): - batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1] + batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1] batch.speculative_ids = speculative_ids - batch.position_ids = next_position_ids + accepted_ids - batch.cache_lengths_tensor += batch.input_lengths_tensor - batch.input_lengths_tensor = accepted_ids.to(dtype=torch.int32) + batch.position_ids += accepted_ids + batch.cache_lengths_tensor += batch.input_lengths_tensor + accepted_ids - 1 + batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor) batch.slot_indices += accepted_ids - batch.adapter_meta.adapter_indices = next_adapter_indices if prefill and prefill_logprobs: # Get prefill logprobs with inplace softmax (avoid copying the `out` tensor (max_batch_prefill_tokens * vocab_size)) @@ -2093,8 +2216,10 @@ class FlashCausalLM(Model): # processing stopped = False new_input_length = next_chunk_lengths[i] + new_cache_length = cache_length + input_length else: - new_input_length = n_accepted_ids + new_input_length = 1 + new_cache_length = cache_length + input_length + n_accepted_ids - 1 # Append next token to all tokens next_token_texts = [] left = 0 @@ -2206,12 +2331,10 @@ class FlashCausalLM(Model): # Update values index += n_accepted_ids - current_cache_length = cache_length + input_length - batch.cache_lengths[i] = current_cache_length - current_input_length = new_input_length - batch.max_input_length = max(batch.max_input_length, current_input_length) - batch.input_lengths[i] = current_input_length - current_length = current_cache_length + current_input_length + batch.cache_lengths[i] = new_cache_length + batch.max_input_length = max(batch.max_input_length, new_input_length) + batch.input_lengths[i] = new_input_length + current_length = new_cache_length + new_input_length batch.max_current_length = max(batch.max_current_length, current_length) batch.prefix_offsets[i] = prefix_offset @@ -2258,11 +2381,6 @@ class FlashCausalLM(Model): state=( state if state is not None else self.prefill_with_paged_kv_state ), - # block_tables=block_tables_to_ragged( - # block_tables=block_tables, - # input_lengths=input_lengths, - # cache_lengths=cache_lengths, - # ), block_tables=block_tables, cu_seqlens=cu_seqlen_prefill, input_lengths=input_lengths_tensor + cache_lengths_tensor, @@ -2287,23 +2405,3 @@ class FlashCausalLM(Model): dtype=self.dtype, window_left=self.sliding_window, ) - - -def block_tables_to_ragged( - *, block_tables: torch.Tensor, input_lengths: List[int], cache_lengths: List[int] -) -> torch.Tensor: - """Convert block table to ragged format compatible with FlashInfer.""" - assert len(input_lengths) == len(cache_lengths) - - total_len = sum(input_lengths) + sum(cache_lengths) - block_tables_ragged = torch.empty( - total_len, dtype=torch.int32, device=block_tables.device - ) - - offset = 0 - for i, (input_length, cache_length) in enumerate(zip(input_lengths, cache_lengths)): - seq_len = cache_length + input_length - block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len] - offset += seq_len - - return block_tables_ragged diff --git a/server/text_generation_server/models/metadata_kernels.py b/server/text_generation_server/models/metadata_kernels.py new file mode 100644 index 00000000..b3e2160d --- /dev/null +++ b/server/text_generation_server/models/metadata_kernels.py @@ -0,0 +1,347 @@ +import torch +import triton + +import triton.language as tl + +from loguru import logger +from typing import List, Optional +from torch.utils._triton import has_triton as has_triton_torch + +from text_generation_server.utils.import_utils import ( + SYSTEM, +) +from text_generation_server.utils.log import log_master + +_HAS_TRITON: Optional[bool] = None + + +def has_triton(): + global _HAS_TRITON + if _HAS_TRITON is None: + # FIXME: it seems that has_triton_torch is bugged on RocM + # For now, only accept cuda + _HAS_TRITON = has_triton_torch() if SYSTEM == "cuda" else False + if _HAS_TRITON: + log_master(logger.info, "Using optimized Triton indexing kernels.") + + return _HAS_TRITON + + +def block_tables_to_padded( + max_blocks: int, + cu_seqlen: torch.Tensor, + block_tables: torch.Tensor, + block_tables_ragged: torch.Tensor, +): + def grid(meta): + return ( + triton.cdiv(max_blocks, meta["BLOCK_SIZE"]), + len(block_tables), + ) + + triton_block_tables_to_padded[grid]( + cu_seqlen, + block_tables, + block_tables_ragged, + block_tables.shape[1], + BLOCK_SIZE=256, + ) + + +def block_tables_to_ragged( + *, + block_tables: torch.Tensor, + input_lengths: List[int], + cache_lengths: List[int], + input_lengths_tensor: torch.Tensor, + cache_lengths_tensor: torch.Tensor, + max_current_length: int +) -> torch.Tensor: + """Convert block table to ragged format compatible with FlashInfer.""" + assert len(input_lengths) == len(cache_lengths) + + total_len = sum(input_lengths) + sum(cache_lengths) + block_tables_ragged = torch.empty( + total_len, dtype=torch.int32, device=block_tables.device + ) + + if has_triton(): + cu_seqlen = torch.nn.functional.pad( + torch.cumsum(input_lengths_tensor + cache_lengths_tensor, dim=0), (1, 0) + ) + + def grid(meta): + return ( + triton.cdiv(max_current_length, meta["BLOCK_SIZE"]), + len(cache_lengths), + ) + + triton_block_tables_to_ragged[grid]( + cu_seqlen, + block_tables, + block_tables_ragged, + block_tables.shape[1], + BLOCK_SIZE=256, + ) + else: + offset = 0 + for i, (input_length, cache_length) in enumerate( + zip(input_lengths, cache_lengths) + ): + seq_len = cache_length + input_length + block_tables_ragged[offset : offset + seq_len] = block_tables[i][:seq_len] + offset += seq_len + + return block_tables_ragged + + +def copy_next_input_ids_inplace( + max_next_input_ids: int, + all_input_ids: torch.Tensor, + cache_lengths: torch.Tensor, + input_lengths: torch.Tensor, + prompt_lengths: torch.Tensor, + next_input_ids: torch.Tensor, + cu_accepted_ids: torch.Tensor, +): + def grid(meta): + return ( + triton.cdiv(max_next_input_ids, meta["BLOCK_SIZE"]), + len(all_input_ids), + ) + + triton_copy_next_input_ids_inplace[grid]( + all_input_ids, + cache_lengths, + input_lengths, + prompt_lengths, + next_input_ids, + cu_accepted_ids, + all_input_ids.shape[1], + BLOCK_SIZE=16, + ) + + +def prepare_position_slot_ids( + max_input_length: int, + cache_lengths: torch.Tensor, + cu_seqlen: torch.Tensor, + cu_slots: torch.Tensor, + position_ids: torch.Tensor, + slot_indices: torch.Tensor, +): + def grid(meta): + return ( + triton.cdiv(max_input_length, meta["BLOCK_SIZE"]), + len(cache_lengths), + ) + + triton_prepare_position_slot_ids[grid]( + cache_lengths, cu_seqlen, cu_slots, position_ids, slot_indices, BLOCK_SIZE=256 + ) + + +def slots_filtering( + max_slots: int, + slots: torch.Tensor, + filtered_slots: torch.Tensor, + cu_slots: torch.Tensor, + slots_start: torch.Tensor, +): + def grid(meta): + return ( + triton.cdiv(max_slots, meta["BLOCK_SIZE"]), + len(slots_start), + ) + + triton_slots_filtering[grid]( + slots, filtered_slots, slots_start, cu_slots, BLOCK_SIZE=256 + ) + + +@triton.jit +def triton_slots_filtering( + # Inputs + slots_ptr, + filtered_slots_ptr, + slots_start_ptr, + cu_slots_ptr, + # Const values + BLOCK_SIZE: "tl.constexpr", +): + # Position in block_tables_ragged.numel() / BLOCK_SIZE + pid = tl.program_id(axis=0) + # Position in batch + bid = tl.program_id(axis=1) + + block_start = pid * BLOCK_SIZE + block_arange = block_start + tl.arange(0, BLOCK_SIZE) + + filter_start = tl.load(slots_start_ptr + bid) + + slot_start = tl.load(cu_slots_ptr + bid) + slot_end = tl.load(cu_slots_ptr + bid + 1) + + mask = (slot_start + block_arange) < slot_end + + slots = tl.load(slots_ptr + filter_start + block_arange, mask=mask) + tl.store(filtered_slots_ptr + slot_start + block_arange, slots, mask=mask) + + +@triton.jit +def triton_block_tables_to_padded( + # Inputs + cu_seqlen_ptr, + # Outputs + block_tables_ptr, + block_tables_ragged_ptr, + # Stride + stride_block_tables, + # Const values + BLOCK_SIZE: "tl.constexpr", +): + # Position in block_tables_ragged.numel() / BLOCK_SIZE + pid = tl.program_id(axis=0) + # Position in batch + bid = tl.program_id(axis=1) + + block_start = pid * BLOCK_SIZE + block_arange = block_start + tl.arange(0, BLOCK_SIZE) + + seq_start = tl.load(cu_seqlen_ptr + bid) + seq_end = tl.load(cu_seqlen_ptr + bid + 1) + + mask = (seq_start + block_arange) < seq_end + + blocks = tl.load(block_tables_ragged_ptr + seq_start + block_arange, mask=mask) + tl.store( + block_tables_ptr + bid * stride_block_tables + block_arange, blocks, mask=mask + ) + + +@triton.jit +def triton_block_tables_to_ragged( + # Inputs + cu_seqlen_ptr, + # Outputs + block_tables_ptr, + block_tables_ragged_ptr, + # Stride + stride_block_tables, + # Const values + BLOCK_SIZE: "tl.constexpr", +): + # Position in block_tables_ragged.numel() / BLOCK_SIZE + pid = tl.program_id(axis=0) + # Position in batch + bid = tl.program_id(axis=1) + + block_start = pid * BLOCK_SIZE + block_arange = block_start + tl.arange(0, BLOCK_SIZE) + + seq_start = tl.load(cu_seqlen_ptr + bid) + seq_end = tl.load(cu_seqlen_ptr + bid + 1) + + mask = (seq_start + block_arange) < seq_end + + blocks = tl.load( + block_tables_ptr + bid * stride_block_tables + block_arange, mask=mask + ) + tl.store(block_tables_ragged_ptr + seq_start + block_arange, blocks, mask=mask) + + +@triton.jit +def triton_copy_next_input_ids_inplace( + # Inputs + all_input_ids_ptr, + cache_lengths_ptr, + input_lengths_ptr, + prompt_lengths_ptr, + next_input_ids_ptr, + cu_accepted_ids_ptr, + # Stride + stride_all_input_ids, + # Const values + BLOCK_SIZE: "tl.constexpr", +): + # Position in max_accepted_ids / BLOCK_SIZE + pid = tl.program_id(axis=0) + # Position in batch + bid = tl.program_id(axis=1) + + block_start = pid * BLOCK_SIZE + block_arange = block_start + tl.arange(0, BLOCK_SIZE) + + # Used for correctly indexing in all_input_ids + cache_length = tl.load(cache_lengths_ptr + bid) + input_length = tl.load(input_lengths_ptr + bid) + prompt_length = tl.load(prompt_lengths_ptr + bid) + + # Start/End of next_input_ids for this request + next_input_ids_start = tl.load(cu_accepted_ids_ptr + bid) + next_input_ids_end = tl.load(cu_accepted_ids_ptr + bid + 1) + + # Mask values out of range + mask = (next_input_ids_start + block_arange) < next_input_ids_end + + # Mask values for request still prefilling + decode_mask = (cache_length + input_length + block_arange) >= prompt_length + + mask = mask & decode_mask + + # Load this request next input ids + next_input_ids = tl.load( + next_input_ids_ptr + next_input_ids_start + block_arange, mask=mask + ) + + # Store in all_input_ids, since it is a 2D tensor, apply stride * bid + tl.store( + all_input_ids_ptr + + stride_all_input_ids * bid + + cache_length + + input_length + + block_arange, + next_input_ids, + mask=mask, + ) + + +@triton.jit +def triton_prepare_position_slot_ids( + # Inputs + cache_lengths_ptr, + cu_seqlen_ptr, + cu_slots_ptr, + # Outputs + position_ids_ptr, + slot_indices_ptr, + # Const values + BLOCK_SIZE: "tl.constexpr", +): + # Position in max_input_length / BLOCK_SIZE + pid = tl.program_id(axis=0) + # Position in batch + bid = tl.program_id(axis=1) + + block_start = pid * BLOCK_SIZE + block_arange = block_start + tl.arange(0, BLOCK_SIZE) + + cache_length = tl.load(cache_lengths_ptr + bid) + + seq_start = tl.load(cu_seqlen_ptr + bid) + seq_end = tl.load(cu_seqlen_ptr + bid + 1) + + slot_start = tl.load(cu_slots_ptr + bid) + + mask = (seq_start + block_arange) < seq_end + + tl.store( + position_ids_ptr + seq_start + block_arange, + cache_length + block_arange, + mask=mask, + ) + tl.store( + slot_indices_ptr + seq_start + block_arange, + slot_start + cache_length + block_arange, + mask=mask, + ) diff --git a/server/text_generation_server/models/mllama_causal_lm.py b/server/text_generation_server/models/mllama_causal_lm.py index 6399f92c..28e7489e 100644 --- a/server/text_generation_server/models/mllama_causal_lm.py +++ b/server/text_generation_server/models/mllama_causal_lm.py @@ -14,11 +14,9 @@ from transformers import ( from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch, VlmCausalLM from text_generation_server.pb import generate_pb2 -from text_generation_server.models.flash_causal_lm import ( - block_tables_to_ragged, -) from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION from text_generation_server.layers.attention import Seqlen +from text_generation_server.models.metadata_kernels import block_tables_to_ragged tracer = trace.get_tracer(__name__) @@ -283,6 +281,9 @@ class MllamaCausalLM(VlmCausalLM): block_tables=block_tables, input_lengths=batch.input_lengths, cache_lengths=batch.cache_lengths, + input_lengths_tensor=batch.input_lengths_tensor, + cache_lengths_tensor=batch.cache_lengths_tensor, + max_current_length=batch.max_current_length, ) with self._forward_context( block_tables=block_tables, @@ -338,6 +339,9 @@ class MllamaCausalLM(VlmCausalLM): block_tables=block_tables, input_lengths=batch.input_lengths, cache_lengths=batch.cache_lengths, + input_lengths_tensor=batch.input_lengths_tensor, + cache_lengths_tensor=batch.cache_lengths_tensor, + max_current_length=batch.max_current_length, ) cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables else: diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 150cf0d0..4bbddcfb 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -11,12 +11,12 @@ from text_generation_server.pb import generate_pb2 from text_generation_server.models.flash_causal_lm import ( FlashCausalLMBatch, FlashCausalLM, - block_tables_to_ragged, ) from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION from text_generation_server.utils.log import log_master from transformers import AutoProcessor from text_generation_server.layers.attention import Seqlen +from text_generation_server.models.metadata_kernels import block_tables_to_ragged tracer = trace.get_tracer(__name__) @@ -363,6 +363,9 @@ class VlmCausalLM(FlashCausalLM): block_tables=block_tables, input_lengths=batch.input_lengths, cache_lengths=batch.cache_lengths, + input_lengths_tensor=batch.input_lengths_tensor, + cache_lengths_tensor=batch.cache_lengths_tensor, + max_current_length=batch.max_current_length, ) with self._forward_context( block_tables=block_tables, @@ -411,6 +414,9 @@ class VlmCausalLM(FlashCausalLM): block_tables=block_tables, input_lengths=batch.input_lengths, cache_lengths=batch.cache_lengths, + input_lengths_tensor=batch.input_lengths_tensor, + cache_lengths_tensor=batch.cache_lengths_tensor, + max_current_length=batch.max_current_length, ) cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables else: