diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index db9070d4..675671b8 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -155,7 +155,7 @@ impl Infer { let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0); // Create stream and keep semaphore permit as long as generate lives - let (_permit, _input_length, mut stream) = self.generate_stream(request).await?; + let (_input_length, mut stream) = self.generate_stream(request).await?; // Return values let mut result_prefill = Vec::new(); @@ -462,7 +462,6 @@ impl ToolGrammar { /// Type alias for generation responses pub(crate) type GenerateStreamResponse = ( - OwnedSemaphorePermit, u32, // input_length UnboundedReceiverStream>, ); diff --git a/router/src/infer/v2/queue.rs b/router/src/infer/v2/queue.rs index 0b51645a..bf035c5f 100644 --- a/router/src/infer/v2/queue.rs +++ b/router/src/infer/v2/queue.rs @@ -9,7 +9,7 @@ use text_generation_client::v2::{ Batch, GrammarType, NextTokenChooserParameters, Request, StoppingCriteriaParameters, }; use text_generation_client::ChunksToString; -use tokio::sync::{mpsc, oneshot}; +use tokio::sync::{mpsc, oneshot, OwnedSemaphorePermit}; use tokio::time::Instant; use tracing::{info_span, instrument, Span}; @@ -18,6 +18,8 @@ use tracing::{info_span, instrument, Span}; pub(crate) struct Entry { /// Request pub request: ValidGenerateRequest, + /// Permit + pub permit: Option, /// Response sender to communicate between the Infer struct and the batching_task pub response_tx: mpsc::UnboundedSender>, /// Span that will live as long as entry @@ -269,6 +271,9 @@ impl State { break; } + // Drop permit + entry.permit = None; + tracing::debug!("Accepting entry"); // Create a new span to link the batch back to this entry let entry_batch_span = info_span!(parent: &entry.span, "infer"); diff --git a/router/src/infer/v2/scheduler.rs b/router/src/infer/v2/scheduler.rs index 97379bc5..ec7e35e8 100644 --- a/router/src/infer/v2/scheduler.rs +++ b/router/src/infer/v2/scheduler.rs @@ -84,6 +84,7 @@ impl Scheduler for SchedulerV2 { self.queue.append(Entry { request, response_tx, + permit: Some(permit), span: Span::current(), temp_span: None, queue_time: Instant::now(), @@ -95,11 +96,7 @@ impl Scheduler for SchedulerV2 { self.batching_task_notifier.notify_one(); // Return stream - Ok(( - permit, - input_length, - UnboundedReceiverStream::new(response_rx), - )) + Ok((input_length, UnboundedReceiverStream::new(response_rx))) } } diff --git a/router/src/infer/v3/queue.rs b/router/src/infer/v3/queue.rs index 894d9cab..1805d9ae 100644 --- a/router/src/infer/v3/queue.rs +++ b/router/src/infer/v3/queue.rs @@ -12,7 +12,7 @@ use text_generation_client::v3::{ }; use text_generation_client::ChunksToString; use text_generation_client::Input; -use tokio::sync::{mpsc, oneshot}; +use tokio::sync::{mpsc, oneshot, OwnedSemaphorePermit}; use tokio::time::Instant; use tracing::{info_span, instrument, Instrument, Span}; @@ -21,6 +21,8 @@ use tracing::{info_span, instrument, Instrument, Span}; pub(crate) struct Entry { /// Request pub request: ValidGenerateRequest, + /// Permit + pub permit: Option, /// Response sender to communicate between the Infer struct and the batching_task pub response_tx: mpsc::UnboundedSender>, /// Span that will live as long as entry @@ -315,6 +317,9 @@ impl State { } }; + // Drop permit + entry.permit = None; + tracing::debug!("Accepting entry"); // Create a new span to link the batch back to this entry let entry_batch_span = info_span!(parent: &entry.span, "infer"); diff --git a/router/src/infer/v3/scheduler.rs b/router/src/infer/v3/scheduler.rs index 26cd9584..639c89d8 100644 --- a/router/src/infer/v3/scheduler.rs +++ b/router/src/infer/v3/scheduler.rs @@ -89,6 +89,7 @@ impl Scheduler for SchedulerV3 { self.queue.append(Entry { request, response_tx, + permit: Some(permit), span: Span::current(), temp_span: None, queue_time: Instant::now(), @@ -101,11 +102,7 @@ impl Scheduler for SchedulerV3 { self.batching_task_notifier.notify_one(); // Return stream - Ok(( - permit, - input_length, - UnboundedReceiverStream::new(response_rx), - )) + Ok((input_length, UnboundedReceiverStream::new(response_rx))) } } diff --git a/router/src/server.rs b/router/src/server.rs index c56c39a3..44e86a87 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -429,7 +429,7 @@ async fn generate_stream_internal( } else { match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await { // Keep permit as long as generate_stream lives - Ok((_permit, _input_length, mut response_stream)) => { + Ok((_input_length, mut response_stream)) => { let mut index = 0; // Server-Sent Event stream while let Some(response) = response_stream.next().await {