Tmp dump.
This commit is contained in:
parent
2f0fde1055
commit
c821a0ff76
|
@ -380,6 +380,7 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
|
|||
tracing::error!("Entry response channel error.");
|
||||
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
||||
}).unwrap_or(true);
|
||||
tracing::debug!("Stopped iteration {stopped:?}");
|
||||
if stopped {
|
||||
entries.remove(&id).expect("ID not found in entries. This is a bug.");
|
||||
}
|
||||
|
@ -419,6 +420,7 @@ fn send_responses(
|
|||
// Create last Token
|
||||
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
|
||||
let n = tokens_.ids.len();
|
||||
tracing::info!("Received {n:?} tokens");
|
||||
metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64);
|
||||
let mut iterator = tokens_
|
||||
.ids
|
||||
|
@ -435,6 +437,7 @@ fn send_responses(
|
|||
logprob,
|
||||
special,
|
||||
};
|
||||
tracing::info!("Sent token {token:?}");
|
||||
let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) {
|
||||
top_tokens_
|
||||
.ids
|
||||
|
|
|
@ -120,7 +120,7 @@ impl Client {
|
|||
// Create requests
|
||||
while rest_tokens > 0 {
|
||||
let curr_tokens = min(max_tokens_per_request, rest_tokens);
|
||||
let truncate = min(max_input_length, rest_tokens);
|
||||
let truncate = max_input_length;
|
||||
let prefix_len = max_input_length.saturating_sub(max_prefill_tokens);
|
||||
let suffix_len = 0;
|
||||
|
||||
|
|
|
@ -281,7 +281,10 @@ impl State {
|
|||
if total_tokens > token_budget {
|
||||
// Entry is over budget
|
||||
// Add it back to the front
|
||||
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
|
||||
tracing::debug!(
|
||||
"Over budget: {prefill_tokens} + {decode_tokens} + {} > {token_budget}",
|
||||
self.speculate
|
||||
);
|
||||
self.entries.push_front((id, entry));
|
||||
break 'entry_loop;
|
||||
}
|
||||
|
@ -301,7 +304,10 @@ impl State {
|
|||
if (prefill_tokens + decode_tokens + self.speculate) > token_budget {
|
||||
// Entry is over budget
|
||||
// Add it back to the front
|
||||
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
|
||||
tracing::debug!(
|
||||
"Over budget: {prefill_tokens} + {decode_tokens} + {} > {token_budget}",
|
||||
self.speculate
|
||||
);
|
||||
self.entries.push_front((id, entry));
|
||||
break;
|
||||
}
|
||||
|
@ -390,7 +396,9 @@ impl State {
|
|||
block_allocation.prefix_len,
|
||||
),
|
||||
};
|
||||
let suffix_len = (slots.len() as u32).saturating_sub(prefix_len);
|
||||
let suffix_len = (slots.len() as u32)
|
||||
.saturating_sub(prefix_len)
|
||||
.saturating_sub(prefill_token_budget);
|
||||
|
||||
entry.block_allocation = block_allocation;
|
||||
|
||||
|
|
|
@ -156,6 +156,7 @@ class FlashCausalLMBatch(Batch):
|
|||
|
||||
# Prefixes
|
||||
prefix_ids: List[List[int]]
|
||||
suffix_ids: List[List[int]]
|
||||
|
||||
# All tokens
|
||||
all_input_ids: List[List[int]]
|
||||
|
@ -230,6 +231,7 @@ class FlashCausalLMBatch(Batch):
|
|||
read_offsets = []
|
||||
all_input_ids = []
|
||||
prefix_ids = []
|
||||
suffix_ids = []
|
||||
requests_idx_mapping = {}
|
||||
|
||||
all_prefill_logprobs = True
|
||||
|
@ -269,20 +271,29 @@ class FlashCausalLMBatch(Batch):
|
|||
orig_input_length = len(tokenized_input)
|
||||
|
||||
prefix_len = r.prefix_len
|
||||
import ipdb
|
||||
|
||||
ipdb.set_trace()
|
||||
suffix_len = r.suffix_len
|
||||
assert (
|
||||
prefix_len <= orig_input_length
|
||||
), f"Prefix {prefix_len} vs input {orig_input_length}"
|
||||
assert (
|
||||
suffix_len <= orig_input_length
|
||||
), f"suffix {suffix_len} vs input {orig_input_length}"
|
||||
if prefix_len == orig_input_length:
|
||||
assert prefix_len > 0
|
||||
prefix_len -= 1
|
||||
assert (
|
||||
prefix_len + suffix_len < orig_input_length
|
||||
), f"Prefix+suffix are invalid {prefix_len} + {suffix_len} < {orig_input_length}"
|
||||
|
||||
# Commented as it's costly.
|
||||
# log_master(logger.debug, "Tokenized input ids {tokenized_input}")
|
||||
prefix_ids.append(tokenized_input[:prefix_len])
|
||||
tokenized_input = tokenized_input[prefix_len:]
|
||||
if suffix_len > 0:
|
||||
suffix_ids.append(tokenized_input[-suffix_len:])
|
||||
tokenized_input = tokenized_input[:-suffix_len]
|
||||
else:
|
||||
suffix_ids.append([])
|
||||
|
||||
input_length = len(tokenized_input)
|
||||
input_lengths.append(input_length)
|
||||
|
@ -294,7 +305,7 @@ class FlashCausalLMBatch(Batch):
|
|||
|
||||
# Position ids
|
||||
request_position_ids = torch.arange(
|
||||
prefix_len, orig_input_length, dtype=torch.int32
|
||||
prefix_len, orig_input_length - suffix_len, dtype=torch.int32
|
||||
)
|
||||
position_ids.append(request_position_ids)
|
||||
|
||||
|
@ -390,7 +401,7 @@ class FlashCausalLMBatch(Batch):
|
|||
# Update
|
||||
cumulative_length += input_length
|
||||
cumulative_slot_tokens += slot_tokens
|
||||
max_seqlen = max(max_seqlen, input_length)
|
||||
max_seqlen = max(max_seqlen, orig_input_length)
|
||||
max_blocks = max(max_blocks, len(request_blocks))
|
||||
max_length = max(
|
||||
max_length, input_length + max_new_tokens + speculative_length
|
||||
|
@ -501,6 +512,7 @@ class FlashCausalLMBatch(Batch):
|
|||
all_input_ids=all_input_ids,
|
||||
all_input_ids_tensor=all_input_ids_tensor,
|
||||
prefix_ids=prefix_ids,
|
||||
suffix_ids=suffix_ids,
|
||||
next_token_chooser=next_token_chooser,
|
||||
stopping_criterias=stopping_criterias,
|
||||
top_n_tokens=top_n_tokens,
|
||||
|
@ -558,6 +570,7 @@ class FlashCausalLMBatch(Batch):
|
|||
block_tables = []
|
||||
all_input_ids = []
|
||||
prefix_ids = []
|
||||
suffix_ids = []
|
||||
|
||||
input_lengths = []
|
||||
prefix_lens = []
|
||||
|
@ -587,6 +600,7 @@ class FlashCausalLMBatch(Batch):
|
|||
|
||||
all_input_ids.append(self.all_input_ids[idx])
|
||||
prefix_ids.append(self.prefix_ids[idx])
|
||||
suffix_ids.append(self.suffix_ids[idx])
|
||||
|
||||
input_lengths.append(request_input_length)
|
||||
prefix_lens.append(prefix_len)
|
||||
|
@ -678,6 +692,7 @@ class FlashCausalLMBatch(Batch):
|
|||
all_input_ids=all_input_ids,
|
||||
all_input_ids_tensor=all_input_ids_tensor,
|
||||
prefix_ids=prefix_ids,
|
||||
suffix_ids=suffix_ids,
|
||||
next_token_chooser=next_token_chooser,
|
||||
stopping_criterias=stopping_criterias,
|
||||
top_n_tokens=top_n_tokens,
|
||||
|
@ -759,6 +774,7 @@ class FlashCausalLMBatch(Batch):
|
|||
prefix_lens = []
|
||||
all_input_ids = []
|
||||
prefix_ids = []
|
||||
suffix_ids = []
|
||||
|
||||
input_lengths = []
|
||||
prefix_offsets = []
|
||||
|
@ -828,6 +844,7 @@ class FlashCausalLMBatch(Batch):
|
|||
prefix_lens.extend(batch.prefix_lens)
|
||||
all_input_ids.extend(batch.all_input_ids)
|
||||
prefix_ids.extend(batch.prefix_ids)
|
||||
suffix_ids.extend(batch.suffix_ids)
|
||||
|
||||
input_lengths.extend(batch.input_lengths)
|
||||
prefix_offsets.extend(batch.prefix_offsets)
|
||||
|
@ -889,6 +906,7 @@ class FlashCausalLMBatch(Batch):
|
|||
all_input_ids=all_input_ids,
|
||||
all_input_ids_tensor=all_input_ids_tensor,
|
||||
prefix_ids=prefix_ids,
|
||||
suffix_ids=suffix_ids,
|
||||
next_token_chooser=next_token_chooser,
|
||||
stopping_criterias=stopping_criterias,
|
||||
top_n_tokens=top_n_tokens,
|
||||
|
@ -1453,6 +1471,8 @@ class FlashCausalLM(Model):
|
|||
|
||||
input_ids = new_input_ids
|
||||
position_ids = new_position_ids
|
||||
if any([len(suffix) for suffix in batch.suffix_ids]):
|
||||
raise RuntimeError("Suffix not supported with medusa")
|
||||
else:
|
||||
input_ids = batch.input_ids
|
||||
position_ids = batch.position_ids
|
||||
|
@ -1747,6 +1767,7 @@ class FlashCausalLM(Model):
|
|||
batch.stopping_criterias,
|
||||
batch.all_input_ids,
|
||||
batch.prefix_ids,
|
||||
batch.suffix_ids,
|
||||
batch.next_token_chooser.do_sample,
|
||||
batch.next_token_chooser.seeds,
|
||||
batch.top_n_tokens,
|
||||
|
@ -1765,6 +1786,7 @@ class FlashCausalLM(Model):
|
|||
stopping_criteria,
|
||||
all_input_ids,
|
||||
prefix_ids,
|
||||
suffix_ids,
|
||||
do_sample,
|
||||
seed,
|
||||
top_n_tokens,
|
||||
|
@ -1776,39 +1798,54 @@ class FlashCausalLM(Model):
|
|||
next_token_texts = []
|
||||
left = 0
|
||||
|
||||
if n_accepted_ids > 1:
|
||||
log_master(logger.debug, f"speculated ids {n_accepted_ids - 1}")
|
||||
if not suffix_ids:
|
||||
if n_accepted_ids > 1:
|
||||
log_master(logger.debug, f"Speculated ids {n_accepted_ids - 1}")
|
||||
|
||||
current_stopped = False
|
||||
for j in range(index, index + n_accepted_ids):
|
||||
# Generated token
|
||||
next_token_id = next_token_ids[j]
|
||||
all_input_ids.append(next_token_id)
|
||||
next_token_text, prefix_offset, read_offset = self.decode_token(
|
||||
all_input_ids,
|
||||
prefix_offset,
|
||||
read_offset,
|
||||
)
|
||||
next_token_texts.append(next_token_text)
|
||||
current_stopped = False
|
||||
for j in range(index, index + n_accepted_ids):
|
||||
# Generated token
|
||||
next_token_id = next_token_ids[j]
|
||||
all_input_ids.append(next_token_id)
|
||||
next_token_text, prefix_offset, read_offset = self.decode_token(
|
||||
all_input_ids,
|
||||
prefix_offset,
|
||||
read_offset,
|
||||
)
|
||||
next_token_texts.append(next_token_text)
|
||||
|
||||
stop, reason = stopping_criteria(
|
||||
next_token_id,
|
||||
next_token_text,
|
||||
)
|
||||
stop, reason = stopping_criteria(
|
||||
next_token_id,
|
||||
next_token_text,
|
||||
)
|
||||
|
||||
if stop:
|
||||
left = index + n_accepted_ids - j - 1
|
||||
current_stopped = True
|
||||
break
|
||||
else:
|
||||
current_stopped = False
|
||||
stopped = stopped and current_stopped
|
||||
if stop:
|
||||
left = index + n_accepted_ids - j - 1
|
||||
current_stopped = True
|
||||
break
|
||||
else:
|
||||
current_stopped = False
|
||||
|
||||
_next_token_ids = next_token_ids[index : index + n_accepted_ids - left]
|
||||
_next_token_logprobs = next_token_logprobs[
|
||||
index : index + n_accepted_ids - left
|
||||
]
|
||||
index += n_accepted_ids
|
||||
stopped = stopped and current_stopped
|
||||
|
||||
_next_token_ids = next_token_ids[index : index + n_accepted_ids - left]
|
||||
_next_token_logprobs = next_token_logprobs[
|
||||
index : index + n_accepted_ids - left
|
||||
]
|
||||
index += n_accepted_ids
|
||||
req_prefill = prefill
|
||||
else:
|
||||
_input_ids = suffix_ids[: self.max_prefill_tokens]
|
||||
n_accepted_ids = len(_input_ids)
|
||||
# TODO fix logprobs.
|
||||
_next_token_logprobs = [float("nan")] * n_accepted_ids
|
||||
index += n_accepted_ids
|
||||
suffix_ids = suffix_ids[n_accepted_ids:]
|
||||
_next_token_ids = []
|
||||
_next_token_logprobs = []
|
||||
stop = False
|
||||
reason = None
|
||||
req_prefill = True
|
||||
|
||||
# Shard generations
|
||||
# All generations will be appended in the rust sharded client
|
||||
|
@ -1834,7 +1871,7 @@ class FlashCausalLM(Model):
|
|||
generated_text = None
|
||||
|
||||
# Prefill
|
||||
if prefill and request.prefill_logprobs:
|
||||
if req_prefill and request.prefill_logprobs:
|
||||
out_start_index = batch.prefill_cu_outlens[i]
|
||||
out_end_index = batch.prefill_cu_outlens[i + 1]
|
||||
|
||||
|
@ -1896,6 +1933,7 @@ class FlashCausalLM(Model):
|
|||
top_tokens,
|
||||
)
|
||||
|
||||
logger.info(f"Generation {generation}")
|
||||
generations.append(generation)
|
||||
|
||||
# accept each new token for this specific request since we may
|
||||
|
@ -1909,6 +1947,7 @@ class FlashCausalLM(Model):
|
|||
batch.input_lengths[i] = input_length + n_accepted_ids
|
||||
if batch.input_lengths[i] > batch.max_seqlen:
|
||||
batch.max_seqlen = batch.input_lengths[i]
|
||||
batch.suffix_ids[i] = suffix_ids
|
||||
batch.prefix_offsets[i] = prefix_offset
|
||||
batch.read_offsets[i] = read_offset
|
||||
batch.all_input_ids[i] = all_input_ids
|
||||
|
|
|
@ -90,6 +90,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
||||
|
||||
async def Warmup(self, request, context):
|
||||
self.model.max_prefill_tokens = request.max_prefill_tokens
|
||||
if self.quantize in {"exl2", "gptq"}:
|
||||
try:
|
||||
# When using GPTQ, Exllama kernels need some global kernels
|
||||
|
|
Loading…
Reference in New Issue