feat(router): add cors allow origin options (#73)
This commit is contained in:
parent
c720555adc
commit
6796d38c6d
|
@ -2275,6 +2275,7 @@ dependencies = [
|
||||||
"tokenizers",
|
"tokenizers",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-stream",
|
"tokio-stream",
|
||||||
|
"tower-http",
|
||||||
"tracing",
|
"tracing",
|
||||||
"tracing-opentelemetry",
|
"tracing-opentelemetry",
|
||||||
"tracing-subscriber",
|
"tracing-subscriber",
|
||||||
|
|
|
@ -53,6 +53,8 @@ struct Args {
|
||||||
json_output: bool,
|
json_output: bool,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
otlp_endpoint: Option<String>,
|
otlp_endpoint: Option<String>,
|
||||||
|
#[clap(long, env)]
|
||||||
|
cors_allow_origin: Vec<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> ExitCode {
|
fn main() -> ExitCode {
|
||||||
|
@ -85,6 +87,7 @@ fn main() -> ExitCode {
|
||||||
disable_custom_kernels,
|
disable_custom_kernels,
|
||||||
json_output,
|
json_output,
|
||||||
otlp_endpoint,
|
otlp_endpoint,
|
||||||
|
cors_allow_origin,
|
||||||
} = args;
|
} = args;
|
||||||
|
|
||||||
// Signal handler
|
// Signal handler
|
||||||
|
@ -320,6 +323,12 @@ fn main() -> ExitCode {
|
||||||
argv.push(otlp_endpoint);
|
argv.push(otlp_endpoint);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CORS origins
|
||||||
|
for origin in cors_allow_origin.into_iter() {
|
||||||
|
argv.push("--cors-allow-origin".to_string());
|
||||||
|
argv.push(origin);
|
||||||
|
}
|
||||||
|
|
||||||
let mut webserver = match Popen::create(
|
let mut webserver = match Popen::create(
|
||||||
&argv,
|
&argv,
|
||||||
PopenConfig {
|
PopenConfig {
|
||||||
|
|
|
@ -32,6 +32,7 @@ thiserror = "1.0.38"
|
||||||
tokenizers = "0.13.2"
|
tokenizers = "0.13.2"
|
||||||
tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||||
tokio-stream = "0.1.11"
|
tokio-stream = "0.1.11"
|
||||||
|
tower-http = { version = "0.3.5", features = ["cors"] }
|
||||||
tracing = "0.1.37"
|
tracing = "0.1.37"
|
||||||
tracing-opentelemetry = "0.18.0"
|
tracing-opentelemetry = "0.18.0"
|
||||||
tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] }
|
tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] }
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
/// Text Generation Inference webserver entrypoint
|
/// Text Generation Inference webserver entrypoint
|
||||||
|
use axum::http::HeaderValue;
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use opentelemetry::sdk::propagation::TraceContextPropagator;
|
use opentelemetry::sdk::propagation::TraceContextPropagator;
|
||||||
use opentelemetry::sdk::trace;
|
use opentelemetry::sdk::trace;
|
||||||
|
@ -10,6 +11,7 @@ use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||||
use text_generation_client::ShardedClient;
|
use text_generation_client::ShardedClient;
|
||||||
use text_generation_router::server;
|
use text_generation_router::server;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
use tower_http::cors::AllowOrigin;
|
||||||
use tracing_subscriber::layer::SubscriberExt;
|
use tracing_subscriber::layer::SubscriberExt;
|
||||||
use tracing_subscriber::util::SubscriberInitExt;
|
use tracing_subscriber::util::SubscriberInitExt;
|
||||||
use tracing_subscriber::{EnvFilter, Layer};
|
use tracing_subscriber::{EnvFilter, Layer};
|
||||||
|
@ -42,6 +44,8 @@ struct Args {
|
||||||
json_output: bool,
|
json_output: bool,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
otlp_endpoint: Option<String>,
|
otlp_endpoint: Option<String>,
|
||||||
|
#[clap(long, env)]
|
||||||
|
cors_allow_origin: Option<Vec<String>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<(), std::io::Error> {
|
fn main() -> Result<(), std::io::Error> {
|
||||||
|
@ -61,12 +65,24 @@ fn main() -> Result<(), std::io::Error> {
|
||||||
validation_workers,
|
validation_workers,
|
||||||
json_output,
|
json_output,
|
||||||
otlp_endpoint,
|
otlp_endpoint,
|
||||||
|
cors_allow_origin,
|
||||||
} = args;
|
} = args;
|
||||||
|
|
||||||
if validation_workers == 0 {
|
if validation_workers == 0 {
|
||||||
panic!("validation_workers must be > 0");
|
panic!("validation_workers must be > 0");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CORS allowed origins
|
||||||
|
// map to go inside the option and then map to parse from String to HeaderValue
|
||||||
|
// Finally, convert to AllowOrigin
|
||||||
|
let cors_allow_origin: Option<AllowOrigin> = cors_allow_origin.map(|cors_allow_origin| {
|
||||||
|
AllowOrigin::list(
|
||||||
|
cors_allow_origin
|
||||||
|
.iter()
|
||||||
|
.map(|origin| origin.parse::<HeaderValue>().unwrap()),
|
||||||
|
)
|
||||||
|
});
|
||||||
|
|
||||||
// Download and instantiate tokenizer
|
// Download and instantiate tokenizer
|
||||||
// This will only be used to validate payloads
|
// This will only be used to validate payloads
|
||||||
//
|
//
|
||||||
|
@ -107,6 +123,7 @@ fn main() -> Result<(), std::io::Error> {
|
||||||
tokenizer,
|
tokenizer,
|
||||||
validation_workers,
|
validation_workers,
|
||||||
addr,
|
addr,
|
||||||
|
cors_allow_origin,
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
|
@ -5,11 +5,11 @@ use crate::{
|
||||||
Infer, StreamDetails, StreamResponse, Token, Validation,
|
Infer, StreamDetails, StreamResponse, Token, Validation,
|
||||||
};
|
};
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
use axum::http::{HeaderMap, StatusCode};
|
use axum::http::{HeaderMap, Method, StatusCode};
|
||||||
use axum::response::sse::{Event, KeepAlive, Sse};
|
use axum::response::sse::{Event, KeepAlive, Sse};
|
||||||
use axum::response::IntoResponse;
|
use axum::response::IntoResponse;
|
||||||
use axum::routing::{get, post};
|
use axum::routing::{get, post};
|
||||||
use axum::{Json, Router};
|
use axum::{http, Json, Router};
|
||||||
use axum_tracing_opentelemetry::opentelemetry_tracing_layer;
|
use axum_tracing_opentelemetry::opentelemetry_tracing_layer;
|
||||||
use futures::Stream;
|
use futures::Stream;
|
||||||
use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle};
|
use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle};
|
||||||
|
@ -20,6 +20,7 @@ use tokenizers::Tokenizer;
|
||||||
use tokio::signal;
|
use tokio::signal;
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tokio_stream::StreamExt;
|
use tokio_stream::StreamExt;
|
||||||
|
use tower_http::cors::{AllowOrigin, CorsLayer};
|
||||||
use tracing::{info_span, instrument, Instrument};
|
use tracing::{info_span, instrument, Instrument};
|
||||||
use utoipa::OpenApi;
|
use utoipa::OpenApi;
|
||||||
use utoipa_swagger_ui::SwaggerUi;
|
use utoipa_swagger_ui::SwaggerUi;
|
||||||
|
@ -334,6 +335,7 @@ pub async fn run(
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
validation_workers: usize,
|
validation_workers: usize,
|
||||||
addr: SocketAddr,
|
addr: SocketAddr,
|
||||||
|
allow_origin: Option<AllowOrigin>,
|
||||||
) {
|
) {
|
||||||
// OpenAPI documentation
|
// OpenAPI documentation
|
||||||
#[derive(OpenApi)]
|
#[derive(OpenApi)]
|
||||||
|
@ -391,6 +393,13 @@ pub async fn run(
|
||||||
.install_recorder()
|
.install_recorder()
|
||||||
.expect("failed to install metrics recorder");
|
.expect("failed to install metrics recorder");
|
||||||
|
|
||||||
|
// CORS layer
|
||||||
|
let allow_origin = allow_origin.unwrap_or(AllowOrigin::any());
|
||||||
|
let cors_layer = CorsLayer::new()
|
||||||
|
.allow_methods([Method::GET, Method::POST])
|
||||||
|
.allow_headers([http::header::CONTENT_TYPE])
|
||||||
|
.allow_origin(allow_origin);
|
||||||
|
|
||||||
// Create router
|
// Create router
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
|
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
|
||||||
|
@ -402,7 +411,8 @@ pub async fn run(
|
||||||
.layer(Extension(infer))
|
.layer(Extension(infer))
|
||||||
.route("/metrics", get(metrics))
|
.route("/metrics", get(metrics))
|
||||||
.layer(Extension(prom_handle))
|
.layer(Extension(prom_handle))
|
||||||
.layer(opentelemetry_tracing_layer());
|
.layer(opentelemetry_tracing_layer())
|
||||||
|
.layer(cors_layer);
|
||||||
|
|
||||||
// Run server
|
// Run server
|
||||||
axum::Server::bind(&addr)
|
axum::Server::bind(&addr)
|
||||||
|
|
Loading…
Reference in New Issue