diff --git a/router/src/server.rs b/router/src/server.rs index bfeee375..9af94951 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -32,25 +32,25 @@ use utoipa_swagger_ui::SwaggerUi; /// Generate tokens if `stream == false` or a stream of token if `stream == true` #[utoipa::path( - post, - tag = "Text Generation Inference", - path = "/", - request_body = CompatGenerateRequest, - responses( - (status = 200, description = "Generated Text", - content( - ("application/json" = GenerateResponse), - ("text/event-stream" = StreamResponse), - )), - (status = 424, description = "Generation Error", body = ErrorResponse, - example = json ! ({"error": "Request failed during generation"})), - (status = 429, description = "Model is overloaded", body = ErrorResponse, - example = json ! ({"error": "Model is overloaded"})), - (status = 422, description = "Input validation error", body = ErrorResponse, - example = json ! ({"error": "Input validation error"})), - (status = 500, description = "Incomplete generation", body = ErrorResponse, - example = json ! ({"error": "Incomplete generation"})), - ) +post, +tag = "Text Generation Inference", +path = "/", +request_body = CompatGenerateRequest, +responses( +(status = 200, description = "Generated Text", +content( +("application/json" = GenerateResponse), +("text/event-stream" = StreamResponse), +)), +(status = 424, description = "Generation Error", body = ErrorResponse, +example = json ! ({"error": "Request failed during generation"})), +(status = 429, description = "Model is overloaded", body = ErrorResponse, +example = json ! ({"error": "Model is overloaded"})), +(status = 422, description = "Input validation error", body = ErrorResponse, +example = json ! ({"error": "Input validation error"})), +(status = 500, description = "Incomplete generation", body = ErrorResponse, +example = json ! ({"error": "Incomplete generation"})), +) )] #[instrument(skip(infer, req))] async fn compat_generate( @@ -79,10 +79,10 @@ async fn compat_generate( /// Text Generation Inference endpoint info #[utoipa::path( - get, - tag = "Text Generation Inference", - path = "/info", - responses((status = 200, description = "Served model info", body = Info)) +get, +tag = "Text Generation Inference", +path = "/info", +responses((status = 200, description = "Served model info", body = Info)) )] #[instrument] async fn get_model_info(info: Extension) -> Json { @@ -90,14 +90,14 @@ async fn get_model_info(info: Extension) -> Json { } #[utoipa::path( - get, - tag = "Text Generation Inference", - path = "/health", - responses( - (status = 200, description = "Everything is working fine"), - (status = 503, description = "Text generation inference is down", body = ErrorResponse, - example = json ! ({"error": "unhealthy", "error_type": "healthcheck"})), - ) +get, +tag = "Text Generation Inference", +path = "/health", +responses( +(status = 200, description = "Everything is working fine"), +(status = 503, description = "Text generation inference is down", body = ErrorResponse, +example = json ! ({"error": "unhealthy", "error_type": "healthcheck"})), +) )] #[instrument(skip(health))] /// Health check method @@ -116,33 +116,33 @@ async fn health(mut health: Extension) -> Result<(), (StatusCode, Json, @@ -297,38 +297,38 @@ async fn generate( /// Generate a stream of token using Server-Sent Events #[utoipa::path( - post, - tag = "Text Generation Inference", - path = "/generate_stream", - request_body = GenerateRequest, - responses( - (status = 200, description = "Generated Text", body = StreamResponse, - content_type = "text/event-stream"), - (status = 424, description = "Generation Error", body = ErrorResponse, - example = json ! ({"error": "Request failed during generation"}), - content_type = "text/event-stream"), - (status = 429, description = "Model is overloaded", body = ErrorResponse, - example = json ! ({"error": "Model is overloaded"}), - content_type = "text/event-stream"), - (status = 422, description = "Input validation error", body = ErrorResponse, - example = json ! ({"error": "Input validation error"}), - content_type = "text/event-stream"), - (status = 500, description = "Incomplete generation", body = ErrorResponse, - example = json ! ({"error": "Incomplete generation"}), - content_type = "text/event-stream"), - ) +post, +tag = "Text Generation Inference", +path = "/generate_stream", +request_body = GenerateRequest, +responses( +(status = 200, description = "Generated Text", body = StreamResponse, +content_type = "text/event-stream"), +(status = 424, description = "Generation Error", body = ErrorResponse, +example = json ! ({"error": "Request failed during generation"}), +content_type = "text/event-stream"), +(status = 429, description = "Model is overloaded", body = ErrorResponse, +example = json ! ({"error": "Model is overloaded"}), +content_type = "text/event-stream"), +(status = 422, description = "Input validation error", body = ErrorResponse, +example = json ! ({"error": "Input validation error"}), +content_type = "text/event-stream"), +(status = 500, description = "Incomplete generation", body = ErrorResponse, +example = json ! ({"error": "Incomplete generation"}), +content_type = "text/event-stream"), +) )] #[instrument( - skip_all, - fields( - parameters = ?req.0.parameters, - total_time, - validation_time, - queue_time, - inference_time, - time_per_token, - seed, - ) +skip_all, +fields( +parameters = ? req.0.parameters, +total_time, +validation_time, +queue_time, +inference_time, +time_per_token, +seed, +) )] async fn generate_stream( infer: Extension, @@ -493,10 +493,10 @@ async fn generate_stream( /// Prometheus metrics scrape endpoint #[utoipa::path( - get, - tag = "Text Generation Inference", - path = "/metrics", - responses((status = 200, description = "Prometheus Metrics", body = String)) +get, +tag = "Text Generation Inference", +path = "/metrics", +responses((status = 200, description = "Prometheus Metrics", body = String)) )] async fn metrics(prom_handle: Extension) -> String { prom_handle.render() @@ -529,41 +529,41 @@ pub async fn run( // OpenAPI documentation #[derive(OpenApi)] #[openapi( - paths( - health, - get_model_info, - compat_generate, - generate, - generate_stream, - metrics, - ), - components( - schemas( - Info, - CompatGenerateRequest, - GenerateRequest, - GenerateParameters, - PrefillToken, - Token, - GenerateResponse, - BestOfSequence, - Details, - FinishReason, - StreamResponse, - StreamDetails, - ErrorResponse, - ) - ), - tags( - (name = "Text Generation Inference", description = "Hugging Face Text Generation Inference API") - ), - info( - title = "Text Generation Inference", - license( - name = "Apache 2.0", - url = "https://www.apache.org/licenses/LICENSE-2.0" - ) - ) + paths( + health, + get_model_info, + compat_generate, + generate, + generate_stream, + metrics, + ), + components( + schemas( + Info, + CompatGenerateRequest, + GenerateRequest, + GenerateParameters, + PrefillToken, + Token, + GenerateResponse, + BestOfSequence, + Details, + FinishReason, + StreamResponse, + StreamDetails, + ErrorResponse, + ) + ), + tags( + (name = "Text Generation Inference", description = "Hugging Face Text Generation Inference API") + ), + info( + title = "Text Generation Inference", + license( + name = "Apache 2.0", + url = "https://www.apache.org/licenses/LICENSE-2.0" + ) + ) )] struct ApiDoc; @@ -683,10 +683,10 @@ pub async fn run( // Prometheus metrics route .route("/metrics", get(metrics)) .layer(Extension(info)) - .layer(Extension(health_ext)) + .layer(Extension(health_ext.clone())) .layer(Extension(compat_return_full_text)) .layer(Extension(infer)) - .layer(Extension(prom_handle)) + .layer(Extension(prom_handle.clone())) .layer(opentelemetry_tracing_layer()) .layer(cors_layer); @@ -712,6 +712,21 @@ pub async fn run( let listener = tunnel.listen().await.unwrap(); + // Run prom metrics and health locally too + tokio::spawn( + axum::Server::bind(&addr) + .serve( + Router::new() + .route("/health", get(health)) + .route("/metrics", get(metrics)) + .layer(Extension(health_ext)) + .layer(Extension(prom_handle)) + .into_make_service(), + ) + //Wait until all requests are finished to shut down + .with_graceful_shutdown(shutdown_signal()), + ); + // Run server axum::Server::builder(listener) .serve(app.into_make_service())