feat(router): add api-inference headers (#91)

This commit is contained in:
OlivierDehaene 2023-03-02 11:41:51 +01:00 committed by GitHub
parent 4e685d907e
commit f874c47831
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 26 additions and 4 deletions

View File

@ -8,7 +8,7 @@ use crate::{
use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode};
use axum::response::sse::{Event, KeepAlive, Sse};
use axum::response::IntoResponse;
use axum::response::{IntoResponse, Response};
use axum::routing::{get, post};
use axum::{http, Json, Router};
use axum_tracing_opentelemetry::opentelemetry_tracing_layer;
@ -32,7 +32,7 @@ async fn compat_generate(
default_return_full_text: Extension<bool>,
infer: Extension<Infer>,
req: Json<CompatGenerateRequest>,
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let mut req = req.0;
// default return_full_text given the pipeline_tag
@ -116,6 +116,7 @@ async fn generate(
let span = tracing::Span::current();
let start_time = Instant::now();
let compute_characters = req.0.inputs.chars().count();
let mut add_prompt = None;
if req.0.parameters.return_full_text.unwrap_or(false) {
add_prompt = Some(req.0.inputs.clone());
@ -147,6 +148,15 @@ async fn generate(
// Headers
let mut headers = HeaderMap::new();
headers.insert("x-compute-type", "gpu+optimized".parse().unwrap());
headers.insert(
"x-compute-time",
total_time.as_millis().to_string().parse().unwrap(),
);
headers.insert(
"x-compute-characters",
compute_characters.to_string().parse().unwrap(),
);
headers.insert(
"x-total-time",
total_time.as_millis().to_string().parse().unwrap(),
@ -239,10 +249,22 @@ async fn generate(
async fn generate_stream(
infer: Extension<Infer>,
req: Json<GenerateRequest>,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
) -> (
HeaderMap,
Sse<impl Stream<Item = Result<Event, Infallible>>>,
) {
let span = tracing::Span::current();
let start_time = Instant::now();
let compute_characters = req.0.inputs.chars().count();
let mut headers = HeaderMap::new();
headers.insert("x-compute-type", "gpu+optimized".parse().unwrap());
headers.insert(
"x-compute-characters",
compute_characters.to_string().parse().unwrap(),
);
let stream = async_stream::stream! {
// Inference
let mut end_reached = false;
@ -360,7 +382,7 @@ async fn generate_stream(
}
};
Sse::new(stream).keep_alive(KeepAlive::default())
(headers, Sse::new(stream).keep_alive(KeepAlive::default()))
}
/// Prometheus metrics scrape endpoint