feat(router): add cors allow origin options (#73)
This commit is contained in:
parent
c720555adc
commit
6796d38c6d
|
@ -2275,6 +2275,7 @@ dependencies = [
|
|||
"tokenizers",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tower-http",
|
||||
"tracing",
|
||||
"tracing-opentelemetry",
|
||||
"tracing-subscriber",
|
||||
|
|
|
@ -53,6 +53,8 @@ struct Args {
|
|||
json_output: bool,
|
||||
#[clap(long, env)]
|
||||
otlp_endpoint: Option<String>,
|
||||
#[clap(long, env)]
|
||||
cors_allow_origin: Vec<String>,
|
||||
}
|
||||
|
||||
fn main() -> ExitCode {
|
||||
|
@ -85,6 +87,7 @@ fn main() -> ExitCode {
|
|||
disable_custom_kernels,
|
||||
json_output,
|
||||
otlp_endpoint,
|
||||
cors_allow_origin,
|
||||
} = args;
|
||||
|
||||
// Signal handler
|
||||
|
@ -320,6 +323,12 @@ fn main() -> ExitCode {
|
|||
argv.push(otlp_endpoint);
|
||||
}
|
||||
|
||||
// CORS origins
|
||||
for origin in cors_allow_origin.into_iter() {
|
||||
argv.push("--cors-allow-origin".to_string());
|
||||
argv.push(origin);
|
||||
}
|
||||
|
||||
let mut webserver = match Popen::create(
|
||||
&argv,
|
||||
PopenConfig {
|
||||
|
|
|
@ -32,6 +32,7 @@ thiserror = "1.0.38"
|
|||
tokenizers = "0.13.2"
|
||||
tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||
tokio-stream = "0.1.11"
|
||||
tower-http = { version = "0.3.5", features = ["cors"] }
|
||||
tracing = "0.1.37"
|
||||
tracing-opentelemetry = "0.18.0"
|
||||
tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] }
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
/// Text Generation Inference webserver entrypoint
|
||||
use axum::http::HeaderValue;
|
||||
use clap::Parser;
|
||||
use opentelemetry::sdk::propagation::TraceContextPropagator;
|
||||
use opentelemetry::sdk::trace;
|
||||
|
@ -10,6 +11,7 @@ use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
|||
use text_generation_client::ShardedClient;
|
||||
use text_generation_router::server;
|
||||
use tokenizers::Tokenizer;
|
||||
use tower_http::cors::AllowOrigin;
|
||||
use tracing_subscriber::layer::SubscriberExt;
|
||||
use tracing_subscriber::util::SubscriberInitExt;
|
||||
use tracing_subscriber::{EnvFilter, Layer};
|
||||
|
@ -42,6 +44,8 @@ struct Args {
|
|||
json_output: bool,
|
||||
#[clap(long, env)]
|
||||
otlp_endpoint: Option<String>,
|
||||
#[clap(long, env)]
|
||||
cors_allow_origin: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
fn main() -> Result<(), std::io::Error> {
|
||||
|
@ -61,12 +65,24 @@ fn main() -> Result<(), std::io::Error> {
|
|||
validation_workers,
|
||||
json_output,
|
||||
otlp_endpoint,
|
||||
cors_allow_origin,
|
||||
} = args;
|
||||
|
||||
if validation_workers == 0 {
|
||||
panic!("validation_workers must be > 0");
|
||||
}
|
||||
|
||||
// CORS allowed origins
|
||||
// map to go inside the option and then map to parse from String to HeaderValue
|
||||
// Finally, convert to AllowOrigin
|
||||
let cors_allow_origin: Option<AllowOrigin> = cors_allow_origin.map(|cors_allow_origin| {
|
||||
AllowOrigin::list(
|
||||
cors_allow_origin
|
||||
.iter()
|
||||
.map(|origin| origin.parse::<HeaderValue>().unwrap()),
|
||||
)
|
||||
});
|
||||
|
||||
// Download and instantiate tokenizer
|
||||
// This will only be used to validate payloads
|
||||
//
|
||||
|
@ -107,6 +123,7 @@ fn main() -> Result<(), std::io::Error> {
|
|||
tokenizer,
|
||||
validation_workers,
|
||||
addr,
|
||||
cors_allow_origin,
|
||||
)
|
||||
.await;
|
||||
Ok(())
|
||||
|
|
|
@ -5,11 +5,11 @@ use crate::{
|
|||
Infer, StreamDetails, StreamResponse, Token, Validation,
|
||||
};
|
||||
use axum::extract::Extension;
|
||||
use axum::http::{HeaderMap, StatusCode};
|
||||
use axum::http::{HeaderMap, Method, StatusCode};
|
||||
use axum::response::sse::{Event, KeepAlive, Sse};
|
||||
use axum::response::IntoResponse;
|
||||
use axum::routing::{get, post};
|
||||
use axum::{Json, Router};
|
||||
use axum::{http, Json, Router};
|
||||
use axum_tracing_opentelemetry::opentelemetry_tracing_layer;
|
||||
use futures::Stream;
|
||||
use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle};
|
||||
|
@ -20,6 +20,7 @@ use tokenizers::Tokenizer;
|
|||
use tokio::signal;
|
||||
use tokio::time::Instant;
|
||||
use tokio_stream::StreamExt;
|
||||
use tower_http::cors::{AllowOrigin, CorsLayer};
|
||||
use tracing::{info_span, instrument, Instrument};
|
||||
use utoipa::OpenApi;
|
||||
use utoipa_swagger_ui::SwaggerUi;
|
||||
|
@ -334,6 +335,7 @@ pub async fn run(
|
|||
tokenizer: Tokenizer,
|
||||
validation_workers: usize,
|
||||
addr: SocketAddr,
|
||||
allow_origin: Option<AllowOrigin>,
|
||||
) {
|
||||
// OpenAPI documentation
|
||||
#[derive(OpenApi)]
|
||||
|
@ -391,6 +393,13 @@ pub async fn run(
|
|||
.install_recorder()
|
||||
.expect("failed to install metrics recorder");
|
||||
|
||||
// CORS layer
|
||||
let allow_origin = allow_origin.unwrap_or(AllowOrigin::any());
|
||||
let cors_layer = CorsLayer::new()
|
||||
.allow_methods([Method::GET, Method::POST])
|
||||
.allow_headers([http::header::CONTENT_TYPE])
|
||||
.allow_origin(allow_origin);
|
||||
|
||||
// Create router
|
||||
let app = Router::new()
|
||||
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
|
||||
|
@ -402,7 +411,8 @@ pub async fn run(
|
|||
.layer(Extension(infer))
|
||||
.route("/metrics", get(metrics))
|
||||
.layer(Extension(prom_handle))
|
||||
.layer(opentelemetry_tracing_layer());
|
||||
.layer(opentelemetry_tracing_layer())
|
||||
.layer(cors_layer);
|
||||
|
||||
// Run server
|
||||
axum::Server::bind(&addr)
|
||||
|
|
Loading…
Reference in New Issue