fix rust and python unit-tests

This commit is contained in:
OlivierDehaene 2024-06-11 17:11:16 +02:00
parent 73c3903214
commit 37266e2dbb
12 changed files with 288 additions and 112 deletions

View File

@ -16,4 +16,3 @@ jobs:
fetch-depth: 0
- name: Secret Scanning
uses: trufflesecurity/trufflehog@main

View File

@ -1,7 +1,8 @@
use std::sync::{Arc, Mutex};
use std::fmt::Formatter;
use std::sync::{Arc, Mutex, TryLockError};
use thiserror::Error;
#[derive(Debug, Clone)]
#[derive(Clone)]
pub(crate) struct BlockAllocation {
allocated_blocks: Vec<u32>,
allocated_slots: Vec<u32>,
@ -53,7 +54,19 @@ impl Drop for BlockAllocation {
}
}
#[derive(Debug, Clone)]
impl std::fmt::Debug for BlockAllocation {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BlockAllocation")
.field("allocated_blocks", &self.allocated_blocks.len())
.field("allocated_slots", &self.allocated_slots.len())
.field("required_blocks", &self.required_blocks)
.field("required_slots", &self.required_slots)
.field("block_allocator", &self.block_allocator)
.finish()
}
}
#[derive(Clone)]
pub(crate) struct BlockAllocator {
free_blocks: Arc<Mutex<Vec<u32>>>,
block_size: u32,
@ -129,8 +142,7 @@ impl BlockAllocator {
Err(AllocationError::NotEnoughPages)
} else {
let n_free_blocks = free_blocks.len();
let allocated_blocks =
free_blocks.split_off(n_free_blocks - clipped_required_blocks);
let allocated_blocks = free_blocks.split_off(n_free_blocks - clipped_required_blocks);
let allocated_blocks = if repeats != 1 {
let mut allocated_blocks = allocated_blocks.repeat(repeats);
@ -140,9 +152,8 @@ impl BlockAllocator {
allocated_blocks
};
let mut allocated_slots = Vec::with_capacity(
allocated_blocks.len() * self.block_size as usize * repeats,
);
let mut allocated_slots =
Vec::with_capacity(allocated_blocks.len() * self.block_size as usize * repeats);
let required_slots = (prompt_tokens + decode_tokens) as usize;
@ -166,7 +177,30 @@ impl BlockAllocator {
}
pub(crate) fn free(&self, blocks: Vec<u32>) {
self.free_blocks.lock().expect("Lock could not be acquired. This is a bug.").extend(blocks)
self.free_blocks
.lock()
.expect("Lock could not be acquired. This is a bug.")
.extend(blocks)
}
}
impl std::fmt::Debug for BlockAllocator {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let mut d = f.debug_struct("BlockAllocator");
d.field("block_size", &self.block_size)
.field("window_size", &self.window_size);
match self.free_blocks.try_lock() {
Ok(guard) => {
d.field("free_blocks", &(*guard).len());
}
Err(TryLockError::Poisoned(err)) => {
d.field("free_blocks", &(**err.get_ref()).len());
}
Err(TryLockError::WouldBlock) => {
d.field("free_blocks", &format_args!("<locked>"));
}
};
d.finish()
}
}

View File

@ -275,7 +275,9 @@ impl State {
if prefill_tokens > prefill_token_budget {
// Entry is over budget
// Add it back to the front
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
tracing::debug!(
"Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget}"
);
self.entries.push_front((id, entry));
break;
}
@ -456,7 +458,7 @@ mod tests {
let entry = Entry {
request: ValidGenerateRequest {
inputs: vec![],
input_length: 0,
input_length: 1,
truncate: 0,
decoder_input_details: false,
parameters: ValidParameters {
@ -567,7 +569,7 @@ mod tests {
#[tokio::test]
async fn test_next_batch_token_budget() {
let mut state = State::new(false, 1, None, 0, 2);
let mut state = State::new(false, 1, None, 0, 16);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
state.append(entry1);
@ -689,7 +691,7 @@ mod tests {
#[tokio::test]
async fn test_queue_next_batch_token_speculate() {
let queue = Queue::new(false, 1, None, 2, 16);
let queue = Queue::new(true, 1, None, 2, 16);
let (entry1, _guard1) = default_entry();
let (entry2, _guard2) = default_entry();
queue.append(entry1);

View File

@ -256,11 +256,7 @@ async fn prefill(
.expect("ID not found in entries. This is a bug.");
// Send intermediate responses
if let Err(_) = send_stream_responses(stream_responses, entry).map_err(|err| {
tracing::error!("Entry response channel error.");
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
err
}) {
if send_stream_responses(stream_responses, entry).is_err() {
// Sending failed, remove entry
entries
.remove(&id)
@ -405,7 +401,7 @@ async fn filter_batch(
.filter_batch(
id,
updated_requests,
terminated_entries.keys().map(|v| *v).collect(),
terminated_entries.keys().copied().collect(),
)
.await
.unwrap()
@ -460,11 +456,14 @@ fn send_terminated_generations(
};
// Send responses
if let Err(_) = entry.response_tx.send(Ok(response)).map_err(|err| {
let send_result = entry.response_tx.send(Ok(response)).map_err(|err| {
tracing::error!("Entry response channel error.");
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
err
}) {
});
if send_result.is_err() {
// The channel is dropped, skip the rest of the messages
continue 'terminated_generations;
}
}
@ -504,11 +503,7 @@ fn filter_send_ended_generations(
// If the generation has ended for this request, we send the responses to the channel and
// remove the entry to drop it and free its blocks
if finished {
let _ = send_stream_responses(stream_responses, entry).map_err(|err| {
tracing::error!("Entry response channel error.");
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
err
});
let _ = send_stream_responses(stream_responses, entry);
// Remove from entries and filter
entries.remove(&id).expect("ID not found in entries. This is a bug.");
return None;
@ -525,7 +520,11 @@ fn send_stream_responses(
entry: &Entry,
) -> Result<(), Box<SendError<Result<InferStreamResponse, InferError>>>> {
for response in stream_responses {
entry.response_tx.send(Ok(response))?;
entry.response_tx.send(Ok(response)).map_err(|err| {
tracing::error!("Entry response channel error.");
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
err
})?;
}
Ok(())
}
@ -541,7 +540,7 @@ fn filter_send_update_allocations(
) -> (bool, IntMap<u64, Entry>) {
let mut updated = false;
let ids: Vec<u64> = entries.keys().map(|v| *v).collect();
let ids: Vec<u64> = entries.keys().copied().collect();
let mut terminated_entries =
IntMap::with_capacity_and_hasher(entries.len(), BuildNoHashHasher::default());
@ -581,11 +580,7 @@ fn filter_send_update_allocations(
.expect("ID not found in stream_responses. This is a bug.");
// Send intermediate responses
if let Err(_) = send_stream_responses(stream_response, entry).map_err(|err| {
tracing::error!("Entry response channel error.");
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
err
}) {
if send_stream_responses(stream_response, entry).is_err() {
// Sending failed, remove entry
entries
.remove(id)

View File

@ -197,8 +197,10 @@ def test_causal_lm_generate_token_completion_multi(
# Copy stopping_criterias before filtering
stopping_criterias = default_multi_requests_bloom_batch.stopping_criterias.copy()
next_batch = next_batch.filter(
[generate_pb2.UpdatedRequest(id=next_batch.requests[0].id, blocks=[], slots=[])]
next_batch, _ = next_batch.filter(
default_bloom,
[generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[])],
[],
)
for _ in range(
@ -307,15 +309,13 @@ def test_batch_concatenate(
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
)
next_batch = next_batch.filter(
next_batch, _ = next_batch.filter(
default_bloom,
[
generate_pb2.UpdatedRequest(
id=next_batch.requests[0].id, blocks=[], slots=[]
),
generate_pb2.UpdatedRequest(
id=next_batch.requests[1].id, blocks=[], slots=[]
),
]
generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[]),
generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[]),
],
[],
)
for _ in range(
@ -339,8 +339,10 @@ def test_batch_concatenate(
== default_bloom_batch.stopping_criterias[0].max_new_tokens
)
next_batch = next_batch.filter(
[generate_pb2.UpdatedRequest(id=next_batch.requests[1].id, blocks=[], slots=[])]
next_batch, _ = next_batch.filter(
default_bloom,
[generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[])],
[],
)
for _ in range(

View File

@ -198,8 +198,10 @@ def test_causal_lm_generate_token_completion_multi(
default_multi_requests_causal_lm_batch.stopping_criterias.copy()
)
next_batch = next_batch.filter(
[generate_pb2.UpdatedRequest(id=next_batch.requests[0].id, blocks=[], slots=[])]
next_batch, _ = next_batch.filter(
default_causal_lm,
[generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[])],
[],
)
for _ in range(
@ -307,15 +309,13 @@ def test_batch_concatenate(
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
)
next_batch = next_batch.filter(
next_batch, _ = next_batch.filter(
default_causal_lm,
[
generate_pb2.UpdatedRequest(
id=next_batch.requests[0].id, blocks=[], slots=[]
),
generate_pb2.UpdatedRequest(
id=next_batch.requests[1].id, blocks=[], slots=[]
),
]
generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[]),
generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[]),
],
[],
)
for _ in range(
@ -337,15 +337,12 @@ def test_batch_concatenate(
== default_causal_lm_batch.stopping_criterias[0].max_new_tokens
)
next_batch = next_batch.filter(
next_batch, _ = next_batch.filter(
default_causal_lm,
[
generate_pb2.UpdatedRequest(
id=next_batch.requests[0].id, blocks=[], slots=[]
),
generate_pb2.UpdatedRequest(
id=next_batch.requests[1].id, blocks=[], slots=[]
),
]
generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[]),
],
[],
)
for _ in range(

View File

@ -206,8 +206,10 @@ def test_seq2seq_lm_generate_token_completion_multi(
)
assert generations[1].generated_text.generated_tokens == 5
next_batch = next_batch.filter(
[generate_pb2.UpdatedRequest(id=next_batch.requests[0].id, blocks=[], slots=[])]
next_batch, _ = next_batch.filter(
default_seq2seq_lm,
[generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[])],
[],
)
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
@ -341,15 +343,13 @@ def test_batch_concatenate(
)
assert generations[2].generated_text.generated_tokens == 5
next_batch = next_batch.filter(
next_batch, _ = next_batch.filter(
default_seq2seq_lm,
[
generate_pb2.UpdatedRequest(
id=next_batch.requests[0].id, blocks=[], slots=[]
),
generate_pb2.UpdatedRequest(
id=next_batch.requests[1].id, blocks=[], slots=[]
),
]
generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[]),
generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[]),
],
[],
)
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
@ -360,8 +360,10 @@ def test_batch_concatenate(
assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id
assert generations[0].generated_text.generated_tokens == 7
next_batch = next_batch.filter(
[generate_pb2.UpdatedRequest(id=next_batch.requests[1].id, blocks=[], slots=[])]
next_batch, _ = next_batch.filter(
default_seq2seq_lm,
[generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[])],
[],
)
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)

View File

@ -159,14 +159,48 @@ class CausalLMBatch(Batch):
@tracer.start_as_current_span("filter")
def filter(
self, updated_requests: List[generate_pb2.KeptRequest]
) -> Optional["CausalLMBatch"]:
request_ids = [r.id for r in updated_requests]
self,
model: "CausalLM",
kept_requests: List[generate_pb2.KeptRequest],
terminated_request_ids: List[int],
) -> Tuple[Optional["CausalLMBatch"], List[generate_pb2.TerminatedGeneration]]:
terminated_generations = []
for request_id in terminated_request_ids:
idx = self.requests_idx_mapping[request_id]
all_input_ids = self.all_input_ids[idx]
stopping_criteria = self.stopping_criterias[idx]
next_token_chooser = self.next_token_choosers[idx]
if len(request_ids) == 0:
raise ValueError("Batch must have at least one request")
# Decode generated tokens
output_text, _, _ = model.decode_token(
all_input_ids[:, 0],
prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1,
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
skip_special_tokens=True,
)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
terminated_generations.append(
generate_pb2.TerminatedGeneration(
id=request_id,
generated_text=generate_pb2.GeneratedText(
text=output_text,
generated_tokens=stopping_criteria.current_tokens,
finish_reason=generate_pb2.FINISH_REASON_TERMINATED,
seed=seed,
),
)
)
if not kept_requests:
return None, terminated_generations
request_ids = [r.id for r in kept_requests]
if len(request_ids) == len(self):
return self
return self, terminated_generations
keep_indices = []
@ -262,7 +296,7 @@ class CausalLMBatch(Batch):
self.padding_right_offset = new_padding_right_offset
self.max_tokens = max_tokens
return self
return self, terminated_generations
@classmethod
@tracer.start_as_current_span("concatenate")

View File

@ -215,15 +215,51 @@ class IdeficsCausalLMBatch(Batch):
@tracer.start_as_current_span("filter")
def filter(
self, updated_requests: List[generate_pb2.KeptRequest]
) -> Optional["IdeficsCausalLMBatch"]:
request_ids = [r.id for r in updated_requests]
self,
model: "IdeficsCausalLM",
kept_requests: List[generate_pb2.KeptRequest],
terminated_request_ids: List[int],
) -> Tuple[
Optional["IdeficsCausalLMBatch"], List[generate_pb2.TerminatedGeneration]
]:
terminated_generations = []
for request_id in terminated_request_ids:
idx = self.requests_idx_mapping[request_id]
all_input_ids = self.all_input_ids[idx]
stopping_criteria = self.stopping_criterias[idx]
next_token_chooser = self.next_token_choosers[idx]
# It deletes requests from the batch. For instance when client lost connection
if len(request_ids) == 0:
raise ValueError("Batch must have at least one request")
# Decode generated tokens
output_text, _, _ = model.decode_token(
all_input_ids[:, 0],
prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1,
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
skip_special_tokens=True,
)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
terminated_generations.append(
generate_pb2.TerminatedGeneration(
id=request_id,
generated_text=generate_pb2.GeneratedText(
text=output_text,
generated_tokens=stopping_criteria.current_tokens,
finish_reason=generate_pb2.FINISH_REASON_TERMINATED,
seed=seed,
),
)
)
if not kept_requests:
return None, terminated_generations
request_ids = [r.id for r in kept_requests]
if len(request_ids) == len(self):
return self
return self, terminated_generations
keep_indices = []
@ -330,7 +366,7 @@ class IdeficsCausalLMBatch(Batch):
self.padding_right_offset = new_padding_right_offset
self.max_tokens = max_tokens
return self
return self, terminated_generations
@classmethod
@tracer.start_as_current_span("concatenate")

View File

@ -196,14 +196,48 @@ class MambaBatch(Batch):
)
def filter(
self, updated_requests: List[generate_pb2.KeptRequest]
) -> Optional["MambaBatch"]:
request_ids = [r.id for r in updated_requests]
self,
model: "Mamba",
kept_requests: List[generate_pb2.KeptRequest],
terminated_request_ids: List[int],
) -> Tuple[Optional["MambaBatch"], List[generate_pb2.TerminatedGeneration]]:
terminated_generations = []
for request_id in terminated_request_ids:
idx = self.requests_idx_mapping[request_id]
all_input_ids = self.all_input_ids[idx]
stopping_criteria = self.stopping_criterias[idx]
next_token_chooser = self.next_token_choosers[idx]
if len(request_ids) == 0:
raise ValueError("Batch must have at least one request")
# Decode generated tokens
output_text, _, _ = model.decode_token(
all_input_ids[:, 0],
prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1,
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
skip_special_tokens=True,
)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
terminated_generations.append(
generate_pb2.TerminatedGeneration(
id=request_id,
generated_text=generate_pb2.GeneratedText(
text=output_text,
generated_tokens=stopping_criteria.current_tokens,
finish_reason=generate_pb2.FINISH_REASON_TERMINATED,
seed=seed,
),
)
)
if not kept_requests:
return None, terminated_generations
request_ids = [r.id for r in kept_requests]
if len(request_ids) == len(self):
return self
return self, terminated_generations
keep_indices = []
@ -278,7 +312,7 @@ class MambaBatch(Batch):
:, indices
]
self.inference_params.ssm_states = self.inference_params.ssm_states[:, indices]
return self
return self, terminated_generations
@classmethod
def concatenate(cls, batches: List["MambaBatch"]) -> "MambaBatch":

View File

@ -167,14 +167,49 @@ class Seq2SeqLMBatch(Batch):
@tracer.start_as_current_span("filter")
def filter(
self, updated_requests: List[generate_pb2.KeptRequest]
) -> Optional["Seq2SeqLMBatch"]:
request_ids = [r.id for r in updated_requests]
self,
model: "Seq2SeqLM",
kept_requests: List[generate_pb2.KeptRequest],
terminated_request_ids: List[int],
) -> Tuple[Optional["Seq2SeqLMBatch"], List[generate_pb2.TerminatedGeneration]]:
terminated_generations = []
for request_id in terminated_request_ids:
idx = self.requests_idx_mapping[request_id]
all_decoder_input_ids = self.all_decoder_input_ids[idx]
decoder_input_length = self.decoder_input_lengths[idx]
stopping_criteria = self.stopping_criterias[idx]
next_token_chooser = self.next_token_choosers[idx]
if len(request_ids) == 0:
raise ValueError("Batch must have at least one request")
# Decode generated tokens
output_text, _, _ = model.decode_token(
all_decoder_input_ids,
prefix_offset=len(all_decoder_input_ids) - decoder_input_length - 1,
read_offset=len(all_decoder_input_ids) - decoder_input_length,
skip_special_tokens=True,
)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
terminated_generations.append(
generate_pb2.TerminatedGeneration(
id=request_id,
generated_text=generate_pb2.GeneratedText(
text=output_text,
generated_tokens=stopping_criteria.current_tokens,
finish_reason=generate_pb2.FINISH_REASON_TERMINATED,
seed=seed,
),
)
)
if not kept_requests:
return None, terminated_generations
request_ids = [r.id for r in kept_requests]
if len(request_ids) == len(self):
return self
return self, terminated_generations
keep_indices = []
@ -281,7 +316,7 @@ class Seq2SeqLMBatch(Batch):
self.padding_right_offset = padding_right_offset
self.max_tokens = max_tokens
return self
return self, terminated_generations
@classmethod
@tracer.start_as_current_span("concatenate")

View File

@ -123,13 +123,19 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
@tracer.start_as_current_span("filter")
def filter(
self, updated_requests: List[generate_pb2.KeptRequest]
) -> Optional["VlmCausalLMBatch"]:
batch = super().filter(updated_requests)
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
return batch
self,
model: "VlmCausalLM",
kept_requests: List[generate_pb2.KeptRequest],
terminated_request_ids: List[int],
) -> Tuple[Optional["VlmCausalLMBatch"], List[generate_pb2.TerminatedGeneration]]:
batch, terminated_generations = super().filter(
model, kept_requests, terminated_request_ids
)
if batch is not None:
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
return batch, terminated_generations
@classmethod
def batch_tokenized_inputs(