feat(router): add ngrok integration (#453)
This commit is contained in:
parent
5ce89059f8
commit
f59fb8b630
File diff suppressed because it is too large
Load Diff
|
@ -229,6 +229,26 @@ struct Args {
|
|||
#[clap(long, env)]
|
||||
watermark_delta: Option<f32>,
|
||||
|
||||
/// Enable ngrok tunneling
|
||||
#[clap(long, env)]
|
||||
ngrok: bool,
|
||||
|
||||
/// ngrok authentication token
|
||||
#[clap(long, env)]
|
||||
ngrok_authtoken: Option<String>,
|
||||
|
||||
/// ngrok domain name where the axum webserver will be available at
|
||||
#[clap(long, env)]
|
||||
ngrok_domain: Option<String>,
|
||||
|
||||
/// ngrok basic auth username
|
||||
#[clap(long, env)]
|
||||
ngrok_username: Option<String>,
|
||||
|
||||
/// ngrok basic auth password
|
||||
#[clap(long, env)]
|
||||
ngrok_password: Option<String>,
|
||||
|
||||
/// Display a lot of information about your runtime environment
|
||||
#[clap(long, short, action)]
|
||||
env: bool,
|
||||
|
@ -845,6 +865,30 @@ fn spawn_webserver(
|
|||
argv.push(origin);
|
||||
}
|
||||
|
||||
// Ngrok
|
||||
if args.ngrok {
|
||||
let authtoken = args.ngrok_authtoken.ok_or_else(|| {
|
||||
tracing::error!("`ngrok-authtoken` must be set when using ngrok tunneling");
|
||||
LauncherError::WebserverCannotStart
|
||||
})?;
|
||||
|
||||
argv.push("--ngrok".to_string());
|
||||
argv.push("--ngrok-authtoken".to_string());
|
||||
argv.push(authtoken);
|
||||
|
||||
if let Some(domain) = args.ngrok_domain {
|
||||
argv.push("--ngrok-domain".to_string());
|
||||
argv.push(domain);
|
||||
}
|
||||
|
||||
if let (Some(username), Some(password)) = (args.ngrok_username, args.ngrok_password) {
|
||||
argv.push("--ngrok-username".to_string());
|
||||
argv.push(username);
|
||||
argv.push("--ngrok-password".to_string());
|
||||
argv.push(password);
|
||||
}
|
||||
}
|
||||
|
||||
// Copy current process env
|
||||
let mut env: Vec<(OsString, OsString)> = env::vars_os().collect();
|
||||
|
||||
|
|
|
@ -40,6 +40,11 @@ tracing-opentelemetry = "0.18.0"
|
|||
tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] }
|
||||
utoipa = { version = "3.0.1", features = ["axum_extras"] }
|
||||
utoipa-swagger-ui = { version = "3.0.2", features = ["axum"] }
|
||||
ngrok = { version = "0.12.3", features = ["axum"], optional = true }
|
||||
|
||||
[build-dependencies]
|
||||
vergen = { version = "8.0.0", features = ["build", "git", "gitcl"] }
|
||||
|
||||
[features]
|
||||
default = ["ngrok"]
|
||||
ngrok = ["dep:ngrok"]
|
|
@ -56,6 +56,16 @@ struct Args {
|
|||
otlp_endpoint: Option<String>,
|
||||
#[clap(long, env)]
|
||||
cors_allow_origin: Option<Vec<String>>,
|
||||
#[clap(long, env)]
|
||||
ngrok: bool,
|
||||
#[clap(long, env)]
|
||||
ngrok_authtoken: Option<String>,
|
||||
#[clap(long, env)]
|
||||
ngrok_domain: Option<String>,
|
||||
#[clap(long, env)]
|
||||
ngrok_username: Option<String>,
|
||||
#[clap(long, env)]
|
||||
ngrok_password: Option<String>,
|
||||
}
|
||||
|
||||
fn main() -> Result<(), std::io::Error> {
|
||||
|
@ -80,6 +90,11 @@ fn main() -> Result<(), std::io::Error> {
|
|||
json_output,
|
||||
otlp_endpoint,
|
||||
cors_allow_origin,
|
||||
ngrok,
|
||||
ngrok_authtoken,
|
||||
ngrok_domain,
|
||||
ngrok_username,
|
||||
ngrok_password,
|
||||
} = args;
|
||||
|
||||
if validation_workers == 0 {
|
||||
|
@ -198,6 +213,11 @@ fn main() -> Result<(), std::io::Error> {
|
|||
validation_workers,
|
||||
addr,
|
||||
cors_allow_origin,
|
||||
ngrok,
|
||||
ngrok_authtoken,
|
||||
ngrok_domain,
|
||||
ngrok_username,
|
||||
ngrok_password,
|
||||
)
|
||||
.await;
|
||||
Ok(())
|
||||
|
|
|
@ -49,7 +49,7 @@ impl Queue {
|
|||
// Send append command to the background task managing the state
|
||||
// Unwrap is safe here
|
||||
self.queue_sender
|
||||
.send(QueueCommand::Append(entry, Span::current()))
|
||||
.send(QueueCommand::Append(Box::new(entry), Span::current()))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
|
@ -85,7 +85,7 @@ async fn queue_task(requires_padding: bool, receiver: flume::Receiver<QueueComma
|
|||
while let Ok(cmd) = receiver.recv_async().await {
|
||||
match cmd {
|
||||
QueueCommand::Append(entry, span) => {
|
||||
span.in_scope(|| state.append(entry));
|
||||
span.in_scope(|| state.append(*entry));
|
||||
metrics::increment_gauge!("tgi_queue_size", 1.0);
|
||||
}
|
||||
QueueCommand::NextBatch {
|
||||
|
@ -256,7 +256,7 @@ type NextBatch = (IntMap<u64, Entry>, Batch, Span);
|
|||
|
||||
#[derive(Debug)]
|
||||
enum QueueCommand {
|
||||
Append(Entry, Span),
|
||||
Append(Box<Entry>, Span),
|
||||
NextBatch {
|
||||
min_size: Option<usize>,
|
||||
token_budget: u32,
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use crate::health::Health;
|
||||
/// HTTP Server logic
|
||||
use crate::health::Health;
|
||||
use crate::infer::{InferError, InferResponse, InferStreamResponse};
|
||||
use crate::validation::ValidationError;
|
||||
use crate::{
|
||||
|
@ -520,6 +520,11 @@ pub async fn run(
|
|||
validation_workers: usize,
|
||||
addr: SocketAddr,
|
||||
allow_origin: Option<AllowOrigin>,
|
||||
ngrok: bool,
|
||||
ngrok_authtoken: Option<String>,
|
||||
ngrok_domain: Option<String>,
|
||||
ngrok_username: Option<String>,
|
||||
ngrok_password: Option<String>,
|
||||
) {
|
||||
// OpenAPI documentation
|
||||
#[derive(OpenApi)]
|
||||
|
@ -683,6 +688,53 @@ pub async fn run(
|
|||
.layer(opentelemetry_tracing_layer())
|
||||
.layer(cors_layer);
|
||||
|
||||
if ngrok {
|
||||
#[cfg(feature = "ngrok")]
|
||||
{
|
||||
use ngrok::config::TunnelBuilder;
|
||||
use ngrok::tunnel::UrlTunnel;
|
||||
|
||||
let _ = addr;
|
||||
|
||||
let authtoken =
|
||||
ngrok_authtoken.expect("`ngrok-authtoken` must be set when using ngrok tunneling");
|
||||
|
||||
let mut tunnel = ngrok::Session::builder()
|
||||
.authtoken(authtoken)
|
||||
.connect()
|
||||
.await
|
||||
.unwrap()
|
||||
.http_endpoint();
|
||||
|
||||
if let Some(domain) = ngrok_domain {
|
||||
tunnel = tunnel.domain(domain);
|
||||
}
|
||||
|
||||
if let (Some(username), Some(password)) = (ngrok_username, ngrok_password) {
|
||||
tunnel = tunnel.basic_auth(username, password);
|
||||
}
|
||||
|
||||
let listener = tunnel.listen().await.unwrap();
|
||||
|
||||
// Run server
|
||||
tracing::info!("Ingress URL: {:?}", listener.url());
|
||||
axum::Server::builder(listener)
|
||||
.serve(app.into_make_service())
|
||||
//Wait until all requests are finished to shut down
|
||||
.with_graceful_shutdown(shutdown_signal())
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
#[cfg(not(feature = "ngrok"))]
|
||||
{
|
||||
let _ngrok_authtoken = ngrok_authtoken;
|
||||
let _ngrok_domain = ngrok_domain;
|
||||
let _ngrok_username = ngrok_username;
|
||||
let _ngrok_password = ngrok_password;
|
||||
|
||||
panic!("`text-generation-router` was compiled without the `ngrok` feature");
|
||||
}
|
||||
} else {
|
||||
// Run server
|
||||
axum::Server::bind(&addr)
|
||||
.serve(app.into_make_service())
|
||||
|
@ -691,6 +743,7 @@ pub async fn run(
|
|||
.await
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
/// Shutdown signal handler
|
||||
async fn shutdown_signal() {
|
||||
|
|
|
@ -47,7 +47,6 @@ def load_multi_mqa(
|
|||
shape = slice_.get_shape()
|
||||
block_size = (shape[0] - 2 * head_size) // world_size
|
||||
assert (shape[0] - 2 * head_size) % world_size == 0
|
||||
q_tensor = slice_[start:stop]
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
q_tensor = slice_[start:stop]
|
||||
|
|
Loading…
Reference in New Issue