From f874c47831e12a1d447da12cb9e699cb2fb001da Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Thu, 2 Mar 2023 11:41:51 +0100 Subject: [PATCH] feat(router): add api-inference headers (#91) --- router/src/server.rs | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/router/src/server.rs b/router/src/server.rs index 75f84ba..5d4140e 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -8,7 +8,7 @@ use crate::{ use axum::extract::Extension; use axum::http::{HeaderMap, Method, StatusCode}; use axum::response::sse::{Event, KeepAlive, Sse}; -use axum::response::IntoResponse; +use axum::response::{IntoResponse, Response}; use axum::routing::{get, post}; use axum::{http, Json, Router}; use axum_tracing_opentelemetry::opentelemetry_tracing_layer; @@ -32,7 +32,7 @@ async fn compat_generate( default_return_full_text: Extension, infer: Extension, req: Json, -) -> Result)> { +) -> Result)> { let mut req = req.0; // default return_full_text given the pipeline_tag @@ -116,6 +116,7 @@ async fn generate( let span = tracing::Span::current(); let start_time = Instant::now(); + let compute_characters = req.0.inputs.chars().count(); let mut add_prompt = None; if req.0.parameters.return_full_text.unwrap_or(false) { add_prompt = Some(req.0.inputs.clone()); @@ -147,6 +148,15 @@ async fn generate( // Headers let mut headers = HeaderMap::new(); + headers.insert("x-compute-type", "gpu+optimized".parse().unwrap()); + headers.insert( + "x-compute-time", + total_time.as_millis().to_string().parse().unwrap(), + ); + headers.insert( + "x-compute-characters", + compute_characters.to_string().parse().unwrap(), + ); headers.insert( "x-total-time", total_time.as_millis().to_string().parse().unwrap(), @@ -239,10 +249,22 @@ async fn generate( async fn generate_stream( infer: Extension, req: Json, -) -> Sse>> { +) -> ( + HeaderMap, + Sse>>, +) { let span = tracing::Span::current(); let start_time = Instant::now(); + let compute_characters = req.0.inputs.chars().count(); + + let mut headers = HeaderMap::new(); + headers.insert("x-compute-type", "gpu+optimized".parse().unwrap()); + headers.insert( + "x-compute-characters", + compute_characters.to_string().parse().unwrap(), + ); + let stream = async_stream::stream! { // Inference let mut end_reached = false; @@ -360,7 +382,7 @@ async fn generate_stream( } }; - Sse::new(stream).keep_alive(KeepAlive::default()) + (headers, Sse::new(stream).keep_alive(KeepAlive::default())) } /// Prometheus metrics scrape endpoint