diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs index f8a10ca2..ebb12768 100644 --- a/backends/v3/src/backend.rs +++ b/backends/v3/src/backend.rs @@ -380,6 +380,7 @@ fn filter_send_generations(generations: Vec, entries: &mut IntMap "dropped").increment(1); }).unwrap_or(true); + tracing::debug!("Stopped iteration {stopped:?}"); if stopped { entries.remove(&id).expect("ID not found in entries. This is a bug."); } @@ -419,6 +420,7 @@ fn send_responses( // Create last Token let tokens_ = generation.tokens.expect("Non empty tokens in generation"); let n = tokens_.ids.len(); + tracing::info!("Received {n:?} tokens"); metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64); let mut iterator = tokens_ .ids @@ -435,6 +437,7 @@ fn send_responses( logprob, special, }; + tracing::info!("Sent token {token:?}"); let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) { top_tokens_ .ids diff --git a/backends/v3/src/client/grpc_client.rs b/backends/v3/src/client/grpc_client.rs index f0f21854..470527ba 100644 --- a/backends/v3/src/client/grpc_client.rs +++ b/backends/v3/src/client/grpc_client.rs @@ -120,7 +120,7 @@ impl Client { // Create requests while rest_tokens > 0 { let curr_tokens = min(max_tokens_per_request, rest_tokens); - let truncate = min(max_input_length, rest_tokens); + let truncate = max_input_length; let prefix_len = max_input_length.saturating_sub(max_prefill_tokens); let suffix_len = 0; diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index ff9879f3..f44aaffa 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -281,7 +281,10 @@ impl State { if total_tokens > token_budget { // Entry is over budget // Add it back to the front - tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate); + tracing::debug!( + "Over budget: {prefill_tokens} + {decode_tokens} + {} > {token_budget}", + self.speculate + ); self.entries.push_front((id, entry)); break 'entry_loop; } @@ -301,7 +304,10 @@ impl State { if (prefill_tokens + decode_tokens + self.speculate) > token_budget { // Entry is over budget // Add it back to the front - tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate); + tracing::debug!( + "Over budget: {prefill_tokens} + {decode_tokens} + {} > {token_budget}", + self.speculate + ); self.entries.push_front((id, entry)); break; } @@ -390,7 +396,9 @@ impl State { block_allocation.prefix_len, ), }; - let suffix_len = (slots.len() as u32).saturating_sub(prefix_len); + let suffix_len = (slots.len() as u32) + .saturating_sub(prefix_len) + .saturating_sub(prefill_token_budget); entry.block_allocation = block_allocation; diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 3844a990..80206032 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -156,6 +156,7 @@ class FlashCausalLMBatch(Batch): # Prefixes prefix_ids: List[List[int]] + suffix_ids: List[List[int]] # All tokens all_input_ids: List[List[int]] @@ -230,6 +231,7 @@ class FlashCausalLMBatch(Batch): read_offsets = [] all_input_ids = [] prefix_ids = [] + suffix_ids = [] requests_idx_mapping = {} all_prefill_logprobs = True @@ -269,20 +271,29 @@ class FlashCausalLMBatch(Batch): orig_input_length = len(tokenized_input) prefix_len = r.prefix_len - import ipdb - - ipdb.set_trace() + suffix_len = r.suffix_len assert ( prefix_len <= orig_input_length ), f"Prefix {prefix_len} vs input {orig_input_length}" + assert ( + suffix_len <= orig_input_length + ), f"suffix {suffix_len} vs input {orig_input_length}" if prefix_len == orig_input_length: assert prefix_len > 0 prefix_len -= 1 + assert ( + prefix_len + suffix_len < orig_input_length + ), f"Prefix+suffix are invalid {prefix_len} + {suffix_len} < {orig_input_length}" # Commented as it's costly. # log_master(logger.debug, "Tokenized input ids {tokenized_input}") prefix_ids.append(tokenized_input[:prefix_len]) tokenized_input = tokenized_input[prefix_len:] + if suffix_len > 0: + suffix_ids.append(tokenized_input[-suffix_len:]) + tokenized_input = tokenized_input[:-suffix_len] + else: + suffix_ids.append([]) input_length = len(tokenized_input) input_lengths.append(input_length) @@ -294,7 +305,7 @@ class FlashCausalLMBatch(Batch): # Position ids request_position_ids = torch.arange( - prefix_len, orig_input_length, dtype=torch.int32 + prefix_len, orig_input_length - suffix_len, dtype=torch.int32 ) position_ids.append(request_position_ids) @@ -390,7 +401,7 @@ class FlashCausalLMBatch(Batch): # Update cumulative_length += input_length cumulative_slot_tokens += slot_tokens - max_seqlen = max(max_seqlen, input_length) + max_seqlen = max(max_seqlen, orig_input_length) max_blocks = max(max_blocks, len(request_blocks)) max_length = max( max_length, input_length + max_new_tokens + speculative_length @@ -501,6 +512,7 @@ class FlashCausalLMBatch(Batch): all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, prefix_ids=prefix_ids, + suffix_ids=suffix_ids, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, @@ -558,6 +570,7 @@ class FlashCausalLMBatch(Batch): block_tables = [] all_input_ids = [] prefix_ids = [] + suffix_ids = [] input_lengths = [] prefix_lens = [] @@ -587,6 +600,7 @@ class FlashCausalLMBatch(Batch): all_input_ids.append(self.all_input_ids[idx]) prefix_ids.append(self.prefix_ids[idx]) + suffix_ids.append(self.suffix_ids[idx]) input_lengths.append(request_input_length) prefix_lens.append(prefix_len) @@ -678,6 +692,7 @@ class FlashCausalLMBatch(Batch): all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, prefix_ids=prefix_ids, + suffix_ids=suffix_ids, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, @@ -759,6 +774,7 @@ class FlashCausalLMBatch(Batch): prefix_lens = [] all_input_ids = [] prefix_ids = [] + suffix_ids = [] input_lengths = [] prefix_offsets = [] @@ -828,6 +844,7 @@ class FlashCausalLMBatch(Batch): prefix_lens.extend(batch.prefix_lens) all_input_ids.extend(batch.all_input_ids) prefix_ids.extend(batch.prefix_ids) + suffix_ids.extend(batch.suffix_ids) input_lengths.extend(batch.input_lengths) prefix_offsets.extend(batch.prefix_offsets) @@ -889,6 +906,7 @@ class FlashCausalLMBatch(Batch): all_input_ids=all_input_ids, all_input_ids_tensor=all_input_ids_tensor, prefix_ids=prefix_ids, + suffix_ids=suffix_ids, next_token_chooser=next_token_chooser, stopping_criterias=stopping_criterias, top_n_tokens=top_n_tokens, @@ -1453,6 +1471,8 @@ class FlashCausalLM(Model): input_ids = new_input_ids position_ids = new_position_ids + if any([len(suffix) for suffix in batch.suffix_ids]): + raise RuntimeError("Suffix not supported with medusa") else: input_ids = batch.input_ids position_ids = batch.position_ids @@ -1747,6 +1767,7 @@ class FlashCausalLM(Model): batch.stopping_criterias, batch.all_input_ids, batch.prefix_ids, + batch.suffix_ids, batch.next_token_chooser.do_sample, batch.next_token_chooser.seeds, batch.top_n_tokens, @@ -1765,6 +1786,7 @@ class FlashCausalLM(Model): stopping_criteria, all_input_ids, prefix_ids, + suffix_ids, do_sample, seed, top_n_tokens, @@ -1776,39 +1798,54 @@ class FlashCausalLM(Model): next_token_texts = [] left = 0 - if n_accepted_ids > 1: - log_master(logger.debug, f"speculated ids {n_accepted_ids - 1}") + if not suffix_ids: + if n_accepted_ids > 1: + log_master(logger.debug, f"Speculated ids {n_accepted_ids - 1}") - current_stopped = False - for j in range(index, index + n_accepted_ids): - # Generated token - next_token_id = next_token_ids[j] - all_input_ids.append(next_token_id) - next_token_text, prefix_offset, read_offset = self.decode_token( - all_input_ids, - prefix_offset, - read_offset, - ) - next_token_texts.append(next_token_text) + current_stopped = False + for j in range(index, index + n_accepted_ids): + # Generated token + next_token_id = next_token_ids[j] + all_input_ids.append(next_token_id) + next_token_text, prefix_offset, read_offset = self.decode_token( + all_input_ids, + prefix_offset, + read_offset, + ) + next_token_texts.append(next_token_text) - stop, reason = stopping_criteria( - next_token_id, - next_token_text, - ) + stop, reason = stopping_criteria( + next_token_id, + next_token_text, + ) - if stop: - left = index + n_accepted_ids - j - 1 - current_stopped = True - break - else: - current_stopped = False - stopped = stopped and current_stopped + if stop: + left = index + n_accepted_ids - j - 1 + current_stopped = True + break + else: + current_stopped = False - _next_token_ids = next_token_ids[index : index + n_accepted_ids - left] - _next_token_logprobs = next_token_logprobs[ - index : index + n_accepted_ids - left - ] - index += n_accepted_ids + stopped = stopped and current_stopped + + _next_token_ids = next_token_ids[index : index + n_accepted_ids - left] + _next_token_logprobs = next_token_logprobs[ + index : index + n_accepted_ids - left + ] + index += n_accepted_ids + req_prefill = prefill + else: + _input_ids = suffix_ids[: self.max_prefill_tokens] + n_accepted_ids = len(_input_ids) + # TODO fix logprobs. + _next_token_logprobs = [float("nan")] * n_accepted_ids + index += n_accepted_ids + suffix_ids = suffix_ids[n_accepted_ids:] + _next_token_ids = [] + _next_token_logprobs = [] + stop = False + reason = None + req_prefill = True # Shard generations # All generations will be appended in the rust sharded client @@ -1834,7 +1871,7 @@ class FlashCausalLM(Model): generated_text = None # Prefill - if prefill and request.prefill_logprobs: + if req_prefill and request.prefill_logprobs: out_start_index = batch.prefill_cu_outlens[i] out_end_index = batch.prefill_cu_outlens[i + 1] @@ -1896,6 +1933,7 @@ class FlashCausalLM(Model): top_tokens, ) + logger.info(f"Generation {generation}") generations.append(generation) # accept each new token for this specific request since we may @@ -1909,6 +1947,7 @@ class FlashCausalLM(Model): batch.input_lengths[i] = input_length + n_accepted_ids if batch.input_lengths[i] > batch.max_seqlen: batch.max_seqlen = batch.input_lengths[i] + batch.suffix_ids[i] = suffix_ids batch.prefix_offsets[i] = prefix_offset batch.read_offsets[i] = read_offset batch.all_input_ids[i] = all_input_ids diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 22871ec5..f4bf181c 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -90,6 +90,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) async def Warmup(self, request, context): + self.model.max_prefill_tokens = request.max_prefill_tokens if self.quantize in {"exl2", "gptq"}: try: # When using GPTQ, Exllama kernels need some global kernels