feat(router): add ngrok integration (#453)

This commit is contained in:
OlivierDehaene 2023-06-16 16:25:11 +02:00 committed by GitHub
parent 5ce89059f8
commit f59fb8b630
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 571 additions and 231 deletions

657
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -229,6 +229,26 @@ struct Args {
#[clap(long, env)]
watermark_delta: Option<f32>,
/// Enable ngrok tunneling
#[clap(long, env)]
ngrok: bool,
/// ngrok authentication token
#[clap(long, env)]
ngrok_authtoken: Option<String>,
/// ngrok domain name where the axum webserver will be available at
#[clap(long, env)]
ngrok_domain: Option<String>,
/// ngrok basic auth username
#[clap(long, env)]
ngrok_username: Option<String>,
/// ngrok basic auth password
#[clap(long, env)]
ngrok_password: Option<String>,
/// Display a lot of information about your runtime environment
#[clap(long, short, action)]
env: bool,
@ -845,6 +865,30 @@ fn spawn_webserver(
argv.push(origin);
}
// Ngrok
if args.ngrok {
let authtoken = args.ngrok_authtoken.ok_or_else(|| {
tracing::error!("`ngrok-authtoken` must be set when using ngrok tunneling");
LauncherError::WebserverCannotStart
})?;
argv.push("--ngrok".to_string());
argv.push("--ngrok-authtoken".to_string());
argv.push(authtoken);
if let Some(domain) = args.ngrok_domain {
argv.push("--ngrok-domain".to_string());
argv.push(domain);
}
if let (Some(username), Some(password)) = (args.ngrok_username, args.ngrok_password) {
argv.push("--ngrok-username".to_string());
argv.push(username);
argv.push("--ngrok-password".to_string());
argv.push(password);
}
}
// Copy current process env
let mut env: Vec<(OsString, OsString)> = env::vars_os().collect();

View File

@ -40,6 +40,11 @@ tracing-opentelemetry = "0.18.0"
tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] }
utoipa = { version = "3.0.1", features = ["axum_extras"] }
utoipa-swagger-ui = { version = "3.0.2", features = ["axum"] }
ngrok = { version = "0.12.3", features = ["axum"], optional = true }
[build-dependencies]
vergen = { version = "8.0.0", features = ["build", "git", "gitcl"] }
[features]
default = ["ngrok"]
ngrok = ["dep:ngrok"]

View File

@ -56,6 +56,16 @@ struct Args {
otlp_endpoint: Option<String>,
#[clap(long, env)]
cors_allow_origin: Option<Vec<String>>,
#[clap(long, env)]
ngrok: bool,
#[clap(long, env)]
ngrok_authtoken: Option<String>,
#[clap(long, env)]
ngrok_domain: Option<String>,
#[clap(long, env)]
ngrok_username: Option<String>,
#[clap(long, env)]
ngrok_password: Option<String>,
}
fn main() -> Result<(), std::io::Error> {
@ -80,6 +90,11 @@ fn main() -> Result<(), std::io::Error> {
json_output,
otlp_endpoint,
cors_allow_origin,
ngrok,
ngrok_authtoken,
ngrok_domain,
ngrok_username,
ngrok_password,
} = args;
if validation_workers == 0 {
@ -198,6 +213,11 @@ fn main() -> Result<(), std::io::Error> {
validation_workers,
addr,
cors_allow_origin,
ngrok,
ngrok_authtoken,
ngrok_domain,
ngrok_username,
ngrok_password,
)
.await;
Ok(())

View File

@ -49,7 +49,7 @@ impl Queue {
// Send append command to the background task managing the state
// Unwrap is safe here
self.queue_sender
.send(QueueCommand::Append(entry, Span::current()))
.send(QueueCommand::Append(Box::new(entry), Span::current()))
.unwrap();
}
@ -85,7 +85,7 @@ async fn queue_task(requires_padding: bool, receiver: flume::Receiver<QueueComma
while let Ok(cmd) = receiver.recv_async().await {
match cmd {
QueueCommand::Append(entry, span) => {
span.in_scope(|| state.append(entry));
span.in_scope(|| state.append(*entry));
metrics::increment_gauge!("tgi_queue_size", 1.0);
}
QueueCommand::NextBatch {
@ -256,7 +256,7 @@ type NextBatch = (IntMap<u64, Entry>, Batch, Span);
#[derive(Debug)]
enum QueueCommand {
Append(Entry, Span),
Append(Box<Entry>, Span),
NextBatch {
min_size: Option<usize>,
token_budget: u32,

View File

@ -1,5 +1,5 @@
use crate::health::Health;
/// HTTP Server logic
use crate::health::Health;
use crate::infer::{InferError, InferResponse, InferStreamResponse};
use crate::validation::ValidationError;
use crate::{
@ -520,6 +520,11 @@ pub async fn run(
validation_workers: usize,
addr: SocketAddr,
allow_origin: Option<AllowOrigin>,
ngrok: bool,
ngrok_authtoken: Option<String>,
ngrok_domain: Option<String>,
ngrok_username: Option<String>,
ngrok_password: Option<String>,
) {
// OpenAPI documentation
#[derive(OpenApi)]
@ -683,13 +688,61 @@ pub async fn run(
.layer(opentelemetry_tracing_layer())
.layer(cors_layer);
// Run server
axum::Server::bind(&addr)
.serve(app.into_make_service())
// Wait until all requests are finished to shut down
.with_graceful_shutdown(shutdown_signal())
.await
.unwrap();
if ngrok {
#[cfg(feature = "ngrok")]
{
use ngrok::config::TunnelBuilder;
use ngrok::tunnel::UrlTunnel;
let _ = addr;
let authtoken =
ngrok_authtoken.expect("`ngrok-authtoken` must be set when using ngrok tunneling");
let mut tunnel = ngrok::Session::builder()
.authtoken(authtoken)
.connect()
.await
.unwrap()
.http_endpoint();
if let Some(domain) = ngrok_domain {
tunnel = tunnel.domain(domain);
}
if let (Some(username), Some(password)) = (ngrok_username, ngrok_password) {
tunnel = tunnel.basic_auth(username, password);
}
let listener = tunnel.listen().await.unwrap();
// Run server
tracing::info!("Ingress URL: {:?}", listener.url());
axum::Server::builder(listener)
.serve(app.into_make_service())
//Wait until all requests are finished to shut down
.with_graceful_shutdown(shutdown_signal())
.await
.unwrap();
}
#[cfg(not(feature = "ngrok"))]
{
let _ngrok_authtoken = ngrok_authtoken;
let _ngrok_domain = ngrok_domain;
let _ngrok_username = ngrok_username;
let _ngrok_password = ngrok_password;
panic!("`text-generation-router` was compiled without the `ngrok` feature");
}
} else {
// Run server
axum::Server::bind(&addr)
.serve(app.into_make_service())
// Wait until all requests are finished to shut down
.with_graceful_shutdown(shutdown_signal())
.await
.unwrap();
}
}
/// Shutdown signal handler

View File

@ -47,7 +47,6 @@ def load_multi_mqa(
shape = slice_.get_shape()
block_size = (shape[0] - 2 * head_size) // world_size
assert (shape[0] - 2 * head_size) % world_size == 0
q_tensor = slice_[start:stop]
start = rank * block_size
stop = (rank + 1) * block_size
q_tensor = slice_[start:stop]