feat(router): add api-inference headers (#91)
This commit is contained in:
parent
4e685d907e
commit
f874c47831
|
@ -8,7 +8,7 @@ use crate::{
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
use axum::http::{HeaderMap, Method, StatusCode};
|
use axum::http::{HeaderMap, Method, StatusCode};
|
||||||
use axum::response::sse::{Event, KeepAlive, Sse};
|
use axum::response::sse::{Event, KeepAlive, Sse};
|
||||||
use axum::response::IntoResponse;
|
use axum::response::{IntoResponse, Response};
|
||||||
use axum::routing::{get, post};
|
use axum::routing::{get, post};
|
||||||
use axum::{http, Json, Router};
|
use axum::{http, Json, Router};
|
||||||
use axum_tracing_opentelemetry::opentelemetry_tracing_layer;
|
use axum_tracing_opentelemetry::opentelemetry_tracing_layer;
|
||||||
|
@ -32,7 +32,7 @@ async fn compat_generate(
|
||||||
default_return_full_text: Extension<bool>,
|
default_return_full_text: Extension<bool>,
|
||||||
infer: Extension<Infer>,
|
infer: Extension<Infer>,
|
||||||
req: Json<CompatGenerateRequest>,
|
req: Json<CompatGenerateRequest>,
|
||||||
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
let mut req = req.0;
|
let mut req = req.0;
|
||||||
|
|
||||||
// default return_full_text given the pipeline_tag
|
// default return_full_text given the pipeline_tag
|
||||||
|
@ -116,6 +116,7 @@ async fn generate(
|
||||||
let span = tracing::Span::current();
|
let span = tracing::Span::current();
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
|
|
||||||
|
let compute_characters = req.0.inputs.chars().count();
|
||||||
let mut add_prompt = None;
|
let mut add_prompt = None;
|
||||||
if req.0.parameters.return_full_text.unwrap_or(false) {
|
if req.0.parameters.return_full_text.unwrap_or(false) {
|
||||||
add_prompt = Some(req.0.inputs.clone());
|
add_prompt = Some(req.0.inputs.clone());
|
||||||
|
@ -147,6 +148,15 @@ async fn generate(
|
||||||
|
|
||||||
// Headers
|
// Headers
|
||||||
let mut headers = HeaderMap::new();
|
let mut headers = HeaderMap::new();
|
||||||
|
headers.insert("x-compute-type", "gpu+optimized".parse().unwrap());
|
||||||
|
headers.insert(
|
||||||
|
"x-compute-time",
|
||||||
|
total_time.as_millis().to_string().parse().unwrap(),
|
||||||
|
);
|
||||||
|
headers.insert(
|
||||||
|
"x-compute-characters",
|
||||||
|
compute_characters.to_string().parse().unwrap(),
|
||||||
|
);
|
||||||
headers.insert(
|
headers.insert(
|
||||||
"x-total-time",
|
"x-total-time",
|
||||||
total_time.as_millis().to_string().parse().unwrap(),
|
total_time.as_millis().to_string().parse().unwrap(),
|
||||||
|
@ -239,10 +249,22 @@ async fn generate(
|
||||||
async fn generate_stream(
|
async fn generate_stream(
|
||||||
infer: Extension<Infer>,
|
infer: Extension<Infer>,
|
||||||
req: Json<GenerateRequest>,
|
req: Json<GenerateRequest>,
|
||||||
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
|
) -> (
|
||||||
|
HeaderMap,
|
||||||
|
Sse<impl Stream<Item = Result<Event, Infallible>>>,
|
||||||
|
) {
|
||||||
let span = tracing::Span::current();
|
let span = tracing::Span::current();
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
|
|
||||||
|
let compute_characters = req.0.inputs.chars().count();
|
||||||
|
|
||||||
|
let mut headers = HeaderMap::new();
|
||||||
|
headers.insert("x-compute-type", "gpu+optimized".parse().unwrap());
|
||||||
|
headers.insert(
|
||||||
|
"x-compute-characters",
|
||||||
|
compute_characters.to_string().parse().unwrap(),
|
||||||
|
);
|
||||||
|
|
||||||
let stream = async_stream::stream! {
|
let stream = async_stream::stream! {
|
||||||
// Inference
|
// Inference
|
||||||
let mut end_reached = false;
|
let mut end_reached = false;
|
||||||
|
@ -360,7 +382,7 @@ async fn generate_stream(
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
Sse::new(stream).keep_alive(KeepAlive::default())
|
(headers, Sse::new(stream).keep_alive(KeepAlive::default()))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Prometheus metrics scrape endpoint
|
/// Prometheus metrics scrape endpoint
|
||||||
|
|
Loading…
Reference in New Issue