feat(server): add local prom and health routes if running w/ ngrok

This commit is contained in:
OlivierDehaene 2023-07-21 16:56:30 +02:00
parent 15b3e9ffb0
commit 1da642bd0e
1 changed files with 142 additions and 127 deletions

View File

@ -32,25 +32,25 @@ use utoipa_swagger_ui::SwaggerUi;
/// Generate tokens if `stream == false` or a stream of token if `stream == true` /// Generate tokens if `stream == false` or a stream of token if `stream == true`
#[utoipa::path( #[utoipa::path(
post, post,
tag = "Text Generation Inference", tag = "Text Generation Inference",
path = "/", path = "/",
request_body = CompatGenerateRequest, request_body = CompatGenerateRequest,
responses( responses(
(status = 200, description = "Generated Text", (status = 200, description = "Generated Text",
content( content(
("application/json" = GenerateResponse), ("application/json" = GenerateResponse),
("text/event-stream" = StreamResponse), ("text/event-stream" = StreamResponse),
)), )),
(status = 424, description = "Generation Error", body = ErrorResponse, (status = 424, description = "Generation Error", body = ErrorResponse,
example = json ! ({"error": "Request failed during generation"})), example = json ! ({"error": "Request failed during generation"})),
(status = 429, description = "Model is overloaded", body = ErrorResponse, (status = 429, description = "Model is overloaded", body = ErrorResponse,
example = json ! ({"error": "Model is overloaded"})), example = json ! ({"error": "Model is overloaded"})),
(status = 422, description = "Input validation error", body = ErrorResponse, (status = 422, description = "Input validation error", body = ErrorResponse,
example = json ! ({"error": "Input validation error"})), example = json ! ({"error": "Input validation error"})),
(status = 500, description = "Incomplete generation", body = ErrorResponse, (status = 500, description = "Incomplete generation", body = ErrorResponse,
example = json ! ({"error": "Incomplete generation"})), example = json ! ({"error": "Incomplete generation"})),
) )
)] )]
#[instrument(skip(infer, req))] #[instrument(skip(infer, req))]
async fn compat_generate( async fn compat_generate(
@ -79,10 +79,10 @@ async fn compat_generate(
/// Text Generation Inference endpoint info /// Text Generation Inference endpoint info
#[utoipa::path( #[utoipa::path(
get, get,
tag = "Text Generation Inference", tag = "Text Generation Inference",
path = "/info", path = "/info",
responses((status = 200, description = "Served model info", body = Info)) responses((status = 200, description = "Served model info", body = Info))
)] )]
#[instrument] #[instrument]
async fn get_model_info(info: Extension<Info>) -> Json<Info> { async fn get_model_info(info: Extension<Info>) -> Json<Info> {
@ -90,14 +90,14 @@ async fn get_model_info(info: Extension<Info>) -> Json<Info> {
} }
#[utoipa::path( #[utoipa::path(
get, get,
tag = "Text Generation Inference", tag = "Text Generation Inference",
path = "/health", path = "/health",
responses( responses(
(status = 200, description = "Everything is working fine"), (status = 200, description = "Everything is working fine"),
(status = 503, description = "Text generation inference is down", body = ErrorResponse, (status = 503, description = "Text generation inference is down", body = ErrorResponse,
example = json ! ({"error": "unhealthy", "error_type": "healthcheck"})), example = json ! ({"error": "unhealthy", "error_type": "healthcheck"})),
) )
)] )]
#[instrument(skip(health))] #[instrument(skip(health))]
/// Health check method /// Health check method
@ -116,33 +116,33 @@ async fn health(mut health: Extension<Health>) -> Result<(), (StatusCode, Json<E
/// Generate tokens /// Generate tokens
#[utoipa::path( #[utoipa::path(
post, post,
tag = "Text Generation Inference", tag = "Text Generation Inference",
path = "/generate", path = "/generate",
request_body = GenerateRequest, request_body = GenerateRequest,
responses( responses(
(status = 200, description = "Generated Text", body = GenerateResponse), (status = 200, description = "Generated Text", body = GenerateResponse),
(status = 424, description = "Generation Error", body = ErrorResponse, (status = 424, description = "Generation Error", body = ErrorResponse,
example = json ! ({"error": "Request failed during generation"})), example = json ! ({"error": "Request failed during generation"})),
(status = 429, description = "Model is overloaded", body = ErrorResponse, (status = 429, description = "Model is overloaded", body = ErrorResponse,
example = json ! ({"error": "Model is overloaded"})), example = json ! ({"error": "Model is overloaded"})),
(status = 422, description = "Input validation error", body = ErrorResponse, (status = 422, description = "Input validation error", body = ErrorResponse,
example = json ! ({"error": "Input validation error"})), example = json ! ({"error": "Input validation error"})),
(status = 500, description = "Incomplete generation", body = ErrorResponse, (status = 500, description = "Incomplete generation", body = ErrorResponse,
example = json ! ({"error": "Incomplete generation"})), example = json ! ({"error": "Incomplete generation"})),
) )
)] )]
#[instrument( #[instrument(
skip_all, skip_all,
fields( fields(
parameters = ?req.0.parameters, parameters = ? req.0.parameters,
total_time, total_time,
validation_time, validation_time,
queue_time, queue_time,
inference_time, inference_time,
time_per_token, time_per_token,
seed, seed,
) )
)] )]
async fn generate( async fn generate(
infer: Extension<Infer>, infer: Extension<Infer>,
@ -297,38 +297,38 @@ async fn generate(
/// Generate a stream of token using Server-Sent Events /// Generate a stream of token using Server-Sent Events
#[utoipa::path( #[utoipa::path(
post, post,
tag = "Text Generation Inference", tag = "Text Generation Inference",
path = "/generate_stream", path = "/generate_stream",
request_body = GenerateRequest, request_body = GenerateRequest,
responses( responses(
(status = 200, description = "Generated Text", body = StreamResponse, (status = 200, description = "Generated Text", body = StreamResponse,
content_type = "text/event-stream"), content_type = "text/event-stream"),
(status = 424, description = "Generation Error", body = ErrorResponse, (status = 424, description = "Generation Error", body = ErrorResponse,
example = json ! ({"error": "Request failed during generation"}), example = json ! ({"error": "Request failed during generation"}),
content_type = "text/event-stream"), content_type = "text/event-stream"),
(status = 429, description = "Model is overloaded", body = ErrorResponse, (status = 429, description = "Model is overloaded", body = ErrorResponse,
example = json ! ({"error": "Model is overloaded"}), example = json ! ({"error": "Model is overloaded"}),
content_type = "text/event-stream"), content_type = "text/event-stream"),
(status = 422, description = "Input validation error", body = ErrorResponse, (status = 422, description = "Input validation error", body = ErrorResponse,
example = json ! ({"error": "Input validation error"}), example = json ! ({"error": "Input validation error"}),
content_type = "text/event-stream"), content_type = "text/event-stream"),
(status = 500, description = "Incomplete generation", body = ErrorResponse, (status = 500, description = "Incomplete generation", body = ErrorResponse,
example = json ! ({"error": "Incomplete generation"}), example = json ! ({"error": "Incomplete generation"}),
content_type = "text/event-stream"), content_type = "text/event-stream"),
) )
)] )]
#[instrument( #[instrument(
skip_all, skip_all,
fields( fields(
parameters = ?req.0.parameters, parameters = ? req.0.parameters,
total_time, total_time,
validation_time, validation_time,
queue_time, queue_time,
inference_time, inference_time,
time_per_token, time_per_token,
seed, seed,
) )
)] )]
async fn generate_stream( async fn generate_stream(
infer: Extension<Infer>, infer: Extension<Infer>,
@ -493,10 +493,10 @@ async fn generate_stream(
/// Prometheus metrics scrape endpoint /// Prometheus metrics scrape endpoint
#[utoipa::path( #[utoipa::path(
get, get,
tag = "Text Generation Inference", tag = "Text Generation Inference",
path = "/metrics", path = "/metrics",
responses((status = 200, description = "Prometheus Metrics", body = String)) responses((status = 200, description = "Prometheus Metrics", body = String))
)] )]
async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String { async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
prom_handle.render() prom_handle.render()
@ -529,41 +529,41 @@ pub async fn run(
// OpenAPI documentation // OpenAPI documentation
#[derive(OpenApi)] #[derive(OpenApi)]
#[openapi( #[openapi(
paths( paths(
health, health,
get_model_info, get_model_info,
compat_generate, compat_generate,
generate, generate,
generate_stream, generate_stream,
metrics, metrics,
), ),
components( components(
schemas( schemas(
Info, Info,
CompatGenerateRequest, CompatGenerateRequest,
GenerateRequest, GenerateRequest,
GenerateParameters, GenerateParameters,
PrefillToken, PrefillToken,
Token, Token,
GenerateResponse, GenerateResponse,
BestOfSequence, BestOfSequence,
Details, Details,
FinishReason, FinishReason,
StreamResponse, StreamResponse,
StreamDetails, StreamDetails,
ErrorResponse, ErrorResponse,
) )
), ),
tags( tags(
(name = "Text Generation Inference", description = "Hugging Face Text Generation Inference API") (name = "Text Generation Inference", description = "Hugging Face Text Generation Inference API")
), ),
info( info(
title = "Text Generation Inference", title = "Text Generation Inference",
license( license(
name = "Apache 2.0", name = "Apache 2.0",
url = "https://www.apache.org/licenses/LICENSE-2.0" url = "https://www.apache.org/licenses/LICENSE-2.0"
) )
) )
)] )]
struct ApiDoc; struct ApiDoc;
@ -683,10 +683,10 @@ pub async fn run(
// Prometheus metrics route // Prometheus metrics route
.route("/metrics", get(metrics)) .route("/metrics", get(metrics))
.layer(Extension(info)) .layer(Extension(info))
.layer(Extension(health_ext)) .layer(Extension(health_ext.clone()))
.layer(Extension(compat_return_full_text)) .layer(Extension(compat_return_full_text))
.layer(Extension(infer)) .layer(Extension(infer))
.layer(Extension(prom_handle)) .layer(Extension(prom_handle.clone()))
.layer(opentelemetry_tracing_layer()) .layer(opentelemetry_tracing_layer())
.layer(cors_layer); .layer(cors_layer);
@ -712,6 +712,21 @@ pub async fn run(
let listener = tunnel.listen().await.unwrap(); let listener = tunnel.listen().await.unwrap();
// Run prom metrics and health locally too
tokio::spawn(
axum::Server::bind(&addr)
.serve(
Router::new()
.route("/health", get(health))
.route("/metrics", get(metrics))
.layer(Extension(health_ext))
.layer(Extension(prom_handle))
.into_make_service(),
)
//Wait until all requests are finished to shut down
.with_graceful_shutdown(shutdown_signal()),
);
// Run server // Run server
axum::Server::builder(listener) axum::Server::builder(listener)
.serve(app.into_make_service()) .serve(app.into_make_service())