From 37266e2dbb8b5f52cff80cf837a5e479c51172fe Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Tue, 11 Jun 2024 17:11:16 +0200 Subject: [PATCH] fix rust and python unit-tests --- .github/workflows/trufflehog.yml | 1 - router/src/infer/v3/block_allocator.rs | 52 +++++++++++++++---- router/src/infer/v3/queue.rs | 10 ++-- router/src/infer/v3/scheduler.rs | 35 ++++++------- server/tests/models/test_bloom.py | 26 +++++----- server/tests/models/test_causal_lm.py | 33 ++++++------ server/tests/models/test_seq2seq_lm.py | 26 +++++----- .../models/causal_lm.py | 48 ++++++++++++++--- .../models/idefics_causal_lm.py | 52 ++++++++++++++++--- server/text_generation_server/models/mamba.py | 48 ++++++++++++++--- .../models/seq2seq_lm.py | 49 ++++++++++++++--- .../models/vlm_causal_lm.py | 20 ++++--- 12 files changed, 288 insertions(+), 112 deletions(-) diff --git a/.github/workflows/trufflehog.yml b/.github/workflows/trufflehog.yml index 8bc60eff..b406d43b 100644 --- a/.github/workflows/trufflehog.yml +++ b/.github/workflows/trufflehog.yml @@ -16,4 +16,3 @@ jobs: fetch-depth: 0 - name: Secret Scanning uses: trufflesecurity/trufflehog@main - diff --git a/router/src/infer/v3/block_allocator.rs b/router/src/infer/v3/block_allocator.rs index a084a505..563f173f 100644 --- a/router/src/infer/v3/block_allocator.rs +++ b/router/src/infer/v3/block_allocator.rs @@ -1,7 +1,8 @@ -use std::sync::{Arc, Mutex}; +use std::fmt::Formatter; +use std::sync::{Arc, Mutex, TryLockError}; use thiserror::Error; -#[derive(Debug, Clone)] +#[derive(Clone)] pub(crate) struct BlockAllocation { allocated_blocks: Vec, allocated_slots: Vec, @@ -53,7 +54,19 @@ impl Drop for BlockAllocation { } } -#[derive(Debug, Clone)] +impl std::fmt::Debug for BlockAllocation { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("BlockAllocation") + .field("allocated_blocks", &self.allocated_blocks.len()) + .field("allocated_slots", &self.allocated_slots.len()) + .field("required_blocks", &self.required_blocks) + .field("required_slots", &self.required_slots) + .field("block_allocator", &self.block_allocator) + .finish() + } +} + +#[derive(Clone)] pub(crate) struct BlockAllocator { free_blocks: Arc>>, block_size: u32, @@ -129,8 +142,7 @@ impl BlockAllocator { Err(AllocationError::NotEnoughPages) } else { let n_free_blocks = free_blocks.len(); - let allocated_blocks = - free_blocks.split_off(n_free_blocks - clipped_required_blocks); + let allocated_blocks = free_blocks.split_off(n_free_blocks - clipped_required_blocks); let allocated_blocks = if repeats != 1 { let mut allocated_blocks = allocated_blocks.repeat(repeats); @@ -140,9 +152,8 @@ impl BlockAllocator { allocated_blocks }; - let mut allocated_slots = Vec::with_capacity( - allocated_blocks.len() * self.block_size as usize * repeats, - ); + let mut allocated_slots = + Vec::with_capacity(allocated_blocks.len() * self.block_size as usize * repeats); let required_slots = (prompt_tokens + decode_tokens) as usize; @@ -166,7 +177,30 @@ impl BlockAllocator { } pub(crate) fn free(&self, blocks: Vec) { - self.free_blocks.lock().expect("Lock could not be acquired. This is a bug.").extend(blocks) + self.free_blocks + .lock() + .expect("Lock could not be acquired. This is a bug.") + .extend(blocks) + } +} + +impl std::fmt::Debug for BlockAllocator { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let mut d = f.debug_struct("BlockAllocator"); + d.field("block_size", &self.block_size) + .field("window_size", &self.window_size); + match self.free_blocks.try_lock() { + Ok(guard) => { + d.field("free_blocks", &(*guard).len()); + } + Err(TryLockError::Poisoned(err)) => { + d.field("free_blocks", &(**err.get_ref()).len()); + } + Err(TryLockError::WouldBlock) => { + d.field("free_blocks", &format_args!("")); + } + }; + d.finish() } } diff --git a/router/src/infer/v3/queue.rs b/router/src/infer/v3/queue.rs index db09f9b4..d8085800 100644 --- a/router/src/infer/v3/queue.rs +++ b/router/src/infer/v3/queue.rs @@ -275,7 +275,9 @@ impl State { if prefill_tokens > prefill_token_budget { // Entry is over budget // Add it back to the front - tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate); + tracing::debug!( + "Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget}" + ); self.entries.push_front((id, entry)); break; } @@ -456,7 +458,7 @@ mod tests { let entry = Entry { request: ValidGenerateRequest { inputs: vec![], - input_length: 0, + input_length: 1, truncate: 0, decoder_input_details: false, parameters: ValidParameters { @@ -567,7 +569,7 @@ mod tests { #[tokio::test] async fn test_next_batch_token_budget() { - let mut state = State::new(false, 1, None, 0, 2); + let mut state = State::new(false, 1, None, 0, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); @@ -689,7 +691,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_token_speculate() { - let queue = Queue::new(false, 1, None, 2, 16); + let queue = Queue::new(true, 1, None, 2, 16); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); diff --git a/router/src/infer/v3/scheduler.rs b/router/src/infer/v3/scheduler.rs index ee93c20a..c03328b2 100644 --- a/router/src/infer/v3/scheduler.rs +++ b/router/src/infer/v3/scheduler.rs @@ -256,11 +256,7 @@ async fn prefill( .expect("ID not found in entries. This is a bug."); // Send intermediate responses - if let Err(_) = send_stream_responses(stream_responses, entry).map_err(|err| { - tracing::error!("Entry response channel error."); - metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); - err - }) { + if send_stream_responses(stream_responses, entry).is_err() { // Sending failed, remove entry entries .remove(&id) @@ -405,7 +401,7 @@ async fn filter_batch( .filter_batch( id, updated_requests, - terminated_entries.keys().map(|v| *v).collect(), + terminated_entries.keys().copied().collect(), ) .await .unwrap() @@ -460,11 +456,14 @@ fn send_terminated_generations( }; // Send responses - if let Err(_) = entry.response_tx.send(Ok(response)).map_err(|err| { + let send_result = entry.response_tx.send(Ok(response)).map_err(|err| { tracing::error!("Entry response channel error."); metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); err - }) { + }); + + if send_result.is_err() { + // The channel is dropped, skip the rest of the messages continue 'terminated_generations; } } @@ -504,11 +503,7 @@ fn filter_send_ended_generations( // If the generation has ended for this request, we send the responses to the channel and // remove the entry to drop it and free its blocks if finished { - let _ = send_stream_responses(stream_responses, entry).map_err(|err| { - tracing::error!("Entry response channel error."); - metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); - err - }); + let _ = send_stream_responses(stream_responses, entry); // Remove from entries and filter entries.remove(&id).expect("ID not found in entries. This is a bug."); return None; @@ -525,7 +520,11 @@ fn send_stream_responses( entry: &Entry, ) -> Result<(), Box>>> { for response in stream_responses { - entry.response_tx.send(Ok(response))?; + entry.response_tx.send(Ok(response)).map_err(|err| { + tracing::error!("Entry response channel error."); + metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); + err + })?; } Ok(()) } @@ -541,7 +540,7 @@ fn filter_send_update_allocations( ) -> (bool, IntMap) { let mut updated = false; - let ids: Vec = entries.keys().map(|v| *v).collect(); + let ids: Vec = entries.keys().copied().collect(); let mut terminated_entries = IntMap::with_capacity_and_hasher(entries.len(), BuildNoHashHasher::default()); @@ -581,11 +580,7 @@ fn filter_send_update_allocations( .expect("ID not found in stream_responses. This is a bug."); // Send intermediate responses - if let Err(_) = send_stream_responses(stream_response, entry).map_err(|err| { - tracing::error!("Entry response channel error."); - metrics::increment_counter!("tgi_request_failure", "err" => "dropped"); - err - }) { + if send_stream_responses(stream_response, entry).is_err() { // Sending failed, remove entry entries .remove(id) diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index 0daa5f41..78bde639 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -197,8 +197,10 @@ def test_causal_lm_generate_token_completion_multi( # Copy stopping_criterias before filtering stopping_criterias = default_multi_requests_bloom_batch.stopping_criterias.copy() - next_batch = next_batch.filter( - [generate_pb2.UpdatedRequest(id=next_batch.requests[0].id, blocks=[], slots=[])] + next_batch, _ = next_batch.filter( + default_bloom, + [generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[])], + [], ) for _ in range( @@ -307,15 +309,13 @@ def test_batch_concatenate( == default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens ) - next_batch = next_batch.filter( + next_batch, _ = next_batch.filter( + default_bloom, [ - generate_pb2.UpdatedRequest( - id=next_batch.requests[0].id, blocks=[], slots=[] - ), - generate_pb2.UpdatedRequest( - id=next_batch.requests[1].id, blocks=[], slots=[] - ), - ] + generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[]), + generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[]), + ], + [], ) for _ in range( @@ -339,8 +339,10 @@ def test_batch_concatenate( == default_bloom_batch.stopping_criterias[0].max_new_tokens ) - next_batch = next_batch.filter( - [generate_pb2.UpdatedRequest(id=next_batch.requests[1].id, blocks=[], slots=[])] + next_batch, _ = next_batch.filter( + default_bloom, + [generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[])], + [], ) for _ in range( diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index 547da81f..6716606c 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -198,8 +198,10 @@ def test_causal_lm_generate_token_completion_multi( default_multi_requests_causal_lm_batch.stopping_criterias.copy() ) - next_batch = next_batch.filter( - [generate_pb2.UpdatedRequest(id=next_batch.requests[0].id, blocks=[], slots=[])] + next_batch, _ = next_batch.filter( + default_causal_lm, + [generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[])], + [], ) for _ in range( @@ -307,15 +309,13 @@ def test_batch_concatenate( == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens ) - next_batch = next_batch.filter( + next_batch, _ = next_batch.filter( + default_causal_lm, [ - generate_pb2.UpdatedRequest( - id=next_batch.requests[0].id, blocks=[], slots=[] - ), - generate_pb2.UpdatedRequest( - id=next_batch.requests[1].id, blocks=[], slots=[] - ), - ] + generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[]), + generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[]), + ], + [], ) for _ in range( @@ -337,15 +337,12 @@ def test_batch_concatenate( == default_causal_lm_batch.stopping_criterias[0].max_new_tokens ) - next_batch = next_batch.filter( + next_batch, _ = next_batch.filter( + default_causal_lm, [ - generate_pb2.UpdatedRequest( - id=next_batch.requests[0].id, blocks=[], slots=[] - ), - generate_pb2.UpdatedRequest( - id=next_batch.requests[1].id, blocks=[], slots=[] - ), - ] + generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[]), + ], + [], ) for _ in range( diff --git a/server/tests/models/test_seq2seq_lm.py b/server/tests/models/test_seq2seq_lm.py index 17b5fa50..f1d2bb75 100644 --- a/server/tests/models/test_seq2seq_lm.py +++ b/server/tests/models/test_seq2seq_lm.py @@ -206,8 +206,10 @@ def test_seq2seq_lm_generate_token_completion_multi( ) assert generations[1].generated_text.generated_tokens == 5 - next_batch = next_batch.filter( - [generate_pb2.UpdatedRequest(id=next_batch.requests[0].id, blocks=[], slots=[])] + next_batch, _ = next_batch.filter( + default_seq2seq_lm, + [generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[])], + [], ) generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch) @@ -341,15 +343,13 @@ def test_batch_concatenate( ) assert generations[2].generated_text.generated_tokens == 5 - next_batch = next_batch.filter( + next_batch, _ = next_batch.filter( + default_seq2seq_lm, [ - generate_pb2.UpdatedRequest( - id=next_batch.requests[0].id, blocks=[], slots=[] - ), - generate_pb2.UpdatedRequest( - id=next_batch.requests[1].id, blocks=[], slots=[] - ), - ] + generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[]), + generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[]), + ], + [], ) generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch) @@ -360,8 +360,10 @@ def test_batch_concatenate( assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id assert generations[0].generated_text.generated_tokens == 7 - next_batch = next_batch.filter( - [generate_pb2.UpdatedRequest(id=next_batch.requests[1].id, blocks=[], slots=[])] + next_batch, _ = next_batch.filter( + default_seq2seq_lm, + [generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[])], + [], ) generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 40a4f100..f3b94e8c 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -159,14 +159,48 @@ class CausalLMBatch(Batch): @tracer.start_as_current_span("filter") def filter( - self, updated_requests: List[generate_pb2.KeptRequest] - ) -> Optional["CausalLMBatch"]: - request_ids = [r.id for r in updated_requests] + self, + model: "CausalLM", + kept_requests: List[generate_pb2.KeptRequest], + terminated_request_ids: List[int], + ) -> Tuple[Optional["CausalLMBatch"], List[generate_pb2.TerminatedGeneration]]: + terminated_generations = [] + for request_id in terminated_request_ids: + idx = self.requests_idx_mapping[request_id] + all_input_ids = self.all_input_ids[idx] + stopping_criteria = self.stopping_criterias[idx] + next_token_chooser = self.next_token_choosers[idx] - if len(request_ids) == 0: - raise ValueError("Batch must have at least one request") + # Decode generated tokens + output_text, _, _ = model.decode_token( + all_input_ids[:, 0], + prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1, + read_offset=len(all_input_ids) - stopping_criteria.current_tokens, + skip_special_tokens=True, + ) + # Get seed + if isinstance(next_token_chooser.choice, Sampling): + seed = next_token_chooser.choice.seed + else: + seed = None + + terminated_generations.append( + generate_pb2.TerminatedGeneration( + id=request_id, + generated_text=generate_pb2.GeneratedText( + text=output_text, + generated_tokens=stopping_criteria.current_tokens, + finish_reason=generate_pb2.FINISH_REASON_TERMINATED, + seed=seed, + ), + ) + ) + if not kept_requests: + return None, terminated_generations + + request_ids = [r.id for r in kept_requests] if len(request_ids) == len(self): - return self + return self, terminated_generations keep_indices = [] @@ -262,7 +296,7 @@ class CausalLMBatch(Batch): self.padding_right_offset = new_padding_right_offset self.max_tokens = max_tokens - return self + return self, terminated_generations @classmethod @tracer.start_as_current_span("concatenate") diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index 495a47e5..f92378cb 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -215,15 +215,51 @@ class IdeficsCausalLMBatch(Batch): @tracer.start_as_current_span("filter") def filter( - self, updated_requests: List[generate_pb2.KeptRequest] - ) -> Optional["IdeficsCausalLMBatch"]: - request_ids = [r.id for r in updated_requests] + self, + model: "IdeficsCausalLM", + kept_requests: List[generate_pb2.KeptRequest], + terminated_request_ids: List[int], + ) -> Tuple[ + Optional["IdeficsCausalLMBatch"], List[generate_pb2.TerminatedGeneration] + ]: + terminated_generations = [] + for request_id in terminated_request_ids: + idx = self.requests_idx_mapping[request_id] + all_input_ids = self.all_input_ids[idx] + stopping_criteria = self.stopping_criterias[idx] + next_token_chooser = self.next_token_choosers[idx] - # It deletes requests from the batch. For instance when client lost connection - if len(request_ids) == 0: - raise ValueError("Batch must have at least one request") + # Decode generated tokens + output_text, _, _ = model.decode_token( + all_input_ids[:, 0], + prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1, + read_offset=len(all_input_ids) - stopping_criteria.current_tokens, + skip_special_tokens=True, + ) + # Get seed + if isinstance(next_token_chooser.choice, Sampling): + seed = next_token_chooser.choice.seed + else: + seed = None + + terminated_generations.append( + generate_pb2.TerminatedGeneration( + id=request_id, + generated_text=generate_pb2.GeneratedText( + text=output_text, + generated_tokens=stopping_criteria.current_tokens, + finish_reason=generate_pb2.FINISH_REASON_TERMINATED, + seed=seed, + ), + ) + ) + + if not kept_requests: + return None, terminated_generations + + request_ids = [r.id for r in kept_requests] if len(request_ids) == len(self): - return self + return self, terminated_generations keep_indices = [] @@ -330,7 +366,7 @@ class IdeficsCausalLMBatch(Batch): self.padding_right_offset = new_padding_right_offset self.max_tokens = max_tokens - return self + return self, terminated_generations @classmethod @tracer.start_as_current_span("concatenate") diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index 0340ca55..64cb739e 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -196,14 +196,48 @@ class MambaBatch(Batch): ) def filter( - self, updated_requests: List[generate_pb2.KeptRequest] - ) -> Optional["MambaBatch"]: - request_ids = [r.id for r in updated_requests] + self, + model: "Mamba", + kept_requests: List[generate_pb2.KeptRequest], + terminated_request_ids: List[int], + ) -> Tuple[Optional["MambaBatch"], List[generate_pb2.TerminatedGeneration]]: + terminated_generations = [] + for request_id in terminated_request_ids: + idx = self.requests_idx_mapping[request_id] + all_input_ids = self.all_input_ids[idx] + stopping_criteria = self.stopping_criterias[idx] + next_token_chooser = self.next_token_choosers[idx] - if len(request_ids) == 0: - raise ValueError("Batch must have at least one request") + # Decode generated tokens + output_text, _, _ = model.decode_token( + all_input_ids[:, 0], + prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1, + read_offset=len(all_input_ids) - stopping_criteria.current_tokens, + skip_special_tokens=True, + ) + # Get seed + if isinstance(next_token_chooser.choice, Sampling): + seed = next_token_chooser.choice.seed + else: + seed = None + + terminated_generations.append( + generate_pb2.TerminatedGeneration( + id=request_id, + generated_text=generate_pb2.GeneratedText( + text=output_text, + generated_tokens=stopping_criteria.current_tokens, + finish_reason=generate_pb2.FINISH_REASON_TERMINATED, + seed=seed, + ), + ) + ) + if not kept_requests: + return None, terminated_generations + + request_ids = [r.id for r in kept_requests] if len(request_ids) == len(self): - return self + return self, terminated_generations keep_indices = [] @@ -278,7 +312,7 @@ class MambaBatch(Batch): :, indices ] self.inference_params.ssm_states = self.inference_params.ssm_states[:, indices] - return self + return self, terminated_generations @classmethod def concatenate(cls, batches: List["MambaBatch"]) -> "MambaBatch": diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 77407118..3cf874fa 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -167,14 +167,49 @@ class Seq2SeqLMBatch(Batch): @tracer.start_as_current_span("filter") def filter( - self, updated_requests: List[generate_pb2.KeptRequest] - ) -> Optional["Seq2SeqLMBatch"]: - request_ids = [r.id for r in updated_requests] + self, + model: "Seq2SeqLM", + kept_requests: List[generate_pb2.KeptRequest], + terminated_request_ids: List[int], + ) -> Tuple[Optional["Seq2SeqLMBatch"], List[generate_pb2.TerminatedGeneration]]: + terminated_generations = [] + for request_id in terminated_request_ids: + idx = self.requests_idx_mapping[request_id] + all_decoder_input_ids = self.all_decoder_input_ids[idx] + decoder_input_length = self.decoder_input_lengths[idx] + stopping_criteria = self.stopping_criterias[idx] + next_token_chooser = self.next_token_choosers[idx] - if len(request_ids) == 0: - raise ValueError("Batch must have at least one request") + # Decode generated tokens + output_text, _, _ = model.decode_token( + all_decoder_input_ids, + prefix_offset=len(all_decoder_input_ids) - decoder_input_length - 1, + read_offset=len(all_decoder_input_ids) - decoder_input_length, + skip_special_tokens=True, + ) + # Get seed + if isinstance(next_token_chooser.choice, Sampling): + seed = next_token_chooser.choice.seed + else: + seed = None + + terminated_generations.append( + generate_pb2.TerminatedGeneration( + id=request_id, + generated_text=generate_pb2.GeneratedText( + text=output_text, + generated_tokens=stopping_criteria.current_tokens, + finish_reason=generate_pb2.FINISH_REASON_TERMINATED, + seed=seed, + ), + ) + ) + if not kept_requests: + return None, terminated_generations + + request_ids = [r.id for r in kept_requests] if len(request_ids) == len(self): - return self + return self, terminated_generations keep_indices = [] @@ -281,7 +316,7 @@ class Seq2SeqLMBatch(Batch): self.padding_right_offset = padding_right_offset self.max_tokens = max_tokens - return self + return self, terminated_generations @classmethod @tracer.start_as_current_span("concatenate") diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index e3d0bee8..cee8b45f 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -123,13 +123,19 @@ class VlmCausalLMBatch(FlashCausalLMBatch): @tracer.start_as_current_span("filter") def filter( - self, updated_requests: List[generate_pb2.KeptRequest] - ) -> Optional["VlmCausalLMBatch"]: - batch = super().filter(updated_requests) - batch.pixel_values = None - batch.pixel_attention_mask = None - batch.image_sizes = None - return batch + self, + model: "VlmCausalLM", + kept_requests: List[generate_pb2.KeptRequest], + terminated_request_ids: List[int], + ) -> Tuple[Optional["VlmCausalLMBatch"], List[generate_pb2.TerminatedGeneration]]: + batch, terminated_generations = super().filter( + model, kept_requests, terminated_request_ids + ) + if batch is not None: + batch.pixel_values = None + batch.pixel_attention_mask = None + batch.image_sizes = None + return batch, terminated_generations @classmethod def batch_tokenized_inputs(