feat(router): add ngrok integration (#453)
This commit is contained in:
parent
5ce89059f8
commit
f59fb8b630
File diff suppressed because it is too large
Load Diff
|
@ -229,6 +229,26 @@ struct Args {
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
watermark_delta: Option<f32>,
|
watermark_delta: Option<f32>,
|
||||||
|
|
||||||
|
/// Enable ngrok tunneling
|
||||||
|
#[clap(long, env)]
|
||||||
|
ngrok: bool,
|
||||||
|
|
||||||
|
/// ngrok authentication token
|
||||||
|
#[clap(long, env)]
|
||||||
|
ngrok_authtoken: Option<String>,
|
||||||
|
|
||||||
|
/// ngrok domain name where the axum webserver will be available at
|
||||||
|
#[clap(long, env)]
|
||||||
|
ngrok_domain: Option<String>,
|
||||||
|
|
||||||
|
/// ngrok basic auth username
|
||||||
|
#[clap(long, env)]
|
||||||
|
ngrok_username: Option<String>,
|
||||||
|
|
||||||
|
/// ngrok basic auth password
|
||||||
|
#[clap(long, env)]
|
||||||
|
ngrok_password: Option<String>,
|
||||||
|
|
||||||
/// Display a lot of information about your runtime environment
|
/// Display a lot of information about your runtime environment
|
||||||
#[clap(long, short, action)]
|
#[clap(long, short, action)]
|
||||||
env: bool,
|
env: bool,
|
||||||
|
@ -845,6 +865,30 @@ fn spawn_webserver(
|
||||||
argv.push(origin);
|
argv.push(origin);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Ngrok
|
||||||
|
if args.ngrok {
|
||||||
|
let authtoken = args.ngrok_authtoken.ok_or_else(|| {
|
||||||
|
tracing::error!("`ngrok-authtoken` must be set when using ngrok tunneling");
|
||||||
|
LauncherError::WebserverCannotStart
|
||||||
|
})?;
|
||||||
|
|
||||||
|
argv.push("--ngrok".to_string());
|
||||||
|
argv.push("--ngrok-authtoken".to_string());
|
||||||
|
argv.push(authtoken);
|
||||||
|
|
||||||
|
if let Some(domain) = args.ngrok_domain {
|
||||||
|
argv.push("--ngrok-domain".to_string());
|
||||||
|
argv.push(domain);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let (Some(username), Some(password)) = (args.ngrok_username, args.ngrok_password) {
|
||||||
|
argv.push("--ngrok-username".to_string());
|
||||||
|
argv.push(username);
|
||||||
|
argv.push("--ngrok-password".to_string());
|
||||||
|
argv.push(password);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Copy current process env
|
// Copy current process env
|
||||||
let mut env: Vec<(OsString, OsString)> = env::vars_os().collect();
|
let mut env: Vec<(OsString, OsString)> = env::vars_os().collect();
|
||||||
|
|
||||||
|
|
|
@ -40,6 +40,11 @@ tracing-opentelemetry = "0.18.0"
|
||||||
tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] }
|
tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] }
|
||||||
utoipa = { version = "3.0.1", features = ["axum_extras"] }
|
utoipa = { version = "3.0.1", features = ["axum_extras"] }
|
||||||
utoipa-swagger-ui = { version = "3.0.2", features = ["axum"] }
|
utoipa-swagger-ui = { version = "3.0.2", features = ["axum"] }
|
||||||
|
ngrok = { version = "0.12.3", features = ["axum"], optional = true }
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
vergen = { version = "8.0.0", features = ["build", "git", "gitcl"] }
|
vergen = { version = "8.0.0", features = ["build", "git", "gitcl"] }
|
||||||
|
|
||||||
|
[features]
|
||||||
|
default = ["ngrok"]
|
||||||
|
ngrok = ["dep:ngrok"]
|
|
@ -56,6 +56,16 @@ struct Args {
|
||||||
otlp_endpoint: Option<String>,
|
otlp_endpoint: Option<String>,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
cors_allow_origin: Option<Vec<String>>,
|
cors_allow_origin: Option<Vec<String>>,
|
||||||
|
#[clap(long, env)]
|
||||||
|
ngrok: bool,
|
||||||
|
#[clap(long, env)]
|
||||||
|
ngrok_authtoken: Option<String>,
|
||||||
|
#[clap(long, env)]
|
||||||
|
ngrok_domain: Option<String>,
|
||||||
|
#[clap(long, env)]
|
||||||
|
ngrok_username: Option<String>,
|
||||||
|
#[clap(long, env)]
|
||||||
|
ngrok_password: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<(), std::io::Error> {
|
fn main() -> Result<(), std::io::Error> {
|
||||||
|
@ -80,6 +90,11 @@ fn main() -> Result<(), std::io::Error> {
|
||||||
json_output,
|
json_output,
|
||||||
otlp_endpoint,
|
otlp_endpoint,
|
||||||
cors_allow_origin,
|
cors_allow_origin,
|
||||||
|
ngrok,
|
||||||
|
ngrok_authtoken,
|
||||||
|
ngrok_domain,
|
||||||
|
ngrok_username,
|
||||||
|
ngrok_password,
|
||||||
} = args;
|
} = args;
|
||||||
|
|
||||||
if validation_workers == 0 {
|
if validation_workers == 0 {
|
||||||
|
@ -198,6 +213,11 @@ fn main() -> Result<(), std::io::Error> {
|
||||||
validation_workers,
|
validation_workers,
|
||||||
addr,
|
addr,
|
||||||
cors_allow_origin,
|
cors_allow_origin,
|
||||||
|
ngrok,
|
||||||
|
ngrok_authtoken,
|
||||||
|
ngrok_domain,
|
||||||
|
ngrok_username,
|
||||||
|
ngrok_password,
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
|
@ -49,7 +49,7 @@ impl Queue {
|
||||||
// Send append command to the background task managing the state
|
// Send append command to the background task managing the state
|
||||||
// Unwrap is safe here
|
// Unwrap is safe here
|
||||||
self.queue_sender
|
self.queue_sender
|
||||||
.send(QueueCommand::Append(entry, Span::current()))
|
.send(QueueCommand::Append(Box::new(entry), Span::current()))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -85,7 +85,7 @@ async fn queue_task(requires_padding: bool, receiver: flume::Receiver<QueueComma
|
||||||
while let Ok(cmd) = receiver.recv_async().await {
|
while let Ok(cmd) = receiver.recv_async().await {
|
||||||
match cmd {
|
match cmd {
|
||||||
QueueCommand::Append(entry, span) => {
|
QueueCommand::Append(entry, span) => {
|
||||||
span.in_scope(|| state.append(entry));
|
span.in_scope(|| state.append(*entry));
|
||||||
metrics::increment_gauge!("tgi_queue_size", 1.0);
|
metrics::increment_gauge!("tgi_queue_size", 1.0);
|
||||||
}
|
}
|
||||||
QueueCommand::NextBatch {
|
QueueCommand::NextBatch {
|
||||||
|
@ -256,7 +256,7 @@ type NextBatch = (IntMap<u64, Entry>, Batch, Span);
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
enum QueueCommand {
|
enum QueueCommand {
|
||||||
Append(Entry, Span),
|
Append(Box<Entry>, Span),
|
||||||
NextBatch {
|
NextBatch {
|
||||||
min_size: Option<usize>,
|
min_size: Option<usize>,
|
||||||
token_budget: u32,
|
token_budget: u32,
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
use crate::health::Health;
|
|
||||||
/// HTTP Server logic
|
/// HTTP Server logic
|
||||||
|
use crate::health::Health;
|
||||||
use crate::infer::{InferError, InferResponse, InferStreamResponse};
|
use crate::infer::{InferError, InferResponse, InferStreamResponse};
|
||||||
use crate::validation::ValidationError;
|
use crate::validation::ValidationError;
|
||||||
use crate::{
|
use crate::{
|
||||||
|
@ -520,6 +520,11 @@ pub async fn run(
|
||||||
validation_workers: usize,
|
validation_workers: usize,
|
||||||
addr: SocketAddr,
|
addr: SocketAddr,
|
||||||
allow_origin: Option<AllowOrigin>,
|
allow_origin: Option<AllowOrigin>,
|
||||||
|
ngrok: bool,
|
||||||
|
ngrok_authtoken: Option<String>,
|
||||||
|
ngrok_domain: Option<String>,
|
||||||
|
ngrok_username: Option<String>,
|
||||||
|
ngrok_password: Option<String>,
|
||||||
) {
|
) {
|
||||||
// OpenAPI documentation
|
// OpenAPI documentation
|
||||||
#[derive(OpenApi)]
|
#[derive(OpenApi)]
|
||||||
|
@ -683,13 +688,61 @@ pub async fn run(
|
||||||
.layer(opentelemetry_tracing_layer())
|
.layer(opentelemetry_tracing_layer())
|
||||||
.layer(cors_layer);
|
.layer(cors_layer);
|
||||||
|
|
||||||
// Run server
|
if ngrok {
|
||||||
axum::Server::bind(&addr)
|
#[cfg(feature = "ngrok")]
|
||||||
.serve(app.into_make_service())
|
{
|
||||||
// Wait until all requests are finished to shut down
|
use ngrok::config::TunnelBuilder;
|
||||||
.with_graceful_shutdown(shutdown_signal())
|
use ngrok::tunnel::UrlTunnel;
|
||||||
.await
|
|
||||||
.unwrap();
|
let _ = addr;
|
||||||
|
|
||||||
|
let authtoken =
|
||||||
|
ngrok_authtoken.expect("`ngrok-authtoken` must be set when using ngrok tunneling");
|
||||||
|
|
||||||
|
let mut tunnel = ngrok::Session::builder()
|
||||||
|
.authtoken(authtoken)
|
||||||
|
.connect()
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.http_endpoint();
|
||||||
|
|
||||||
|
if let Some(domain) = ngrok_domain {
|
||||||
|
tunnel = tunnel.domain(domain);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let (Some(username), Some(password)) = (ngrok_username, ngrok_password) {
|
||||||
|
tunnel = tunnel.basic_auth(username, password);
|
||||||
|
}
|
||||||
|
|
||||||
|
let listener = tunnel.listen().await.unwrap();
|
||||||
|
|
||||||
|
// Run server
|
||||||
|
tracing::info!("Ingress URL: {:?}", listener.url());
|
||||||
|
axum::Server::builder(listener)
|
||||||
|
.serve(app.into_make_service())
|
||||||
|
//Wait until all requests are finished to shut down
|
||||||
|
.with_graceful_shutdown(shutdown_signal())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
#[cfg(not(feature = "ngrok"))]
|
||||||
|
{
|
||||||
|
let _ngrok_authtoken = ngrok_authtoken;
|
||||||
|
let _ngrok_domain = ngrok_domain;
|
||||||
|
let _ngrok_username = ngrok_username;
|
||||||
|
let _ngrok_password = ngrok_password;
|
||||||
|
|
||||||
|
panic!("`text-generation-router` was compiled without the `ngrok` feature");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Run server
|
||||||
|
axum::Server::bind(&addr)
|
||||||
|
.serve(app.into_make_service())
|
||||||
|
// Wait until all requests are finished to shut down
|
||||||
|
.with_graceful_shutdown(shutdown_signal())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Shutdown signal handler
|
/// Shutdown signal handler
|
||||||
|
|
|
@ -47,7 +47,6 @@ def load_multi_mqa(
|
||||||
shape = slice_.get_shape()
|
shape = slice_.get_shape()
|
||||||
block_size = (shape[0] - 2 * head_size) // world_size
|
block_size = (shape[0] - 2 * head_size) // world_size
|
||||||
assert (shape[0] - 2 * head_size) % world_size == 0
|
assert (shape[0] - 2 * head_size) % world_size == 0
|
||||||
q_tensor = slice_[start:stop]
|
|
||||||
start = rank * block_size
|
start = rank * block_size
|
||||||
stop = (rank + 1) * block_size
|
stop = (rank + 1) * block_size
|
||||||
q_tensor = slice_[start:stop]
|
q_tensor = slice_[start:stop]
|
||||||
|
|
Loading…
Reference in New Issue