feat(router): add api-inference headers (#91)
This commit is contained in:
parent
4e685d907e
commit
f874c47831
|
@ -8,7 +8,7 @@ use crate::{
|
|||
use axum::extract::Extension;
|
||||
use axum::http::{HeaderMap, Method, StatusCode};
|
||||
use axum::response::sse::{Event, KeepAlive, Sse};
|
||||
use axum::response::IntoResponse;
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use axum::routing::{get, post};
|
||||
use axum::{http, Json, Router};
|
||||
use axum_tracing_opentelemetry::opentelemetry_tracing_layer;
|
||||
|
@ -32,7 +32,7 @@ async fn compat_generate(
|
|||
default_return_full_text: Extension<bool>,
|
||||
infer: Extension<Infer>,
|
||||
req: Json<CompatGenerateRequest>,
|
||||
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
|
||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
let mut req = req.0;
|
||||
|
||||
// default return_full_text given the pipeline_tag
|
||||
|
@ -116,6 +116,7 @@ async fn generate(
|
|||
let span = tracing::Span::current();
|
||||
let start_time = Instant::now();
|
||||
|
||||
let compute_characters = req.0.inputs.chars().count();
|
||||
let mut add_prompt = None;
|
||||
if req.0.parameters.return_full_text.unwrap_or(false) {
|
||||
add_prompt = Some(req.0.inputs.clone());
|
||||
|
@ -147,6 +148,15 @@ async fn generate(
|
|||
|
||||
// Headers
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("x-compute-type", "gpu+optimized".parse().unwrap());
|
||||
headers.insert(
|
||||
"x-compute-time",
|
||||
total_time.as_millis().to_string().parse().unwrap(),
|
||||
);
|
||||
headers.insert(
|
||||
"x-compute-characters",
|
||||
compute_characters.to_string().parse().unwrap(),
|
||||
);
|
||||
headers.insert(
|
||||
"x-total-time",
|
||||
total_time.as_millis().to_string().parse().unwrap(),
|
||||
|
@ -239,10 +249,22 @@ async fn generate(
|
|||
async fn generate_stream(
|
||||
infer: Extension<Infer>,
|
||||
req: Json<GenerateRequest>,
|
||||
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
|
||||
) -> (
|
||||
HeaderMap,
|
||||
Sse<impl Stream<Item = Result<Event, Infallible>>>,
|
||||
) {
|
||||
let span = tracing::Span::current();
|
||||
let start_time = Instant::now();
|
||||
|
||||
let compute_characters = req.0.inputs.chars().count();
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("x-compute-type", "gpu+optimized".parse().unwrap());
|
||||
headers.insert(
|
||||
"x-compute-characters",
|
||||
compute_characters.to_string().parse().unwrap(),
|
||||
);
|
||||
|
||||
let stream = async_stream::stream! {
|
||||
// Inference
|
||||
let mut end_reached = false;
|
||||
|
@ -360,7 +382,7 @@ async fn generate_stream(
|
|||
}
|
||||
};
|
||||
|
||||
Sse::new(stream).keep_alive(KeepAlive::default())
|
||||
(headers, Sse::new(stream).keep_alive(KeepAlive::default()))
|
||||
}
|
||||
|
||||
/// Prometheus metrics scrape endpoint
|
||||
|
|
Loading…
Reference in New Issue