diff --git a/router/src/infer.rs b/router/src/infer.rs index bf5920da..b4094c1b 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -90,6 +90,7 @@ impl Infer { ) -> Result< ( OwnedSemaphorePermit, + u32, UnboundedReceiverStream>, ), InferError, @@ -114,6 +115,7 @@ impl Infer { // MPSC channel to communicate with the background batching task let (response_tx, response_rx) = mpsc::unbounded_channel(); + let input_length = valid_request.input_length; // Append the request to the queue self.queue.append(Entry { @@ -130,7 +132,11 @@ impl Infer { self.shared.batching_task.notify_one(); // Return stream - Ok((permit, UnboundedReceiverStream::new(response_rx))) + Ok(( + permit, + input_length, + UnboundedReceiverStream::new(response_rx), + )) } /// Add a new request to the queue and return a InferResponse @@ -142,7 +148,7 @@ impl Infer { let use_top_tokens = request.parameters.top_n_tokens.is_some_and(|x| x > 0); // Create stream and keep semaphore permit as long as generate lives - let (_permit, mut stream) = self.generate_stream(request).await?; + let (_permit, _input_length, mut stream) = self.generate_stream(request).await?; // Return values let mut result_prefill = Vec::new(); @@ -196,6 +202,7 @@ impl Infer { { Ok(InferResponse { prefill: result_prefill, + _input_length, tokens: result_tokens, generated_text, queued, @@ -636,6 +643,10 @@ pub(crate) enum InferStreamResponse { #[derive(Debug)] pub(crate) struct InferResponse { + /// input_length is the input as perceived by the rust tokenizer in the + /// validation pathway. It is redundant with prefill.len() but prefill + /// has data only if the user asked for it. This will always be filled. + pub(crate) _input_length: u32, pub(crate) prefill: Vec, pub(crate) tokens: Vec, pub(crate) generated_text: GeneratedText, diff --git a/router/src/server.rs b/router/src/server.rs index fe1b8309..3db5c7cd 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -170,6 +170,7 @@ async fn generate( }; // Token details + let input_length = response._input_length; let details = match details { true => { // convert best_of_responses @@ -257,6 +258,11 @@ async fn generate( "x-time-per-token", time_per_token.as_millis().to_string().parse().unwrap(), ); + headers.insert("x-prompt-tokens", input_length.into()); + headers.insert( + "x-generated-tokens", + response.generated_text.generated_tokens.into(), + ); // Metrics metrics::increment_counter!("tgi_request_success"); @@ -378,7 +384,7 @@ async fn generate_stream( } else { match infer.generate_stream(req).instrument(info_span!(parent: &span, "async_stream")).await { // Keep permit as long as generate_stream lives - Ok((_permit, mut response_stream)) => { + Ok((_permit, _input_length, mut response_stream)) => { // Server-Sent Event stream while let Some(response) = response_stream.next().await { match response {