Tmp dump.

This commit is contained in:
Nicolas Patry 2024-09-03 12:30:12 +02:00
parent 2f0fde1055
commit c821a0ff76
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
5 changed files with 90 additions and 39 deletions

View File

@ -380,6 +380,7 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
tracing::error!("Entry response channel error."); tracing::error!("Entry response channel error.");
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1); metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
}).unwrap_or(true); }).unwrap_or(true);
tracing::debug!("Stopped iteration {stopped:?}");
if stopped { if stopped {
entries.remove(&id).expect("ID not found in entries. This is a bug."); entries.remove(&id).expect("ID not found in entries. This is a bug.");
} }
@ -419,6 +420,7 @@ fn send_responses(
// Create last Token // Create last Token
let tokens_ = generation.tokens.expect("Non empty tokens in generation"); let tokens_ = generation.tokens.expect("Non empty tokens in generation");
let n = tokens_.ids.len(); let n = tokens_.ids.len();
tracing::info!("Received {n:?} tokens");
metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64); metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64);
let mut iterator = tokens_ let mut iterator = tokens_
.ids .ids
@ -435,6 +437,7 @@ fn send_responses(
logprob, logprob,
special, special,
}; };
tracing::info!("Sent token {token:?}");
let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) { let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) {
top_tokens_ top_tokens_
.ids .ids

View File

@ -120,7 +120,7 @@ impl Client {
// Create requests // Create requests
while rest_tokens > 0 { while rest_tokens > 0 {
let curr_tokens = min(max_tokens_per_request, rest_tokens); let curr_tokens = min(max_tokens_per_request, rest_tokens);
let truncate = min(max_input_length, rest_tokens); let truncate = max_input_length;
let prefix_len = max_input_length.saturating_sub(max_prefill_tokens); let prefix_len = max_input_length.saturating_sub(max_prefill_tokens);
let suffix_len = 0; let suffix_len = 0;

View File

@ -281,7 +281,10 @@ impl State {
if total_tokens > token_budget { if total_tokens > token_budget {
// Entry is over budget // Entry is over budget
// Add it back to the front // Add it back to the front
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate); tracing::debug!(
"Over budget: {prefill_tokens} + {decode_tokens} + {} > {token_budget}",
self.speculate
);
self.entries.push_front((id, entry)); self.entries.push_front((id, entry));
break 'entry_loop; break 'entry_loop;
} }
@ -301,7 +304,10 @@ impl State {
if (prefill_tokens + decode_tokens + self.speculate) > token_budget { if (prefill_tokens + decode_tokens + self.speculate) > token_budget {
// Entry is over budget // Entry is over budget
// Add it back to the front // Add it back to the front
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate); tracing::debug!(
"Over budget: {prefill_tokens} + {decode_tokens} + {} > {token_budget}",
self.speculate
);
self.entries.push_front((id, entry)); self.entries.push_front((id, entry));
break; break;
} }
@ -390,7 +396,9 @@ impl State {
block_allocation.prefix_len, block_allocation.prefix_len,
), ),
}; };
let suffix_len = (slots.len() as u32).saturating_sub(prefix_len); let suffix_len = (slots.len() as u32)
.saturating_sub(prefix_len)
.saturating_sub(prefill_token_budget);
entry.block_allocation = block_allocation; entry.block_allocation = block_allocation;

View File

@ -156,6 +156,7 @@ class FlashCausalLMBatch(Batch):
# Prefixes # Prefixes
prefix_ids: List[List[int]] prefix_ids: List[List[int]]
suffix_ids: List[List[int]]
# All tokens # All tokens
all_input_ids: List[List[int]] all_input_ids: List[List[int]]
@ -230,6 +231,7 @@ class FlashCausalLMBatch(Batch):
read_offsets = [] read_offsets = []
all_input_ids = [] all_input_ids = []
prefix_ids = [] prefix_ids = []
suffix_ids = []
requests_idx_mapping = {} requests_idx_mapping = {}
all_prefill_logprobs = True all_prefill_logprobs = True
@ -269,20 +271,29 @@ class FlashCausalLMBatch(Batch):
orig_input_length = len(tokenized_input) orig_input_length = len(tokenized_input)
prefix_len = r.prefix_len prefix_len = r.prefix_len
import ipdb suffix_len = r.suffix_len
ipdb.set_trace()
assert ( assert (
prefix_len <= orig_input_length prefix_len <= orig_input_length
), f"Prefix {prefix_len} vs input {orig_input_length}" ), f"Prefix {prefix_len} vs input {orig_input_length}"
assert (
suffix_len <= orig_input_length
), f"suffix {suffix_len} vs input {orig_input_length}"
if prefix_len == orig_input_length: if prefix_len == orig_input_length:
assert prefix_len > 0 assert prefix_len > 0
prefix_len -= 1 prefix_len -= 1
assert (
prefix_len + suffix_len < orig_input_length
), f"Prefix+suffix are invalid {prefix_len} + {suffix_len} < {orig_input_length}"
# Commented as it's costly. # Commented as it's costly.
# log_master(logger.debug, "Tokenized input ids {tokenized_input}") # log_master(logger.debug, "Tokenized input ids {tokenized_input}")
prefix_ids.append(tokenized_input[:prefix_len]) prefix_ids.append(tokenized_input[:prefix_len])
tokenized_input = tokenized_input[prefix_len:] tokenized_input = tokenized_input[prefix_len:]
if suffix_len > 0:
suffix_ids.append(tokenized_input[-suffix_len:])
tokenized_input = tokenized_input[:-suffix_len]
else:
suffix_ids.append([])
input_length = len(tokenized_input) input_length = len(tokenized_input)
input_lengths.append(input_length) input_lengths.append(input_length)
@ -294,7 +305,7 @@ class FlashCausalLMBatch(Batch):
# Position ids # Position ids
request_position_ids = torch.arange( request_position_ids = torch.arange(
prefix_len, orig_input_length, dtype=torch.int32 prefix_len, orig_input_length - suffix_len, dtype=torch.int32
) )
position_ids.append(request_position_ids) position_ids.append(request_position_ids)
@ -390,7 +401,7 @@ class FlashCausalLMBatch(Batch):
# Update # Update
cumulative_length += input_length cumulative_length += input_length
cumulative_slot_tokens += slot_tokens cumulative_slot_tokens += slot_tokens
max_seqlen = max(max_seqlen, input_length) max_seqlen = max(max_seqlen, orig_input_length)
max_blocks = max(max_blocks, len(request_blocks)) max_blocks = max(max_blocks, len(request_blocks))
max_length = max( max_length = max(
max_length, input_length + max_new_tokens + speculative_length max_length, input_length + max_new_tokens + speculative_length
@ -501,6 +512,7 @@ class FlashCausalLMBatch(Batch):
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor, all_input_ids_tensor=all_input_ids_tensor,
prefix_ids=prefix_ids, prefix_ids=prefix_ids,
suffix_ids=suffix_ids,
next_token_chooser=next_token_chooser, next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens, top_n_tokens=top_n_tokens,
@ -558,6 +570,7 @@ class FlashCausalLMBatch(Batch):
block_tables = [] block_tables = []
all_input_ids = [] all_input_ids = []
prefix_ids = [] prefix_ids = []
suffix_ids = []
input_lengths = [] input_lengths = []
prefix_lens = [] prefix_lens = []
@ -587,6 +600,7 @@ class FlashCausalLMBatch(Batch):
all_input_ids.append(self.all_input_ids[idx]) all_input_ids.append(self.all_input_ids[idx])
prefix_ids.append(self.prefix_ids[idx]) prefix_ids.append(self.prefix_ids[idx])
suffix_ids.append(self.suffix_ids[idx])
input_lengths.append(request_input_length) input_lengths.append(request_input_length)
prefix_lens.append(prefix_len) prefix_lens.append(prefix_len)
@ -678,6 +692,7 @@ class FlashCausalLMBatch(Batch):
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor, all_input_ids_tensor=all_input_ids_tensor,
prefix_ids=prefix_ids, prefix_ids=prefix_ids,
suffix_ids=suffix_ids,
next_token_chooser=next_token_chooser, next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens, top_n_tokens=top_n_tokens,
@ -759,6 +774,7 @@ class FlashCausalLMBatch(Batch):
prefix_lens = [] prefix_lens = []
all_input_ids = [] all_input_ids = []
prefix_ids = [] prefix_ids = []
suffix_ids = []
input_lengths = [] input_lengths = []
prefix_offsets = [] prefix_offsets = []
@ -828,6 +844,7 @@ class FlashCausalLMBatch(Batch):
prefix_lens.extend(batch.prefix_lens) prefix_lens.extend(batch.prefix_lens)
all_input_ids.extend(batch.all_input_ids) all_input_ids.extend(batch.all_input_ids)
prefix_ids.extend(batch.prefix_ids) prefix_ids.extend(batch.prefix_ids)
suffix_ids.extend(batch.suffix_ids)
input_lengths.extend(batch.input_lengths) input_lengths.extend(batch.input_lengths)
prefix_offsets.extend(batch.prefix_offsets) prefix_offsets.extend(batch.prefix_offsets)
@ -889,6 +906,7 @@ class FlashCausalLMBatch(Batch):
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor, all_input_ids_tensor=all_input_ids_tensor,
prefix_ids=prefix_ids, prefix_ids=prefix_ids,
suffix_ids=suffix_ids,
next_token_chooser=next_token_chooser, next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens, top_n_tokens=top_n_tokens,
@ -1453,6 +1471,8 @@ class FlashCausalLM(Model):
input_ids = new_input_ids input_ids = new_input_ids
position_ids = new_position_ids position_ids = new_position_ids
if any([len(suffix) for suffix in batch.suffix_ids]):
raise RuntimeError("Suffix not supported with medusa")
else: else:
input_ids = batch.input_ids input_ids = batch.input_ids
position_ids = batch.position_ids position_ids = batch.position_ids
@ -1747,6 +1767,7 @@ class FlashCausalLM(Model):
batch.stopping_criterias, batch.stopping_criterias,
batch.all_input_ids, batch.all_input_ids,
batch.prefix_ids, batch.prefix_ids,
batch.suffix_ids,
batch.next_token_chooser.do_sample, batch.next_token_chooser.do_sample,
batch.next_token_chooser.seeds, batch.next_token_chooser.seeds,
batch.top_n_tokens, batch.top_n_tokens,
@ -1765,6 +1786,7 @@ class FlashCausalLM(Model):
stopping_criteria, stopping_criteria,
all_input_ids, all_input_ids,
prefix_ids, prefix_ids,
suffix_ids,
do_sample, do_sample,
seed, seed,
top_n_tokens, top_n_tokens,
@ -1776,39 +1798,54 @@ class FlashCausalLM(Model):
next_token_texts = [] next_token_texts = []
left = 0 left = 0
if n_accepted_ids > 1: if not suffix_ids:
log_master(logger.debug, f"speculated ids {n_accepted_ids - 1}") if n_accepted_ids > 1:
log_master(logger.debug, f"Speculated ids {n_accepted_ids - 1}")
current_stopped = False current_stopped = False
for j in range(index, index + n_accepted_ids): for j in range(index, index + n_accepted_ids):
# Generated token # Generated token
next_token_id = next_token_ids[j] next_token_id = next_token_ids[j]
all_input_ids.append(next_token_id) all_input_ids.append(next_token_id)
next_token_text, prefix_offset, read_offset = self.decode_token( next_token_text, prefix_offset, read_offset = self.decode_token(
all_input_ids, all_input_ids,
prefix_offset, prefix_offset,
read_offset, read_offset,
) )
next_token_texts.append(next_token_text) next_token_texts.append(next_token_text)
stop, reason = stopping_criteria( stop, reason = stopping_criteria(
next_token_id, next_token_id,
next_token_text, next_token_text,
) )
if stop: if stop:
left = index + n_accepted_ids - j - 1 left = index + n_accepted_ids - j - 1
current_stopped = True current_stopped = True
break break
else: else:
current_stopped = False current_stopped = False
stopped = stopped and current_stopped
_next_token_ids = next_token_ids[index : index + n_accepted_ids - left] stopped = stopped and current_stopped
_next_token_logprobs = next_token_logprobs[
index : index + n_accepted_ids - left _next_token_ids = next_token_ids[index : index + n_accepted_ids - left]
] _next_token_logprobs = next_token_logprobs[
index += n_accepted_ids index : index + n_accepted_ids - left
]
index += n_accepted_ids
req_prefill = prefill
else:
_input_ids = suffix_ids[: self.max_prefill_tokens]
n_accepted_ids = len(_input_ids)
# TODO fix logprobs.
_next_token_logprobs = [float("nan")] * n_accepted_ids
index += n_accepted_ids
suffix_ids = suffix_ids[n_accepted_ids:]
_next_token_ids = []
_next_token_logprobs = []
stop = False
reason = None
req_prefill = True
# Shard generations # Shard generations
# All generations will be appended in the rust sharded client # All generations will be appended in the rust sharded client
@ -1834,7 +1871,7 @@ class FlashCausalLM(Model):
generated_text = None generated_text = None
# Prefill # Prefill
if prefill and request.prefill_logprobs: if req_prefill and request.prefill_logprobs:
out_start_index = batch.prefill_cu_outlens[i] out_start_index = batch.prefill_cu_outlens[i]
out_end_index = batch.prefill_cu_outlens[i + 1] out_end_index = batch.prefill_cu_outlens[i + 1]
@ -1896,6 +1933,7 @@ class FlashCausalLM(Model):
top_tokens, top_tokens,
) )
logger.info(f"Generation {generation}")
generations.append(generation) generations.append(generation)
# accept each new token for this specific request since we may # accept each new token for this specific request since we may
@ -1909,6 +1947,7 @@ class FlashCausalLM(Model):
batch.input_lengths[i] = input_length + n_accepted_ids batch.input_lengths[i] = input_length + n_accepted_ids
if batch.input_lengths[i] > batch.max_seqlen: if batch.input_lengths[i] > batch.max_seqlen:
batch.max_seqlen = batch.input_lengths[i] batch.max_seqlen = batch.input_lengths[i]
batch.suffix_ids[i] = suffix_ids
batch.prefix_offsets[i] = prefix_offset batch.prefix_offsets[i] = prefix_offset
batch.read_offsets[i] = read_offset batch.read_offsets[i] = read_offset
batch.all_input_ids[i] = all_input_ids batch.all_input_ids[i] = all_input_ids

View File

@ -90,6 +90,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
async def Warmup(self, request, context): async def Warmup(self, request, context):
self.model.max_prefill_tokens = request.max_prefill_tokens
if self.quantize in {"exl2", "gptq"}: if self.quantize in {"exl2", "gptq"}:
try: try:
# When using GPTQ, Exllama kernels need some global kernels # When using GPTQ, Exllama kernels need some global kernels