fix rust and python unit-tests
This commit is contained in:
parent
73c3903214
commit
37266e2dbb
|
@ -16,4 +16,3 @@ jobs:
|
|||
fetch-depth: 0
|
||||
- name: Secret Scanning
|
||||
uses: trufflesecurity/trufflehog@main
|
||||
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
use std::sync::{Arc, Mutex};
|
||||
use std::fmt::Formatter;
|
||||
use std::sync::{Arc, Mutex, TryLockError};
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct BlockAllocation {
|
||||
allocated_blocks: Vec<u32>,
|
||||
allocated_slots: Vec<u32>,
|
||||
|
@ -53,7 +54,19 @@ impl Drop for BlockAllocation {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
impl std::fmt::Debug for BlockAllocation {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("BlockAllocation")
|
||||
.field("allocated_blocks", &self.allocated_blocks.len())
|
||||
.field("allocated_slots", &self.allocated_slots.len())
|
||||
.field("required_blocks", &self.required_blocks)
|
||||
.field("required_slots", &self.required_slots)
|
||||
.field("block_allocator", &self.block_allocator)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct BlockAllocator {
|
||||
free_blocks: Arc<Mutex<Vec<u32>>>,
|
||||
block_size: u32,
|
||||
|
@ -129,8 +142,7 @@ impl BlockAllocator {
|
|||
Err(AllocationError::NotEnoughPages)
|
||||
} else {
|
||||
let n_free_blocks = free_blocks.len();
|
||||
let allocated_blocks =
|
||||
free_blocks.split_off(n_free_blocks - clipped_required_blocks);
|
||||
let allocated_blocks = free_blocks.split_off(n_free_blocks - clipped_required_blocks);
|
||||
|
||||
let allocated_blocks = if repeats != 1 {
|
||||
let mut allocated_blocks = allocated_blocks.repeat(repeats);
|
||||
|
@ -140,9 +152,8 @@ impl BlockAllocator {
|
|||
allocated_blocks
|
||||
};
|
||||
|
||||
let mut allocated_slots = Vec::with_capacity(
|
||||
allocated_blocks.len() * self.block_size as usize * repeats,
|
||||
);
|
||||
let mut allocated_slots =
|
||||
Vec::with_capacity(allocated_blocks.len() * self.block_size as usize * repeats);
|
||||
|
||||
let required_slots = (prompt_tokens + decode_tokens) as usize;
|
||||
|
||||
|
@ -166,7 +177,30 @@ impl BlockAllocator {
|
|||
}
|
||||
|
||||
pub(crate) fn free(&self, blocks: Vec<u32>) {
|
||||
self.free_blocks.lock().expect("Lock could not be acquired. This is a bug.").extend(blocks)
|
||||
self.free_blocks
|
||||
.lock()
|
||||
.expect("Lock could not be acquired. This is a bug.")
|
||||
.extend(blocks)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for BlockAllocator {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
let mut d = f.debug_struct("BlockAllocator");
|
||||
d.field("block_size", &self.block_size)
|
||||
.field("window_size", &self.window_size);
|
||||
match self.free_blocks.try_lock() {
|
||||
Ok(guard) => {
|
||||
d.field("free_blocks", &(*guard).len());
|
||||
}
|
||||
Err(TryLockError::Poisoned(err)) => {
|
||||
d.field("free_blocks", &(**err.get_ref()).len());
|
||||
}
|
||||
Err(TryLockError::WouldBlock) => {
|
||||
d.field("free_blocks", &format_args!("<locked>"));
|
||||
}
|
||||
};
|
||||
d.finish()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -275,7 +275,9 @@ impl State {
|
|||
if prefill_tokens > prefill_token_budget {
|
||||
// Entry is over budget
|
||||
// Add it back to the front
|
||||
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
|
||||
tracing::debug!(
|
||||
"Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget}"
|
||||
);
|
||||
self.entries.push_front((id, entry));
|
||||
break;
|
||||
}
|
||||
|
@ -456,7 +458,7 @@ mod tests {
|
|||
let entry = Entry {
|
||||
request: ValidGenerateRequest {
|
||||
inputs: vec![],
|
||||
input_length: 0,
|
||||
input_length: 1,
|
||||
truncate: 0,
|
||||
decoder_input_details: false,
|
||||
parameters: ValidParameters {
|
||||
|
@ -567,7 +569,7 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_next_batch_token_budget() {
|
||||
let mut state = State::new(false, 1, None, 0, 2);
|
||||
let mut state = State::new(false, 1, None, 0, 16);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
state.append(entry1);
|
||||
|
@ -689,7 +691,7 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_queue_next_batch_token_speculate() {
|
||||
let queue = Queue::new(false, 1, None, 2, 16);
|
||||
let queue = Queue::new(true, 1, None, 2, 16);
|
||||
let (entry1, _guard1) = default_entry();
|
||||
let (entry2, _guard2) = default_entry();
|
||||
queue.append(entry1);
|
||||
|
|
|
@ -256,11 +256,7 @@ async fn prefill(
|
|||
.expect("ID not found in entries. This is a bug.");
|
||||
|
||||
// Send intermediate responses
|
||||
if let Err(_) = send_stream_responses(stream_responses, entry).map_err(|err| {
|
||||
tracing::error!("Entry response channel error.");
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
||||
err
|
||||
}) {
|
||||
if send_stream_responses(stream_responses, entry).is_err() {
|
||||
// Sending failed, remove entry
|
||||
entries
|
||||
.remove(&id)
|
||||
|
@ -405,7 +401,7 @@ async fn filter_batch(
|
|||
.filter_batch(
|
||||
id,
|
||||
updated_requests,
|
||||
terminated_entries.keys().map(|v| *v).collect(),
|
||||
terminated_entries.keys().copied().collect(),
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
|
@ -460,11 +456,14 @@ fn send_terminated_generations(
|
|||
};
|
||||
|
||||
// Send responses
|
||||
if let Err(_) = entry.response_tx.send(Ok(response)).map_err(|err| {
|
||||
let send_result = entry.response_tx.send(Ok(response)).map_err(|err| {
|
||||
tracing::error!("Entry response channel error.");
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
||||
err
|
||||
}) {
|
||||
});
|
||||
|
||||
if send_result.is_err() {
|
||||
// The channel is dropped, skip the rest of the messages
|
||||
continue 'terminated_generations;
|
||||
}
|
||||
}
|
||||
|
@ -504,11 +503,7 @@ fn filter_send_ended_generations(
|
|||
// If the generation has ended for this request, we send the responses to the channel and
|
||||
// remove the entry to drop it and free its blocks
|
||||
if finished {
|
||||
let _ = send_stream_responses(stream_responses, entry).map_err(|err| {
|
||||
tracing::error!("Entry response channel error.");
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
||||
err
|
||||
});
|
||||
let _ = send_stream_responses(stream_responses, entry);
|
||||
// Remove from entries and filter
|
||||
entries.remove(&id).expect("ID not found in entries. This is a bug.");
|
||||
return None;
|
||||
|
@ -525,7 +520,11 @@ fn send_stream_responses(
|
|||
entry: &Entry,
|
||||
) -> Result<(), Box<SendError<Result<InferStreamResponse, InferError>>>> {
|
||||
for response in stream_responses {
|
||||
entry.response_tx.send(Ok(response))?;
|
||||
entry.response_tx.send(Ok(response)).map_err(|err| {
|
||||
tracing::error!("Entry response channel error.");
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
||||
err
|
||||
})?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -541,7 +540,7 @@ fn filter_send_update_allocations(
|
|||
) -> (bool, IntMap<u64, Entry>) {
|
||||
let mut updated = false;
|
||||
|
||||
let ids: Vec<u64> = entries.keys().map(|v| *v).collect();
|
||||
let ids: Vec<u64> = entries.keys().copied().collect();
|
||||
let mut terminated_entries =
|
||||
IntMap::with_capacity_and_hasher(entries.len(), BuildNoHashHasher::default());
|
||||
|
||||
|
@ -581,11 +580,7 @@ fn filter_send_update_allocations(
|
|||
.expect("ID not found in stream_responses. This is a bug.");
|
||||
|
||||
// Send intermediate responses
|
||||
if let Err(_) = send_stream_responses(stream_response, entry).map_err(|err| {
|
||||
tracing::error!("Entry response channel error.");
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
||||
err
|
||||
}) {
|
||||
if send_stream_responses(stream_response, entry).is_err() {
|
||||
// Sending failed, remove entry
|
||||
entries
|
||||
.remove(id)
|
||||
|
|
|
@ -197,8 +197,10 @@ def test_causal_lm_generate_token_completion_multi(
|
|||
# Copy stopping_criterias before filtering
|
||||
stopping_criterias = default_multi_requests_bloom_batch.stopping_criterias.copy()
|
||||
|
||||
next_batch = next_batch.filter(
|
||||
[generate_pb2.UpdatedRequest(id=next_batch.requests[0].id, blocks=[], slots=[])]
|
||||
next_batch, _ = next_batch.filter(
|
||||
default_bloom,
|
||||
[generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[])],
|
||||
[],
|
||||
)
|
||||
|
||||
for _ in range(
|
||||
|
@ -307,15 +309,13 @@ def test_batch_concatenate(
|
|||
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
|
||||
)
|
||||
|
||||
next_batch = next_batch.filter(
|
||||
next_batch, _ = next_batch.filter(
|
||||
default_bloom,
|
||||
[
|
||||
generate_pb2.UpdatedRequest(
|
||||
id=next_batch.requests[0].id, blocks=[], slots=[]
|
||||
),
|
||||
generate_pb2.UpdatedRequest(
|
||||
id=next_batch.requests[1].id, blocks=[], slots=[]
|
||||
),
|
||||
]
|
||||
generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[]),
|
||||
generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[]),
|
||||
],
|
||||
[],
|
||||
)
|
||||
|
||||
for _ in range(
|
||||
|
@ -339,8 +339,10 @@ def test_batch_concatenate(
|
|||
== default_bloom_batch.stopping_criterias[0].max_new_tokens
|
||||
)
|
||||
|
||||
next_batch = next_batch.filter(
|
||||
[generate_pb2.UpdatedRequest(id=next_batch.requests[1].id, blocks=[], slots=[])]
|
||||
next_batch, _ = next_batch.filter(
|
||||
default_bloom,
|
||||
[generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[])],
|
||||
[],
|
||||
)
|
||||
|
||||
for _ in range(
|
||||
|
|
|
@ -198,8 +198,10 @@ def test_causal_lm_generate_token_completion_multi(
|
|||
default_multi_requests_causal_lm_batch.stopping_criterias.copy()
|
||||
)
|
||||
|
||||
next_batch = next_batch.filter(
|
||||
[generate_pb2.UpdatedRequest(id=next_batch.requests[0].id, blocks=[], slots=[])]
|
||||
next_batch, _ = next_batch.filter(
|
||||
default_causal_lm,
|
||||
[generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[])],
|
||||
[],
|
||||
)
|
||||
|
||||
for _ in range(
|
||||
|
@ -307,15 +309,13 @@ def test_batch_concatenate(
|
|||
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
|
||||
)
|
||||
|
||||
next_batch = next_batch.filter(
|
||||
next_batch, _ = next_batch.filter(
|
||||
default_causal_lm,
|
||||
[
|
||||
generate_pb2.UpdatedRequest(
|
||||
id=next_batch.requests[0].id, blocks=[], slots=[]
|
||||
),
|
||||
generate_pb2.UpdatedRequest(
|
||||
id=next_batch.requests[1].id, blocks=[], slots=[]
|
||||
),
|
||||
]
|
||||
generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[]),
|
||||
generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[]),
|
||||
],
|
||||
[],
|
||||
)
|
||||
|
||||
for _ in range(
|
||||
|
@ -337,15 +337,12 @@ def test_batch_concatenate(
|
|||
== default_causal_lm_batch.stopping_criterias[0].max_new_tokens
|
||||
)
|
||||
|
||||
next_batch = next_batch.filter(
|
||||
next_batch, _ = next_batch.filter(
|
||||
default_causal_lm,
|
||||
[
|
||||
generate_pb2.UpdatedRequest(
|
||||
id=next_batch.requests[0].id, blocks=[], slots=[]
|
||||
),
|
||||
generate_pb2.UpdatedRequest(
|
||||
id=next_batch.requests[1].id, blocks=[], slots=[]
|
||||
),
|
||||
]
|
||||
generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[]),
|
||||
],
|
||||
[],
|
||||
)
|
||||
|
||||
for _ in range(
|
||||
|
|
|
@ -206,8 +206,10 @@ def test_seq2seq_lm_generate_token_completion_multi(
|
|||
)
|
||||
assert generations[1].generated_text.generated_tokens == 5
|
||||
|
||||
next_batch = next_batch.filter(
|
||||
[generate_pb2.UpdatedRequest(id=next_batch.requests[0].id, blocks=[], slots=[])]
|
||||
next_batch, _ = next_batch.filter(
|
||||
default_seq2seq_lm,
|
||||
[generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[])],
|
||||
[],
|
||||
)
|
||||
|
||||
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
|
||||
|
@ -341,15 +343,13 @@ def test_batch_concatenate(
|
|||
)
|
||||
assert generations[2].generated_text.generated_tokens == 5
|
||||
|
||||
next_batch = next_batch.filter(
|
||||
next_batch, _ = next_batch.filter(
|
||||
default_seq2seq_lm,
|
||||
[
|
||||
generate_pb2.UpdatedRequest(
|
||||
id=next_batch.requests[0].id, blocks=[], slots=[]
|
||||
),
|
||||
generate_pb2.UpdatedRequest(
|
||||
id=next_batch.requests[1].id, blocks=[], slots=[]
|
||||
),
|
||||
]
|
||||
generate_pb2.KeptRequest(id=next_batch.requests[0].id, blocks=[], slots=[]),
|
||||
generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[]),
|
||||
],
|
||||
[],
|
||||
)
|
||||
|
||||
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
|
||||
|
@ -360,8 +360,10 @@ def test_batch_concatenate(
|
|||
assert generations[0].request_id == default_seq2seq_lm_batch.requests[0].id
|
||||
assert generations[0].generated_text.generated_tokens == 7
|
||||
|
||||
next_batch = next_batch.filter(
|
||||
[generate_pb2.UpdatedRequest(id=next_batch.requests[1].id, blocks=[], slots=[])]
|
||||
next_batch, _ = next_batch.filter(
|
||||
default_seq2seq_lm,
|
||||
[generate_pb2.KeptRequest(id=next_batch.requests[1].id, blocks=[], slots=[])],
|
||||
[],
|
||||
)
|
||||
|
||||
generations, next_batch, _ = default_seq2seq_lm.generate_token(next_batch)
|
||||
|
|
|
@ -159,14 +159,48 @@ class CausalLMBatch(Batch):
|
|||
|
||||
@tracer.start_as_current_span("filter")
|
||||
def filter(
|
||||
self, updated_requests: List[generate_pb2.KeptRequest]
|
||||
) -> Optional["CausalLMBatch"]:
|
||||
request_ids = [r.id for r in updated_requests]
|
||||
self,
|
||||
model: "CausalLM",
|
||||
kept_requests: List[generate_pb2.KeptRequest],
|
||||
terminated_request_ids: List[int],
|
||||
) -> Tuple[Optional["CausalLMBatch"], List[generate_pb2.TerminatedGeneration]]:
|
||||
terminated_generations = []
|
||||
for request_id in terminated_request_ids:
|
||||
idx = self.requests_idx_mapping[request_id]
|
||||
all_input_ids = self.all_input_ids[idx]
|
||||
stopping_criteria = self.stopping_criterias[idx]
|
||||
next_token_chooser = self.next_token_choosers[idx]
|
||||
|
||||
if len(request_ids) == 0:
|
||||
raise ValueError("Batch must have at least one request")
|
||||
# Decode generated tokens
|
||||
output_text, _, _ = model.decode_token(
|
||||
all_input_ids[:, 0],
|
||||
prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1,
|
||||
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
# Get seed
|
||||
if isinstance(next_token_chooser.choice, Sampling):
|
||||
seed = next_token_chooser.choice.seed
|
||||
else:
|
||||
seed = None
|
||||
|
||||
terminated_generations.append(
|
||||
generate_pb2.TerminatedGeneration(
|
||||
id=request_id,
|
||||
generated_text=generate_pb2.GeneratedText(
|
||||
text=output_text,
|
||||
generated_tokens=stopping_criteria.current_tokens,
|
||||
finish_reason=generate_pb2.FINISH_REASON_TERMINATED,
|
||||
seed=seed,
|
||||
),
|
||||
)
|
||||
)
|
||||
if not kept_requests:
|
||||
return None, terminated_generations
|
||||
|
||||
request_ids = [r.id for r in kept_requests]
|
||||
if len(request_ids) == len(self):
|
||||
return self
|
||||
return self, terminated_generations
|
||||
|
||||
keep_indices = []
|
||||
|
||||
|
@ -262,7 +296,7 @@ class CausalLMBatch(Batch):
|
|||
self.padding_right_offset = new_padding_right_offset
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
return self
|
||||
return self, terminated_generations
|
||||
|
||||
@classmethod
|
||||
@tracer.start_as_current_span("concatenate")
|
||||
|
|
|
@ -215,15 +215,51 @@ class IdeficsCausalLMBatch(Batch):
|
|||
|
||||
@tracer.start_as_current_span("filter")
|
||||
def filter(
|
||||
self, updated_requests: List[generate_pb2.KeptRequest]
|
||||
) -> Optional["IdeficsCausalLMBatch"]:
|
||||
request_ids = [r.id for r in updated_requests]
|
||||
self,
|
||||
model: "IdeficsCausalLM",
|
||||
kept_requests: List[generate_pb2.KeptRequest],
|
||||
terminated_request_ids: List[int],
|
||||
) -> Tuple[
|
||||
Optional["IdeficsCausalLMBatch"], List[generate_pb2.TerminatedGeneration]
|
||||
]:
|
||||
terminated_generations = []
|
||||
for request_id in terminated_request_ids:
|
||||
idx = self.requests_idx_mapping[request_id]
|
||||
all_input_ids = self.all_input_ids[idx]
|
||||
stopping_criteria = self.stopping_criterias[idx]
|
||||
next_token_chooser = self.next_token_choosers[idx]
|
||||
|
||||
# It deletes requests from the batch. For instance when client lost connection
|
||||
if len(request_ids) == 0:
|
||||
raise ValueError("Batch must have at least one request")
|
||||
# Decode generated tokens
|
||||
output_text, _, _ = model.decode_token(
|
||||
all_input_ids[:, 0],
|
||||
prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1,
|
||||
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
# Get seed
|
||||
if isinstance(next_token_chooser.choice, Sampling):
|
||||
seed = next_token_chooser.choice.seed
|
||||
else:
|
||||
seed = None
|
||||
|
||||
terminated_generations.append(
|
||||
generate_pb2.TerminatedGeneration(
|
||||
id=request_id,
|
||||
generated_text=generate_pb2.GeneratedText(
|
||||
text=output_text,
|
||||
generated_tokens=stopping_criteria.current_tokens,
|
||||
finish_reason=generate_pb2.FINISH_REASON_TERMINATED,
|
||||
seed=seed,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if not kept_requests:
|
||||
return None, terminated_generations
|
||||
|
||||
request_ids = [r.id for r in kept_requests]
|
||||
if len(request_ids) == len(self):
|
||||
return self
|
||||
return self, terminated_generations
|
||||
|
||||
keep_indices = []
|
||||
|
||||
|
@ -330,7 +366,7 @@ class IdeficsCausalLMBatch(Batch):
|
|||
self.padding_right_offset = new_padding_right_offset
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
return self
|
||||
return self, terminated_generations
|
||||
|
||||
@classmethod
|
||||
@tracer.start_as_current_span("concatenate")
|
||||
|
|
|
@ -196,14 +196,48 @@ class MambaBatch(Batch):
|
|||
)
|
||||
|
||||
def filter(
|
||||
self, updated_requests: List[generate_pb2.KeptRequest]
|
||||
) -> Optional["MambaBatch"]:
|
||||
request_ids = [r.id for r in updated_requests]
|
||||
self,
|
||||
model: "Mamba",
|
||||
kept_requests: List[generate_pb2.KeptRequest],
|
||||
terminated_request_ids: List[int],
|
||||
) -> Tuple[Optional["MambaBatch"], List[generate_pb2.TerminatedGeneration]]:
|
||||
terminated_generations = []
|
||||
for request_id in terminated_request_ids:
|
||||
idx = self.requests_idx_mapping[request_id]
|
||||
all_input_ids = self.all_input_ids[idx]
|
||||
stopping_criteria = self.stopping_criterias[idx]
|
||||
next_token_chooser = self.next_token_choosers[idx]
|
||||
|
||||
if len(request_ids) == 0:
|
||||
raise ValueError("Batch must have at least one request")
|
||||
# Decode generated tokens
|
||||
output_text, _, _ = model.decode_token(
|
||||
all_input_ids[:, 0],
|
||||
prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1,
|
||||
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
# Get seed
|
||||
if isinstance(next_token_chooser.choice, Sampling):
|
||||
seed = next_token_chooser.choice.seed
|
||||
else:
|
||||
seed = None
|
||||
|
||||
terminated_generations.append(
|
||||
generate_pb2.TerminatedGeneration(
|
||||
id=request_id,
|
||||
generated_text=generate_pb2.GeneratedText(
|
||||
text=output_text,
|
||||
generated_tokens=stopping_criteria.current_tokens,
|
||||
finish_reason=generate_pb2.FINISH_REASON_TERMINATED,
|
||||
seed=seed,
|
||||
),
|
||||
)
|
||||
)
|
||||
if not kept_requests:
|
||||
return None, terminated_generations
|
||||
|
||||
request_ids = [r.id for r in kept_requests]
|
||||
if len(request_ids) == len(self):
|
||||
return self
|
||||
return self, terminated_generations
|
||||
|
||||
keep_indices = []
|
||||
|
||||
|
@ -278,7 +312,7 @@ class MambaBatch(Batch):
|
|||
:, indices
|
||||
]
|
||||
self.inference_params.ssm_states = self.inference_params.ssm_states[:, indices]
|
||||
return self
|
||||
return self, terminated_generations
|
||||
|
||||
@classmethod
|
||||
def concatenate(cls, batches: List["MambaBatch"]) -> "MambaBatch":
|
||||
|
|
|
@ -167,14 +167,49 @@ class Seq2SeqLMBatch(Batch):
|
|||
|
||||
@tracer.start_as_current_span("filter")
|
||||
def filter(
|
||||
self, updated_requests: List[generate_pb2.KeptRequest]
|
||||
) -> Optional["Seq2SeqLMBatch"]:
|
||||
request_ids = [r.id for r in updated_requests]
|
||||
self,
|
||||
model: "Seq2SeqLM",
|
||||
kept_requests: List[generate_pb2.KeptRequest],
|
||||
terminated_request_ids: List[int],
|
||||
) -> Tuple[Optional["Seq2SeqLMBatch"], List[generate_pb2.TerminatedGeneration]]:
|
||||
terminated_generations = []
|
||||
for request_id in terminated_request_ids:
|
||||
idx = self.requests_idx_mapping[request_id]
|
||||
all_decoder_input_ids = self.all_decoder_input_ids[idx]
|
||||
decoder_input_length = self.decoder_input_lengths[idx]
|
||||
stopping_criteria = self.stopping_criterias[idx]
|
||||
next_token_chooser = self.next_token_choosers[idx]
|
||||
|
||||
if len(request_ids) == 0:
|
||||
raise ValueError("Batch must have at least one request")
|
||||
# Decode generated tokens
|
||||
output_text, _, _ = model.decode_token(
|
||||
all_decoder_input_ids,
|
||||
prefix_offset=len(all_decoder_input_ids) - decoder_input_length - 1,
|
||||
read_offset=len(all_decoder_input_ids) - decoder_input_length,
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
# Get seed
|
||||
if isinstance(next_token_chooser.choice, Sampling):
|
||||
seed = next_token_chooser.choice.seed
|
||||
else:
|
||||
seed = None
|
||||
|
||||
terminated_generations.append(
|
||||
generate_pb2.TerminatedGeneration(
|
||||
id=request_id,
|
||||
generated_text=generate_pb2.GeneratedText(
|
||||
text=output_text,
|
||||
generated_tokens=stopping_criteria.current_tokens,
|
||||
finish_reason=generate_pb2.FINISH_REASON_TERMINATED,
|
||||
seed=seed,
|
||||
),
|
||||
)
|
||||
)
|
||||
if not kept_requests:
|
||||
return None, terminated_generations
|
||||
|
||||
request_ids = [r.id for r in kept_requests]
|
||||
if len(request_ids) == len(self):
|
||||
return self
|
||||
return self, terminated_generations
|
||||
|
||||
keep_indices = []
|
||||
|
||||
|
@ -281,7 +316,7 @@ class Seq2SeqLMBatch(Batch):
|
|||
self.padding_right_offset = padding_right_offset
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
return self
|
||||
return self, terminated_generations
|
||||
|
||||
@classmethod
|
||||
@tracer.start_as_current_span("concatenate")
|
||||
|
|
|
@ -123,13 +123,19 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||
|
||||
@tracer.start_as_current_span("filter")
|
||||
def filter(
|
||||
self, updated_requests: List[generate_pb2.KeptRequest]
|
||||
) -> Optional["VlmCausalLMBatch"]:
|
||||
batch = super().filter(updated_requests)
|
||||
batch.pixel_values = None
|
||||
batch.pixel_attention_mask = None
|
||||
batch.image_sizes = None
|
||||
return batch
|
||||
self,
|
||||
model: "VlmCausalLM",
|
||||
kept_requests: List[generate_pb2.KeptRequest],
|
||||
terminated_request_ids: List[int],
|
||||
) -> Tuple[Optional["VlmCausalLMBatch"], List[generate_pb2.TerminatedGeneration]]:
|
||||
batch, terminated_generations = super().filter(
|
||||
model, kept_requests, terminated_request_ids
|
||||
)
|
||||
if batch is not None:
|
||||
batch.pixel_values = None
|
||||
batch.pixel_attention_mask = None
|
||||
batch.image_sizes = None
|
||||
return batch, terminated_generations
|
||||
|
||||
@classmethod
|
||||
def batch_tokenized_inputs(
|
||||
|
|
Loading…
Reference in New Issue