Tmp dump.
This commit is contained in:
parent
2f0fde1055
commit
c821a0ff76
|
@ -380,6 +380,7 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
|
||||||
tracing::error!("Entry response channel error.");
|
tracing::error!("Entry response channel error.");
|
||||||
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
||||||
}).unwrap_or(true);
|
}).unwrap_or(true);
|
||||||
|
tracing::debug!("Stopped iteration {stopped:?}");
|
||||||
if stopped {
|
if stopped {
|
||||||
entries.remove(&id).expect("ID not found in entries. This is a bug.");
|
entries.remove(&id).expect("ID not found in entries. This is a bug.");
|
||||||
}
|
}
|
||||||
|
@ -419,6 +420,7 @@ fn send_responses(
|
||||||
// Create last Token
|
// Create last Token
|
||||||
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
|
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
|
||||||
let n = tokens_.ids.len();
|
let n = tokens_.ids.len();
|
||||||
|
tracing::info!("Received {n:?} tokens");
|
||||||
metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64);
|
metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64);
|
||||||
let mut iterator = tokens_
|
let mut iterator = tokens_
|
||||||
.ids
|
.ids
|
||||||
|
@ -435,6 +437,7 @@ fn send_responses(
|
||||||
logprob,
|
logprob,
|
||||||
special,
|
special,
|
||||||
};
|
};
|
||||||
|
tracing::info!("Sent token {token:?}");
|
||||||
let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) {
|
let top_tokens = if let Some(top_tokens_) = generation.top_tokens.get(i) {
|
||||||
top_tokens_
|
top_tokens_
|
||||||
.ids
|
.ids
|
||||||
|
|
|
@ -120,7 +120,7 @@ impl Client {
|
||||||
// Create requests
|
// Create requests
|
||||||
while rest_tokens > 0 {
|
while rest_tokens > 0 {
|
||||||
let curr_tokens = min(max_tokens_per_request, rest_tokens);
|
let curr_tokens = min(max_tokens_per_request, rest_tokens);
|
||||||
let truncate = min(max_input_length, rest_tokens);
|
let truncate = max_input_length;
|
||||||
let prefix_len = max_input_length.saturating_sub(max_prefill_tokens);
|
let prefix_len = max_input_length.saturating_sub(max_prefill_tokens);
|
||||||
let suffix_len = 0;
|
let suffix_len = 0;
|
||||||
|
|
||||||
|
|
|
@ -281,7 +281,10 @@ impl State {
|
||||||
if total_tokens > token_budget {
|
if total_tokens > token_budget {
|
||||||
// Entry is over budget
|
// Entry is over budget
|
||||||
// Add it back to the front
|
// Add it back to the front
|
||||||
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
|
tracing::debug!(
|
||||||
|
"Over budget: {prefill_tokens} + {decode_tokens} + {} > {token_budget}",
|
||||||
|
self.speculate
|
||||||
|
);
|
||||||
self.entries.push_front((id, entry));
|
self.entries.push_front((id, entry));
|
||||||
break 'entry_loop;
|
break 'entry_loop;
|
||||||
}
|
}
|
||||||
|
@ -301,7 +304,10 @@ impl State {
|
||||||
if (prefill_tokens + decode_tokens + self.speculate) > token_budget {
|
if (prefill_tokens + decode_tokens + self.speculate) > token_budget {
|
||||||
// Entry is over budget
|
// Entry is over budget
|
||||||
// Add it back to the front
|
// Add it back to the front
|
||||||
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
|
tracing::debug!(
|
||||||
|
"Over budget: {prefill_tokens} + {decode_tokens} + {} > {token_budget}",
|
||||||
|
self.speculate
|
||||||
|
);
|
||||||
self.entries.push_front((id, entry));
|
self.entries.push_front((id, entry));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -390,7 +396,9 @@ impl State {
|
||||||
block_allocation.prefix_len,
|
block_allocation.prefix_len,
|
||||||
),
|
),
|
||||||
};
|
};
|
||||||
let suffix_len = (slots.len() as u32).saturating_sub(prefix_len);
|
let suffix_len = (slots.len() as u32)
|
||||||
|
.saturating_sub(prefix_len)
|
||||||
|
.saturating_sub(prefill_token_budget);
|
||||||
|
|
||||||
entry.block_allocation = block_allocation;
|
entry.block_allocation = block_allocation;
|
||||||
|
|
||||||
|
|
|
@ -156,6 +156,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
|
|
||||||
# Prefixes
|
# Prefixes
|
||||||
prefix_ids: List[List[int]]
|
prefix_ids: List[List[int]]
|
||||||
|
suffix_ids: List[List[int]]
|
||||||
|
|
||||||
# All tokens
|
# All tokens
|
||||||
all_input_ids: List[List[int]]
|
all_input_ids: List[List[int]]
|
||||||
|
@ -230,6 +231,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
read_offsets = []
|
read_offsets = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
prefix_ids = []
|
prefix_ids = []
|
||||||
|
suffix_ids = []
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
|
|
||||||
all_prefill_logprobs = True
|
all_prefill_logprobs = True
|
||||||
|
@ -269,20 +271,29 @@ class FlashCausalLMBatch(Batch):
|
||||||
orig_input_length = len(tokenized_input)
|
orig_input_length = len(tokenized_input)
|
||||||
|
|
||||||
prefix_len = r.prefix_len
|
prefix_len = r.prefix_len
|
||||||
import ipdb
|
suffix_len = r.suffix_len
|
||||||
|
|
||||||
ipdb.set_trace()
|
|
||||||
assert (
|
assert (
|
||||||
prefix_len <= orig_input_length
|
prefix_len <= orig_input_length
|
||||||
), f"Prefix {prefix_len} vs input {orig_input_length}"
|
), f"Prefix {prefix_len} vs input {orig_input_length}"
|
||||||
|
assert (
|
||||||
|
suffix_len <= orig_input_length
|
||||||
|
), f"suffix {suffix_len} vs input {orig_input_length}"
|
||||||
if prefix_len == orig_input_length:
|
if prefix_len == orig_input_length:
|
||||||
assert prefix_len > 0
|
assert prefix_len > 0
|
||||||
prefix_len -= 1
|
prefix_len -= 1
|
||||||
|
assert (
|
||||||
|
prefix_len + suffix_len < orig_input_length
|
||||||
|
), f"Prefix+suffix are invalid {prefix_len} + {suffix_len} < {orig_input_length}"
|
||||||
|
|
||||||
# Commented as it's costly.
|
# Commented as it's costly.
|
||||||
# log_master(logger.debug, "Tokenized input ids {tokenized_input}")
|
# log_master(logger.debug, "Tokenized input ids {tokenized_input}")
|
||||||
prefix_ids.append(tokenized_input[:prefix_len])
|
prefix_ids.append(tokenized_input[:prefix_len])
|
||||||
tokenized_input = tokenized_input[prefix_len:]
|
tokenized_input = tokenized_input[prefix_len:]
|
||||||
|
if suffix_len > 0:
|
||||||
|
suffix_ids.append(tokenized_input[-suffix_len:])
|
||||||
|
tokenized_input = tokenized_input[:-suffix_len]
|
||||||
|
else:
|
||||||
|
suffix_ids.append([])
|
||||||
|
|
||||||
input_length = len(tokenized_input)
|
input_length = len(tokenized_input)
|
||||||
input_lengths.append(input_length)
|
input_lengths.append(input_length)
|
||||||
|
@ -294,7 +305,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
|
|
||||||
# Position ids
|
# Position ids
|
||||||
request_position_ids = torch.arange(
|
request_position_ids = torch.arange(
|
||||||
prefix_len, orig_input_length, dtype=torch.int32
|
prefix_len, orig_input_length - suffix_len, dtype=torch.int32
|
||||||
)
|
)
|
||||||
position_ids.append(request_position_ids)
|
position_ids.append(request_position_ids)
|
||||||
|
|
||||||
|
@ -390,7 +401,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
# Update
|
# Update
|
||||||
cumulative_length += input_length
|
cumulative_length += input_length
|
||||||
cumulative_slot_tokens += slot_tokens
|
cumulative_slot_tokens += slot_tokens
|
||||||
max_seqlen = max(max_seqlen, input_length)
|
max_seqlen = max(max_seqlen, orig_input_length)
|
||||||
max_blocks = max(max_blocks, len(request_blocks))
|
max_blocks = max(max_blocks, len(request_blocks))
|
||||||
max_length = max(
|
max_length = max(
|
||||||
max_length, input_length + max_new_tokens + speculative_length
|
max_length, input_length + max_new_tokens + speculative_length
|
||||||
|
@ -501,6 +512,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
all_input_ids_tensor=all_input_ids_tensor,
|
all_input_ids_tensor=all_input_ids_tensor,
|
||||||
prefix_ids=prefix_ids,
|
prefix_ids=prefix_ids,
|
||||||
|
suffix_ids=suffix_ids,
|
||||||
next_token_chooser=next_token_chooser,
|
next_token_chooser=next_token_chooser,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
top_n_tokens=top_n_tokens,
|
top_n_tokens=top_n_tokens,
|
||||||
|
@ -558,6 +570,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
block_tables = []
|
block_tables = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
prefix_ids = []
|
prefix_ids = []
|
||||||
|
suffix_ids = []
|
||||||
|
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
prefix_lens = []
|
prefix_lens = []
|
||||||
|
@ -587,6 +600,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
|
|
||||||
all_input_ids.append(self.all_input_ids[idx])
|
all_input_ids.append(self.all_input_ids[idx])
|
||||||
prefix_ids.append(self.prefix_ids[idx])
|
prefix_ids.append(self.prefix_ids[idx])
|
||||||
|
suffix_ids.append(self.suffix_ids[idx])
|
||||||
|
|
||||||
input_lengths.append(request_input_length)
|
input_lengths.append(request_input_length)
|
||||||
prefix_lens.append(prefix_len)
|
prefix_lens.append(prefix_len)
|
||||||
|
@ -678,6 +692,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
all_input_ids_tensor=all_input_ids_tensor,
|
all_input_ids_tensor=all_input_ids_tensor,
|
||||||
prefix_ids=prefix_ids,
|
prefix_ids=prefix_ids,
|
||||||
|
suffix_ids=suffix_ids,
|
||||||
next_token_chooser=next_token_chooser,
|
next_token_chooser=next_token_chooser,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
top_n_tokens=top_n_tokens,
|
top_n_tokens=top_n_tokens,
|
||||||
|
@ -759,6 +774,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
prefix_lens = []
|
prefix_lens = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
prefix_ids = []
|
prefix_ids = []
|
||||||
|
suffix_ids = []
|
||||||
|
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
prefix_offsets = []
|
prefix_offsets = []
|
||||||
|
@ -828,6 +844,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
prefix_lens.extend(batch.prefix_lens)
|
prefix_lens.extend(batch.prefix_lens)
|
||||||
all_input_ids.extend(batch.all_input_ids)
|
all_input_ids.extend(batch.all_input_ids)
|
||||||
prefix_ids.extend(batch.prefix_ids)
|
prefix_ids.extend(batch.prefix_ids)
|
||||||
|
suffix_ids.extend(batch.suffix_ids)
|
||||||
|
|
||||||
input_lengths.extend(batch.input_lengths)
|
input_lengths.extend(batch.input_lengths)
|
||||||
prefix_offsets.extend(batch.prefix_offsets)
|
prefix_offsets.extend(batch.prefix_offsets)
|
||||||
|
@ -889,6 +906,7 @@ class FlashCausalLMBatch(Batch):
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
all_input_ids_tensor=all_input_ids_tensor,
|
all_input_ids_tensor=all_input_ids_tensor,
|
||||||
prefix_ids=prefix_ids,
|
prefix_ids=prefix_ids,
|
||||||
|
suffix_ids=suffix_ids,
|
||||||
next_token_chooser=next_token_chooser,
|
next_token_chooser=next_token_chooser,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
top_n_tokens=top_n_tokens,
|
top_n_tokens=top_n_tokens,
|
||||||
|
@ -1453,6 +1471,8 @@ class FlashCausalLM(Model):
|
||||||
|
|
||||||
input_ids = new_input_ids
|
input_ids = new_input_ids
|
||||||
position_ids = new_position_ids
|
position_ids = new_position_ids
|
||||||
|
if any([len(suffix) for suffix in batch.suffix_ids]):
|
||||||
|
raise RuntimeError("Suffix not supported with medusa")
|
||||||
else:
|
else:
|
||||||
input_ids = batch.input_ids
|
input_ids = batch.input_ids
|
||||||
position_ids = batch.position_ids
|
position_ids = batch.position_ids
|
||||||
|
@ -1747,6 +1767,7 @@ class FlashCausalLM(Model):
|
||||||
batch.stopping_criterias,
|
batch.stopping_criterias,
|
||||||
batch.all_input_ids,
|
batch.all_input_ids,
|
||||||
batch.prefix_ids,
|
batch.prefix_ids,
|
||||||
|
batch.suffix_ids,
|
||||||
batch.next_token_chooser.do_sample,
|
batch.next_token_chooser.do_sample,
|
||||||
batch.next_token_chooser.seeds,
|
batch.next_token_chooser.seeds,
|
||||||
batch.top_n_tokens,
|
batch.top_n_tokens,
|
||||||
|
@ -1765,6 +1786,7 @@ class FlashCausalLM(Model):
|
||||||
stopping_criteria,
|
stopping_criteria,
|
||||||
all_input_ids,
|
all_input_ids,
|
||||||
prefix_ids,
|
prefix_ids,
|
||||||
|
suffix_ids,
|
||||||
do_sample,
|
do_sample,
|
||||||
seed,
|
seed,
|
||||||
top_n_tokens,
|
top_n_tokens,
|
||||||
|
@ -1776,39 +1798,54 @@ class FlashCausalLM(Model):
|
||||||
next_token_texts = []
|
next_token_texts = []
|
||||||
left = 0
|
left = 0
|
||||||
|
|
||||||
if n_accepted_ids > 1:
|
if not suffix_ids:
|
||||||
log_master(logger.debug, f"speculated ids {n_accepted_ids - 1}")
|
if n_accepted_ids > 1:
|
||||||
|
log_master(logger.debug, f"Speculated ids {n_accepted_ids - 1}")
|
||||||
|
|
||||||
current_stopped = False
|
current_stopped = False
|
||||||
for j in range(index, index + n_accepted_ids):
|
for j in range(index, index + n_accepted_ids):
|
||||||
# Generated token
|
# Generated token
|
||||||
next_token_id = next_token_ids[j]
|
next_token_id = next_token_ids[j]
|
||||||
all_input_ids.append(next_token_id)
|
all_input_ids.append(next_token_id)
|
||||||
next_token_text, prefix_offset, read_offset = self.decode_token(
|
next_token_text, prefix_offset, read_offset = self.decode_token(
|
||||||
all_input_ids,
|
all_input_ids,
|
||||||
prefix_offset,
|
prefix_offset,
|
||||||
read_offset,
|
read_offset,
|
||||||
)
|
)
|
||||||
next_token_texts.append(next_token_text)
|
next_token_texts.append(next_token_text)
|
||||||
|
|
||||||
stop, reason = stopping_criteria(
|
stop, reason = stopping_criteria(
|
||||||
next_token_id,
|
next_token_id,
|
||||||
next_token_text,
|
next_token_text,
|
||||||
)
|
)
|
||||||
|
|
||||||
if stop:
|
if stop:
|
||||||
left = index + n_accepted_ids - j - 1
|
left = index + n_accepted_ids - j - 1
|
||||||
current_stopped = True
|
current_stopped = True
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
current_stopped = False
|
current_stopped = False
|
||||||
stopped = stopped and current_stopped
|
|
||||||
|
|
||||||
_next_token_ids = next_token_ids[index : index + n_accepted_ids - left]
|
stopped = stopped and current_stopped
|
||||||
_next_token_logprobs = next_token_logprobs[
|
|
||||||
index : index + n_accepted_ids - left
|
_next_token_ids = next_token_ids[index : index + n_accepted_ids - left]
|
||||||
]
|
_next_token_logprobs = next_token_logprobs[
|
||||||
index += n_accepted_ids
|
index : index + n_accepted_ids - left
|
||||||
|
]
|
||||||
|
index += n_accepted_ids
|
||||||
|
req_prefill = prefill
|
||||||
|
else:
|
||||||
|
_input_ids = suffix_ids[: self.max_prefill_tokens]
|
||||||
|
n_accepted_ids = len(_input_ids)
|
||||||
|
# TODO fix logprobs.
|
||||||
|
_next_token_logprobs = [float("nan")] * n_accepted_ids
|
||||||
|
index += n_accepted_ids
|
||||||
|
suffix_ids = suffix_ids[n_accepted_ids:]
|
||||||
|
_next_token_ids = []
|
||||||
|
_next_token_logprobs = []
|
||||||
|
stop = False
|
||||||
|
reason = None
|
||||||
|
req_prefill = True
|
||||||
|
|
||||||
# Shard generations
|
# Shard generations
|
||||||
# All generations will be appended in the rust sharded client
|
# All generations will be appended in the rust sharded client
|
||||||
|
@ -1834,7 +1871,7 @@ class FlashCausalLM(Model):
|
||||||
generated_text = None
|
generated_text = None
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if prefill and request.prefill_logprobs:
|
if req_prefill and request.prefill_logprobs:
|
||||||
out_start_index = batch.prefill_cu_outlens[i]
|
out_start_index = batch.prefill_cu_outlens[i]
|
||||||
out_end_index = batch.prefill_cu_outlens[i + 1]
|
out_end_index = batch.prefill_cu_outlens[i + 1]
|
||||||
|
|
||||||
|
@ -1896,6 +1933,7 @@ class FlashCausalLM(Model):
|
||||||
top_tokens,
|
top_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger.info(f"Generation {generation}")
|
||||||
generations.append(generation)
|
generations.append(generation)
|
||||||
|
|
||||||
# accept each new token for this specific request since we may
|
# accept each new token for this specific request since we may
|
||||||
|
@ -1909,6 +1947,7 @@ class FlashCausalLM(Model):
|
||||||
batch.input_lengths[i] = input_length + n_accepted_ids
|
batch.input_lengths[i] = input_length + n_accepted_ids
|
||||||
if batch.input_lengths[i] > batch.max_seqlen:
|
if batch.input_lengths[i] > batch.max_seqlen:
|
||||||
batch.max_seqlen = batch.input_lengths[i]
|
batch.max_seqlen = batch.input_lengths[i]
|
||||||
|
batch.suffix_ids[i] = suffix_ids
|
||||||
batch.prefix_offsets[i] = prefix_offset
|
batch.prefix_offsets[i] = prefix_offset
|
||||||
batch.read_offsets[i] = read_offset
|
batch.read_offsets[i] = read_offset
|
||||||
batch.all_input_ids[i] = all_input_ids
|
batch.all_input_ids[i] = all_input_ids
|
||||||
|
|
|
@ -90,6 +90,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||||
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
||||||
|
|
||||||
async def Warmup(self, request, context):
|
async def Warmup(self, request, context):
|
||||||
|
self.model.max_prefill_tokens = request.max_prefill_tokens
|
||||||
if self.quantize in {"exl2", "gptq"}:
|
if self.quantize in {"exl2", "gptq"}:
|
||||||
try:
|
try:
|
||||||
# When using GPTQ, Exllama kernels need some global kernels
|
# When using GPTQ, Exllama kernels need some global kernels
|
||||||
|
|
Loading…
Reference in New Issue