feat(router): drop permit after batching

This commit is contained in:
OlivierDehaene 2024-07-23 14:45:30 +02:00
parent e7e3aa6cac
commit 344427b6ab
No known key found for this signature in database
GPG Key ID: BB104D67809DA93C
6 changed files with 18 additions and 15 deletions

View File

@ -155,7 +155,7 @@ impl Infer {
let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0);
// Create stream and keep semaphore permit as long as generate lives
let (_permit, _input_length, mut stream) = self.generate_stream(request).await?;
let (_input_length, mut stream) = self.generate_stream(request).await?;
// Return values
let mut result_prefill = Vec::new();
@ -462,7 +462,6 @@ impl ToolGrammar {
/// Type alias for generation responses
pub(crate) type GenerateStreamResponse = (
OwnedSemaphorePermit,
u32, // input_length
UnboundedReceiverStream<Result<InferStreamResponse, InferError>>,
);

View File

@ -9,7 +9,7 @@ use text_generation_client::v2::{
Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
};
use text_generation_client::ChunksToString;
use tokio::sync::{mpsc, oneshot};
use tokio::sync::{mpsc, oneshot, OwnedSemaphorePermit};
use tokio::time::Instant;
use tracing::{info_span, instrument, Span};
@ -18,6 +18,8 @@ use tracing::{info_span, instrument, Span};
pub(crate) struct Entry {
/// Request
pub request: ValidGenerateRequest,
/// Permit
pub permit: Option<OwnedSemaphorePermit>,
/// Response sender to communicate between the Infer struct and the batching_task
pub response_tx: mpsc::UnboundedSender<Result<InferStreamResponse, InferError>>,
/// Span that will live as long as entry
@ -269,6 +271,9 @@ impl State {
break;
}
// Drop permit
entry.permit = None;
tracing::debug!("Accepting entry");
// Create a new span to link the batch back to this entry
let entry_batch_span = info_span!(parent: &entry.span, "infer");

View File

@ -84,6 +84,7 @@ impl Scheduler for SchedulerV2 {
self.queue.append(Entry {
request,
response_tx,
permit: Some(permit),
span: Span::current(),
temp_span: None,
queue_time: Instant::now(),
@ -95,11 +96,7 @@ impl Scheduler for SchedulerV2 {
self.batching_task_notifier.notify_one();
// Return stream
Ok((
permit,
input_length,
UnboundedReceiverStream::new(response_rx),
))
Ok((input_length, UnboundedReceiverStream::new(response_rx)))
}
}

View File

@ -12,7 +12,7 @@ use text_generation_client::v3::{
};
use text_generation_client::ChunksToString;
use text_generation_client::Input;
use tokio::sync::{mpsc, oneshot};
use tokio::sync::{mpsc, oneshot, OwnedSemaphorePermit};
use tokio::time::Instant;
use tracing::{info_span, instrument, Instrument, Span};
@ -21,6 +21,8 @@ use tracing::{info_span, instrument, Instrument, Span};
pub(crate) struct Entry {
/// Request
pub request: ValidGenerateRequest,
/// Permit
pub permit: Option<OwnedSemaphorePermit>,
/// Response sender to communicate between the Infer struct and the batching_task
pub response_tx: mpsc::UnboundedSender<Result<InferStreamResponse, InferError>>,
/// Span that will live as long as entry
@ -315,6 +317,9 @@ impl State {
}
};
// Drop permit
entry.permit = None;
tracing::debug!("Accepting entry");
// Create a new span to link the batch back to this entry
let entry_batch_span = info_span!(parent: &entry.span, "infer");

View File

@ -89,6 +89,7 @@ impl Scheduler for SchedulerV3 {
self.queue.append(Entry {
request,
response_tx,
permit: Some(permit),
span: Span::current(),
temp_span: None,
queue_time: Instant::now(),
@ -101,11 +102,7 @@ impl Scheduler for SchedulerV3 {
self.batching_task_notifier.notify_one();
// Return stream
Ok((
permit,
input_length,
UnboundedReceiverStream::new(response_rx),
))
Ok((input_length, UnboundedReceiverStream::new(response_rx)))
}
}

View File

@ -429,7 +429,7 @@ async fn generate_stream_internal(
} else {
match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await {
// Keep permit as long as generate_stream lives
Ok((_permit, _input_length, mut response_stream)) => {
Ok((_input_length, mut response_stream)) => {
let mut index = 0;
// Server-Sent Event stream
while let Some(response) = response_stream.next().await {