Tmp dump.

This commit is contained in:
Nicolas Patry 2024-09-03 12:30:12 +02:00
parent 2f0fde1055
commit c821a0ff76
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
5 changed files with 90 additions and 39 deletions

View File

@ -380,6 +380,7 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
tracing::error!("Entry response channel error.");
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
}).unwrap_or(true);
tracing::debug!("Stopped iteration {stopped:?}");
if stopped {
entries.remove(&id).expect("ID not found in entries. This is a bug.");
}
@ -419,6 +420,7 @@ fn send_responses(
// Create last Token
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
let n = tokens_.ids.len();
tracing::info!("Received {n:?} tokens");
metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64);
let mut iterator = tokens_
.ids
@ -435,6 +437,7 @@ fn send_responses(
logprob,
special,
};
tracing::info!("Sent token {token:?}");
let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) {
top_tokens_
.ids

View File

@ -120,7 +120,7 @@ impl Client {
// Create requests
while rest_tokens > 0 {
let curr_tokens = min(max_tokens_per_request, rest_tokens);
let truncate = min(max_input_length, rest_tokens);
let truncate = max_input_length;
let prefix_len = max_input_length.saturating_sub(max_prefill_tokens);
let suffix_len = 0;

View File

@ -281,7 +281,10 @@ impl State {
if total_tokens > token_budget {
// Entry is over budget
// Add it back to the front
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
tracing::debug!(
"Over budget: {prefill_tokens} + {decode_tokens} + {} > {token_budget}",
self.speculate
);
self.entries.push_front((id, entry));
break 'entry_loop;
}
@ -301,7 +304,10 @@ impl State {
if (prefill_tokens + decode_tokens + self.speculate) > token_budget {
// Entry is over budget
// Add it back to the front
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
tracing::debug!(
"Over budget: {prefill_tokens} + {decode_tokens} + {} > {token_budget}",
self.speculate
);
self.entries.push_front((id, entry));
break;
}
@ -390,7 +396,9 @@ impl State {
block_allocation.prefix_len,
),
};
let suffix_len = (slots.len() as u32).saturating_sub(prefix_len);
let suffix_len = (slots.len() as u32)
.saturating_sub(prefix_len)
.saturating_sub(prefill_token_budget);
entry.block_allocation = block_allocation;

View File

@ -156,6 +156,7 @@ class FlashCausalLMBatch(Batch):
# Prefixes
prefix_ids: List[List[int]]
suffix_ids: List[List[int]]
# All tokens
all_input_ids: List[List[int]]
@ -230,6 +231,7 @@ class FlashCausalLMBatch(Batch):
read_offsets = []
all_input_ids = []
prefix_ids = []
suffix_ids = []
requests_idx_mapping = {}
all_prefill_logprobs = True
@ -269,20 +271,29 @@ class FlashCausalLMBatch(Batch):
orig_input_length = len(tokenized_input)
prefix_len = r.prefix_len
import ipdb
ipdb.set_trace()
suffix_len = r.suffix_len
assert (
prefix_len <= orig_input_length
), f"Prefix {prefix_len} vs input {orig_input_length}"
assert (
suffix_len <= orig_input_length
), f"suffix {suffix_len} vs input {orig_input_length}"
if prefix_len == orig_input_length:
assert prefix_len > 0
prefix_len -= 1
assert (
prefix_len + suffix_len < orig_input_length
), f"Prefix+suffix are invalid {prefix_len} + {suffix_len} < {orig_input_length}"
# Commented as it's costly.
# log_master(logger.debug, "Tokenized input ids {tokenized_input}")
prefix_ids.append(tokenized_input[:prefix_len])
tokenized_input = tokenized_input[prefix_len:]
if suffix_len > 0:
suffix_ids.append(tokenized_input[-suffix_len:])
tokenized_input = tokenized_input[:-suffix_len]
else:
suffix_ids.append([])
input_length = len(tokenized_input)
input_lengths.append(input_length)
@ -294,7 +305,7 @@ class FlashCausalLMBatch(Batch):
# Position ids
request_position_ids = torch.arange(
prefix_len, orig_input_length, dtype=torch.int32
prefix_len, orig_input_length - suffix_len, dtype=torch.int32
)
position_ids.append(request_position_ids)
@ -390,7 +401,7 @@ class FlashCausalLMBatch(Batch):
# Update
cumulative_length += input_length
cumulative_slot_tokens += slot_tokens
max_seqlen = max(max_seqlen, input_length)
max_seqlen = max(max_seqlen, orig_input_length)
max_blocks = max(max_blocks, len(request_blocks))
max_length = max(
max_length, input_length + max_new_tokens + speculative_length
@ -501,6 +512,7 @@ class FlashCausalLMBatch(Batch):
all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
prefix_ids=prefix_ids,
suffix_ids=suffix_ids,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
@ -558,6 +570,7 @@ class FlashCausalLMBatch(Batch):
block_tables = []
all_input_ids = []
prefix_ids = []
suffix_ids = []
input_lengths = []
prefix_lens = []
@ -587,6 +600,7 @@ class FlashCausalLMBatch(Batch):
all_input_ids.append(self.all_input_ids[idx])
prefix_ids.append(self.prefix_ids[idx])
suffix_ids.append(self.suffix_ids[idx])
input_lengths.append(request_input_length)
prefix_lens.append(prefix_len)
@ -678,6 +692,7 @@ class FlashCausalLMBatch(Batch):
all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
prefix_ids=prefix_ids,
suffix_ids=suffix_ids,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
@ -759,6 +774,7 @@ class FlashCausalLMBatch(Batch):
prefix_lens = []
all_input_ids = []
prefix_ids = []
suffix_ids = []
input_lengths = []
prefix_offsets = []
@ -828,6 +844,7 @@ class FlashCausalLMBatch(Batch):
prefix_lens.extend(batch.prefix_lens)
all_input_ids.extend(batch.all_input_ids)
prefix_ids.extend(batch.prefix_ids)
suffix_ids.extend(batch.suffix_ids)
input_lengths.extend(batch.input_lengths)
prefix_offsets.extend(batch.prefix_offsets)
@ -889,6 +906,7 @@ class FlashCausalLMBatch(Batch):
all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
prefix_ids=prefix_ids,
suffix_ids=suffix_ids,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
@ -1453,6 +1471,8 @@ class FlashCausalLM(Model):
input_ids = new_input_ids
position_ids = new_position_ids
if any([len(suffix) for suffix in batch.suffix_ids]):
raise RuntimeError("Suffix not supported with medusa")
else:
input_ids = batch.input_ids
position_ids = batch.position_ids
@ -1747,6 +1767,7 @@ class FlashCausalLM(Model):
batch.stopping_criterias,
batch.all_input_ids,
batch.prefix_ids,
batch.suffix_ids,
batch.next_token_chooser.do_sample,
batch.next_token_chooser.seeds,
batch.top_n_tokens,
@ -1765,6 +1786,7 @@ class FlashCausalLM(Model):
stopping_criteria,
all_input_ids,
prefix_ids,
suffix_ids,
do_sample,
seed,
top_n_tokens,
@ -1776,39 +1798,54 @@ class FlashCausalLM(Model):
next_token_texts = []
left = 0
if n_accepted_ids > 1:
log_master(logger.debug, f"speculated ids {n_accepted_ids - 1}")
if not suffix_ids:
if n_accepted_ids > 1:
log_master(logger.debug, f"Speculated ids {n_accepted_ids - 1}")
current_stopped = False
for j in range(index, index + n_accepted_ids):
# Generated token
next_token_id = next_token_ids[j]
all_input_ids.append(next_token_id)
next_token_text, prefix_offset, read_offset = self.decode_token(
all_input_ids,
prefix_offset,
read_offset,
)
next_token_texts.append(next_token_text)
current_stopped = False
for j in range(index, index + n_accepted_ids):
# Generated token
next_token_id = next_token_ids[j]
all_input_ids.append(next_token_id)
next_token_text, prefix_offset, read_offset = self.decode_token(
all_input_ids,
prefix_offset,
read_offset,
)
next_token_texts.append(next_token_text)
stop, reason = stopping_criteria(
next_token_id,
next_token_text,
)
stop, reason = stopping_criteria(
next_token_id,
next_token_text,
)
if stop:
left = index + n_accepted_ids - j - 1
current_stopped = True
break
else:
current_stopped = False
stopped = stopped and current_stopped
if stop:
left = index + n_accepted_ids - j - 1
current_stopped = True
break
else:
current_stopped = False
_next_token_ids = next_token_ids[index : index + n_accepted_ids - left]
_next_token_logprobs = next_token_logprobs[
index : index + n_accepted_ids - left
]
index += n_accepted_ids
stopped = stopped and current_stopped
_next_token_ids = next_token_ids[index : index + n_accepted_ids - left]
_next_token_logprobs = next_token_logprobs[
index : index + n_accepted_ids - left
]
index += n_accepted_ids
req_prefill = prefill
else:
_input_ids = suffix_ids[: self.max_prefill_tokens]
n_accepted_ids = len(_input_ids)
# TODO fix logprobs.
_next_token_logprobs = [float("nan")] * n_accepted_ids
index += n_accepted_ids
suffix_ids = suffix_ids[n_accepted_ids:]
_next_token_ids = []
_next_token_logprobs = []
stop = False
reason = None
req_prefill = True
# Shard generations
# All generations will be appended in the rust sharded client
@ -1834,7 +1871,7 @@ class FlashCausalLM(Model):
generated_text = None
# Prefill
if prefill and request.prefill_logprobs:
if req_prefill and request.prefill_logprobs:
out_start_index = batch.prefill_cu_outlens[i]
out_end_index = batch.prefill_cu_outlens[i + 1]
@ -1896,6 +1933,7 @@ class FlashCausalLM(Model):
top_tokens,
)
logger.info(f"Generation {generation}")
generations.append(generation)
# accept each new token for this specific request since we may
@ -1909,6 +1947,7 @@ class FlashCausalLM(Model):
batch.input_lengths[i] = input_length + n_accepted_ids
if batch.input_lengths[i] > batch.max_seqlen:
batch.max_seqlen = batch.input_lengths[i]
batch.suffix_ids[i] = suffix_ids
batch.prefix_offsets[i] = prefix_offset
batch.read_offsets[i] = read_offset
batch.all_input_ids[i] = all_input_ids

View File

@ -90,6 +90,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
async def Warmup(self, request, context):
self.model.max_prefill_tokens = request.max_prefill_tokens
if self.quantize in {"exl2", "gptq"}:
try:
# When using GPTQ, Exllama kernels need some global kernels