Run ci api key (#2315)
* Add API_Key for Auth and conditionally add authorisation for non info/health endpoints. * change name to info routes * Fix comment * convert strings to lowercase for case insensitive comparison * convert header to string * fixes and update docs * update docs again * revert wrong update --------- Co-authored-by: Kevin Duffy <kevin.duffy94@gmail.com>
This commit is contained in:
parent
fd2e06316d
commit
583d37a2f8
|
@ -349,6 +349,12 @@ Options:
|
|||
--cors-allow-origin <CORS_ALLOW_ORIGIN>
|
||||
[env: CORS_ALLOW_ORIGIN=]
|
||||
|
||||
```
|
||||
## API_KEY
|
||||
```shell
|
||||
--api-key <API_KEY>
|
||||
[env: API_KEY=]
|
||||
|
||||
```
|
||||
## WATERMARK_GAMMA
|
||||
```shell
|
||||
|
|
|
@ -422,6 +422,10 @@ struct Args {
|
|||
|
||||
#[clap(long, env)]
|
||||
cors_allow_origin: Vec<String>,
|
||||
|
||||
#[clap(long, env)]
|
||||
api_key: Option<String>,
|
||||
|
||||
#[clap(long, env)]
|
||||
watermark_gamma: Option<f32>,
|
||||
#[clap(long, env)]
|
||||
|
@ -1271,6 +1275,11 @@ fn spawn_webserver(
|
|||
router_args.push(origin);
|
||||
}
|
||||
|
||||
// API Key
|
||||
if let Some(api_key) = args.api_key {
|
||||
router_args.push("--api-key".to_string());
|
||||
router_args.push(api_key);
|
||||
}
|
||||
// Ngrok
|
||||
if args.ngrok {
|
||||
router_args.push("--ngrok".to_string());
|
||||
|
|
|
@ -77,6 +77,8 @@ struct Args {
|
|||
#[clap(long, env)]
|
||||
cors_allow_origin: Option<Vec<String>>,
|
||||
#[clap(long, env)]
|
||||
api_key: Option<String>,
|
||||
#[clap(long, env)]
|
||||
ngrok: bool,
|
||||
#[clap(long, env)]
|
||||
ngrok_authtoken: Option<String>,
|
||||
|
@ -127,6 +129,7 @@ async fn main() -> Result<(), RouterError> {
|
|||
otlp_endpoint,
|
||||
otlp_service_name,
|
||||
cors_allow_origin,
|
||||
api_key,
|
||||
ngrok,
|
||||
ngrok_authtoken,
|
||||
ngrok_edge,
|
||||
|
@ -446,6 +449,7 @@ async fn main() -> Result<(), RouterError> {
|
|||
validation_workers,
|
||||
addr,
|
||||
cors_allow_origin,
|
||||
api_key,
|
||||
ngrok,
|
||||
ngrok_authtoken,
|
||||
ngrok_edge,
|
||||
|
|
|
@ -37,6 +37,7 @@ use futures::stream::StreamExt;
|
|||
use futures::stream::{FuturesOrdered, FuturesUnordered};
|
||||
use futures::Stream;
|
||||
use futures::TryStreamExt;
|
||||
use http::header::AUTHORIZATION;
|
||||
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
||||
use serde_json::Value;
|
||||
use std::convert::Infallible;
|
||||
|
@ -1417,6 +1418,7 @@ pub async fn run(
|
|||
validation_workers: usize,
|
||||
addr: SocketAddr,
|
||||
allow_origin: Option<AllowOrigin>,
|
||||
api_key: Option<String>,
|
||||
ngrok: bool,
|
||||
_ngrok_authtoken: Option<String>,
|
||||
_ngrok_edge: Option<String>,
|
||||
|
@ -1810,16 +1812,42 @@ pub async fn run(
|
|||
let swagger_ui = SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc);
|
||||
|
||||
// Define base and health routes
|
||||
let base_routes = Router::new()
|
||||
let mut base_routes = Router::new()
|
||||
.route("/", post(compat_generate))
|
||||
.route("/", get(health))
|
||||
.route("/info", get(get_model_info))
|
||||
.route("/generate", post(generate))
|
||||
.route("/generate_stream", post(generate_stream))
|
||||
.route("/v1/chat/completions", post(chat_completions))
|
||||
.route("/v1/completions", post(completions))
|
||||
.route("/vertex", post(vertex_compatibility))
|
||||
.route("/tokenize", post(tokenize))
|
||||
.route("/tokenize", post(tokenize));
|
||||
|
||||
if let Some(api_key) = api_key {
|
||||
let mut prefix = "Bearer ".to_string();
|
||||
prefix.push_str(&api_key);
|
||||
|
||||
// Leak to allow FnMut
|
||||
let api_key: &'static str = prefix.leak();
|
||||
|
||||
let auth = move |headers: HeaderMap,
|
||||
request: axum::extract::Request,
|
||||
next: axum::middleware::Next| async move {
|
||||
match headers.get(AUTHORIZATION) {
|
||||
Some(token) => match token.to_str() {
|
||||
Ok(token_str) if token_str.to_lowercase() == api_key.to_lowercase() => {
|
||||
let response = next.run(request).await;
|
||||
Ok(response)
|
||||
}
|
||||
_ => Err(StatusCode::UNAUTHORIZED),
|
||||
},
|
||||
None => Err(StatusCode::UNAUTHORIZED),
|
||||
}
|
||||
};
|
||||
|
||||
base_routes = base_routes.layer(axum::middleware::from_fn(auth))
|
||||
}
|
||||
let info_routes = Router::new()
|
||||
.route("/", get(health))
|
||||
.route("/info", get(get_model_info))
|
||||
.route("/health", get(health))
|
||||
.route("/ping", get(health))
|
||||
.route("/metrics", get(metrics));
|
||||
|
@ -1838,6 +1866,7 @@ pub async fn run(
|
|||
let mut app = Router::new()
|
||||
.merge(swagger_ui)
|
||||
.merge(base_routes)
|
||||
.merge(info_routes)
|
||||
.merge(aws_sagemaker_route);
|
||||
|
||||
#[cfg(feature = "google")]
|
||||
|
|
Loading…
Reference in New Issue