feat(router): add cors allow origin options (#73)

This commit is contained in:
OlivierDehaene 2023-02-17 18:22:00 +01:00 committed by GitHub
parent c720555adc
commit 6796d38c6d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 41 additions and 3 deletions

1
Cargo.lock generated
View File

@ -2275,6 +2275,7 @@ dependencies = [
"tokenizers", "tokenizers",
"tokio", "tokio",
"tokio-stream", "tokio-stream",
"tower-http",
"tracing", "tracing",
"tracing-opentelemetry", "tracing-opentelemetry",
"tracing-subscriber", "tracing-subscriber",

View File

@ -53,6 +53,8 @@ struct Args {
json_output: bool, json_output: bool,
#[clap(long, env)] #[clap(long, env)]
otlp_endpoint: Option<String>, otlp_endpoint: Option<String>,
#[clap(long, env)]
cors_allow_origin: Vec<String>,
} }
fn main() -> ExitCode { fn main() -> ExitCode {
@ -85,6 +87,7 @@ fn main() -> ExitCode {
disable_custom_kernels, disable_custom_kernels,
json_output, json_output,
otlp_endpoint, otlp_endpoint,
cors_allow_origin,
} = args; } = args;
// Signal handler // Signal handler
@ -320,6 +323,12 @@ fn main() -> ExitCode {
argv.push(otlp_endpoint); argv.push(otlp_endpoint);
} }
// CORS origins
for origin in cors_allow_origin.into_iter() {
argv.push("--cors-allow-origin".to_string());
argv.push(origin);
}
let mut webserver = match Popen::create( let mut webserver = match Popen::create(
&argv, &argv,
PopenConfig { PopenConfig {

View File

@ -32,6 +32,7 @@ thiserror = "1.0.38"
tokenizers = "0.13.2" tokenizers = "0.13.2"
tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tokio-stream = "0.1.11" tokio-stream = "0.1.11"
tower-http = { version = "0.3.5", features = ["cors"] }
tracing = "0.1.37" tracing = "0.1.37"
tracing-opentelemetry = "0.18.0" tracing-opentelemetry = "0.18.0"
tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] } tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] }

View File

@ -1,4 +1,5 @@
/// Text Generation Inference webserver entrypoint /// Text Generation Inference webserver entrypoint
use axum::http::HeaderValue;
use clap::Parser; use clap::Parser;
use opentelemetry::sdk::propagation::TraceContextPropagator; use opentelemetry::sdk::propagation::TraceContextPropagator;
use opentelemetry::sdk::trace; use opentelemetry::sdk::trace;
@ -10,6 +11,7 @@ use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use text_generation_client::ShardedClient; use text_generation_client::ShardedClient;
use text_generation_router::server; use text_generation_router::server;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use tower_http::cors::AllowOrigin;
use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt; use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::{EnvFilter, Layer}; use tracing_subscriber::{EnvFilter, Layer};
@ -42,6 +44,8 @@ struct Args {
json_output: bool, json_output: bool,
#[clap(long, env)] #[clap(long, env)]
otlp_endpoint: Option<String>, otlp_endpoint: Option<String>,
#[clap(long, env)]
cors_allow_origin: Option<Vec<String>>,
} }
fn main() -> Result<(), std::io::Error> { fn main() -> Result<(), std::io::Error> {
@ -61,12 +65,24 @@ fn main() -> Result<(), std::io::Error> {
validation_workers, validation_workers,
json_output, json_output,
otlp_endpoint, otlp_endpoint,
cors_allow_origin,
} = args; } = args;
if validation_workers == 0 { if validation_workers == 0 {
panic!("validation_workers must be > 0"); panic!("validation_workers must be > 0");
} }
// CORS allowed origins
// map to go inside the option and then map to parse from String to HeaderValue
// Finally, convert to AllowOrigin
let cors_allow_origin: Option<AllowOrigin> = cors_allow_origin.map(|cors_allow_origin| {
AllowOrigin::list(
cors_allow_origin
.iter()
.map(|origin| origin.parse::<HeaderValue>().unwrap()),
)
});
// Download and instantiate tokenizer // Download and instantiate tokenizer
// This will only be used to validate payloads // This will only be used to validate payloads
// //
@ -107,6 +123,7 @@ fn main() -> Result<(), std::io::Error> {
tokenizer, tokenizer,
validation_workers, validation_workers,
addr, addr,
cors_allow_origin,
) )
.await; .await;
Ok(()) Ok(())

View File

@ -5,11 +5,11 @@ use crate::{
Infer, StreamDetails, StreamResponse, Token, Validation, Infer, StreamDetails, StreamResponse, Token, Validation,
}; };
use axum::extract::Extension; use axum::extract::Extension;
use axum::http::{HeaderMap, StatusCode}; use axum::http::{HeaderMap, Method, StatusCode};
use axum::response::sse::{Event, KeepAlive, Sse}; use axum::response::sse::{Event, KeepAlive, Sse};
use axum::response::IntoResponse; use axum::response::IntoResponse;
use axum::routing::{get, post}; use axum::routing::{get, post};
use axum::{Json, Router}; use axum::{http, Json, Router};
use axum_tracing_opentelemetry::opentelemetry_tracing_layer; use axum_tracing_opentelemetry::opentelemetry_tracing_layer;
use futures::Stream; use futures::Stream;
use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle}; use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle};
@ -20,6 +20,7 @@ use tokenizers::Tokenizer;
use tokio::signal; use tokio::signal;
use tokio::time::Instant; use tokio::time::Instant;
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
use tower_http::cors::{AllowOrigin, CorsLayer};
use tracing::{info_span, instrument, Instrument}; use tracing::{info_span, instrument, Instrument};
use utoipa::OpenApi; use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi; use utoipa_swagger_ui::SwaggerUi;
@ -334,6 +335,7 @@ pub async fn run(
tokenizer: Tokenizer, tokenizer: Tokenizer,
validation_workers: usize, validation_workers: usize,
addr: SocketAddr, addr: SocketAddr,
allow_origin: Option<AllowOrigin>,
) { ) {
// OpenAPI documentation // OpenAPI documentation
#[derive(OpenApi)] #[derive(OpenApi)]
@ -391,6 +393,13 @@ pub async fn run(
.install_recorder() .install_recorder()
.expect("failed to install metrics recorder"); .expect("failed to install metrics recorder");
// CORS layer
let allow_origin = allow_origin.unwrap_or(AllowOrigin::any());
let cors_layer = CorsLayer::new()
.allow_methods([Method::GET, Method::POST])
.allow_headers([http::header::CONTENT_TYPE])
.allow_origin(allow_origin);
// Create router // Create router
let app = Router::new() let app = Router::new()
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi())) .merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
@ -402,7 +411,8 @@ pub async fn run(
.layer(Extension(infer)) .layer(Extension(infer))
.route("/metrics", get(metrics)) .route("/metrics", get(metrics))
.layer(Extension(prom_handle)) .layer(Extension(prom_handle))
.layer(opentelemetry_tracing_layer()); .layer(opentelemetry_tracing_layer())
.layer(cors_layer);
// Run server // Run server
axum::Server::bind(&addr) axum::Server::bind(&addr)