feat(router): add cors allow origin options (#73)

This commit is contained in:
OlivierDehaene 2023-02-17 18:22:00 +01:00 committed by GitHub
parent c720555adc
commit 6796d38c6d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 41 additions and 3 deletions

1
Cargo.lock generated
View File

@ -2275,6 +2275,7 @@ dependencies = [
"tokenizers",
"tokio",
"tokio-stream",
"tower-http",
"tracing",
"tracing-opentelemetry",
"tracing-subscriber",

View File

@ -53,6 +53,8 @@ struct Args {
json_output: bool,
#[clap(long, env)]
otlp_endpoint: Option<String>,
#[clap(long, env)]
cors_allow_origin: Vec<String>,
}
fn main() -> ExitCode {
@ -85,6 +87,7 @@ fn main() -> ExitCode {
disable_custom_kernels,
json_output,
otlp_endpoint,
cors_allow_origin,
} = args;
// Signal handler
@ -320,6 +323,12 @@ fn main() -> ExitCode {
argv.push(otlp_endpoint);
}
// CORS origins
for origin in cors_allow_origin.into_iter() {
argv.push("--cors-allow-origin".to_string());
argv.push(origin);
}
let mut webserver = match Popen::create(
&argv,
PopenConfig {

View File

@ -32,6 +32,7 @@ thiserror = "1.0.38"
tokenizers = "0.13.2"
tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tokio-stream = "0.1.11"
tower-http = { version = "0.3.5", features = ["cors"] }
tracing = "0.1.37"
tracing-opentelemetry = "0.18.0"
tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] }

View File

@ -1,4 +1,5 @@
/// Text Generation Inference webserver entrypoint
use axum::http::HeaderValue;
use clap::Parser;
use opentelemetry::sdk::propagation::TraceContextPropagator;
use opentelemetry::sdk::trace;
@ -10,6 +11,7 @@ use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use text_generation_client::ShardedClient;
use text_generation_router::server;
use tokenizers::Tokenizer;
use tower_http::cors::AllowOrigin;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::{EnvFilter, Layer};
@ -42,6 +44,8 @@ struct Args {
json_output: bool,
#[clap(long, env)]
otlp_endpoint: Option<String>,
#[clap(long, env)]
cors_allow_origin: Option<Vec<String>>,
}
fn main() -> Result<(), std::io::Error> {
@ -61,12 +65,24 @@ fn main() -> Result<(), std::io::Error> {
validation_workers,
json_output,
otlp_endpoint,
cors_allow_origin,
} = args;
if validation_workers == 0 {
panic!("validation_workers must be > 0");
}
// CORS allowed origins
// map to go inside the option and then map to parse from String to HeaderValue
// Finally, convert to AllowOrigin
let cors_allow_origin: Option<AllowOrigin> = cors_allow_origin.map(|cors_allow_origin| {
AllowOrigin::list(
cors_allow_origin
.iter()
.map(|origin| origin.parse::<HeaderValue>().unwrap()),
)
});
// Download and instantiate tokenizer
// This will only be used to validate payloads
//
@ -107,6 +123,7 @@ fn main() -> Result<(), std::io::Error> {
tokenizer,
validation_workers,
addr,
cors_allow_origin,
)
.await;
Ok(())

View File

@ -5,11 +5,11 @@ use crate::{
Infer, StreamDetails, StreamResponse, Token, Validation,
};
use axum::extract::Extension;
use axum::http::{HeaderMap, StatusCode};
use axum::http::{HeaderMap, Method, StatusCode};
use axum::response::sse::{Event, KeepAlive, Sse};
use axum::response::IntoResponse;
use axum::routing::{get, post};
use axum::{Json, Router};
use axum::{http, Json, Router};
use axum_tracing_opentelemetry::opentelemetry_tracing_layer;
use futures::Stream;
use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle};
@ -20,6 +20,7 @@ use tokenizers::Tokenizer;
use tokio::signal;
use tokio::time::Instant;
use tokio_stream::StreamExt;
use tower_http::cors::{AllowOrigin, CorsLayer};
use tracing::{info_span, instrument, Instrument};
use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi;
@ -334,6 +335,7 @@ pub async fn run(
tokenizer: Tokenizer,
validation_workers: usize,
addr: SocketAddr,
allow_origin: Option<AllowOrigin>,
) {
// OpenAPI documentation
#[derive(OpenApi)]
@ -391,6 +393,13 @@ pub async fn run(
.install_recorder()
.expect("failed to install metrics recorder");
// CORS layer
let allow_origin = allow_origin.unwrap_or(AllowOrigin::any());
let cors_layer = CorsLayer::new()
.allow_methods([Method::GET, Method::POST])
.allow_headers([http::header::CONTENT_TYPE])
.allow_origin(allow_origin);
// Create router
let app = Router::new()
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
@ -402,7 +411,8 @@ pub async fn run(
.layer(Extension(infer))
.route("/metrics", get(metrics))
.layer(Extension(prom_handle))
.layer(opentelemetry_tracing_layer());
.layer(opentelemetry_tracing_layer())
.layer(cors_layer);
// Run server
axum::Server::bind(&addr)