diff --git a/docs/source/basic_tutorials/launcher.md b/docs/source/basic_tutorials/launcher.md index 77f88490..ce98876f 100644 --- a/docs/source/basic_tutorials/launcher.md +++ b/docs/source/basic_tutorials/launcher.md @@ -349,6 +349,12 @@ Options: --cors-allow-origin [env: CORS_ALLOW_ORIGIN=] +``` +## API_KEY +```shell + --api-key + [env: API_KEY=] + ``` ## WATERMARK_GAMMA ```shell diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 228b0e79..0e1405d4 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -422,6 +422,10 @@ struct Args { #[clap(long, env)] cors_allow_origin: Vec, + + #[clap(long, env)] + api_key: Option, + #[clap(long, env)] watermark_gamma: Option, #[clap(long, env)] @@ -1271,6 +1275,11 @@ fn spawn_webserver( router_args.push(origin); } + // API Key + if let Some(api_key) = args.api_key { + router_args.push("--api-key".to_string()); + router_args.push(api_key); + } // Ngrok if args.ngrok { router_args.push("--ngrok".to_string()); diff --git a/router/src/main.rs b/router/src/main.rs index bfc77913..36879aa4 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -77,6 +77,8 @@ struct Args { #[clap(long, env)] cors_allow_origin: Option>, #[clap(long, env)] + api_key: Option, + #[clap(long, env)] ngrok: bool, #[clap(long, env)] ngrok_authtoken: Option, @@ -127,6 +129,7 @@ async fn main() -> Result<(), RouterError> { otlp_endpoint, otlp_service_name, cors_allow_origin, + api_key, ngrok, ngrok_authtoken, ngrok_edge, @@ -446,6 +449,7 @@ async fn main() -> Result<(), RouterError> { validation_workers, addr, cors_allow_origin, + api_key, ngrok, ngrok_authtoken, ngrok_edge, diff --git a/router/src/server.rs b/router/src/server.rs index c56c39a3..0fd5aade 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -37,6 +37,7 @@ use futures::stream::StreamExt; use futures::stream::{FuturesOrdered, FuturesUnordered}; use futures::Stream; use futures::TryStreamExt; +use http::header::AUTHORIZATION; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use serde_json::Value; use std::convert::Infallible; @@ -1417,6 +1418,7 @@ pub async fn run( validation_workers: usize, addr: SocketAddr, allow_origin: Option, + api_key: Option, ngrok: bool, _ngrok_authtoken: Option, _ngrok_edge: Option, @@ -1810,16 +1812,42 @@ pub async fn run( let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc); // Define base and health routes - let base_routes = Router::new() + let mut base_routes = Router::new() .route("/", post(compat_generate)) - .route("/", get(health)) - .route("/info", get(get_model_info)) .route("/generate", post(generate)) .route("/generate_stream", post(generate_stream)) .route("/v1/chat/completions", post(chat_completions)) .route("/v1/completions", post(completions)) .route("/vertex", post(vertex_compatibility)) - .route("/tokenize", post(tokenize)) + .route("/tokenize", post(tokenize)); + + if let Some(api_key) = api_key { + let mut prefix = "Bearer ".to_string(); + prefix.push_str(&api_key); + + // Leak to allow FnMut + let api_key: &'static str = prefix.leak(); + + let auth = move |headers: HeaderMap, + request: axum::extract::Request, + next: axum::middleware::Next| async move { + match headers.get(AUTHORIZATION) { + Some(token) => match token.to_str() { + Ok(token_str) if token_str.to_lowercase() == api_key.to_lowercase() => { + let response = next.run(request).await; + Ok(response) + } + _ => Err(StatusCode::UNAUTHORIZED), + }, + None => Err(StatusCode::UNAUTHORIZED), + } + }; + + base_routes = base_routes.layer(axum::middleware::from_fn(auth)) + } + let info_routes = Router::new() + .route("/", get(health)) + .route("/info", get(get_model_info)) .route("/health", get(health)) .route("/ping", get(health)) .route("/metrics", get(metrics)); @@ -1838,6 +1866,7 @@ pub async fn run( let mut app = Router::new() .merge(swagger_ui) .merge(base_routes) + .merge(info_routes) .merge(aws_sagemaker_route); #[cfg(feature = "google")]