feat(router): add argument for hostname in router (#545) (#550)

# What does this PR do?

In title. Adds argument `--hostname` in router to support something like
`--hostname ::`. Tested with

```commandline
cargo run -- --port 8080 --hostname ::
curl -I -X GET 'http://[::1]:8080/health'  # failed before this commit
```

Trigger CI

---------

Co-authored-by: Phil Chen <philchen2000@gmail.com>
This commit is contained in:
OlivierDehaene 2023-07-05 18:28:45 +02:00 committed by GitHub
parent 31e2253ae7
commit 6f42942772
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 2 deletions

View File

@ -197,6 +197,10 @@ struct Args {
#[clap(default_value = "20", long, env)]
max_waiting_tokens: usize,
/// The IP address to listen on
#[clap(default_value = "0.0.0.0", long, env)]
hostname: String,
/// The port to listen on.
#[clap(default_value = "3000", long, short, env)]
port: u16,
@ -874,6 +878,8 @@ fn spawn_webserver(
args.waiting_served_ratio.to_string(),
"--max-waiting-tokens".to_string(),
args.max_waiting_tokens.to_string(),
"--hostname".to_string(),
args.hostname.to_string(),
"--port".to_string(),
args.port.to_string(),
"--master-shard-uds-path".to_string(),

View File

@ -40,6 +40,8 @@ struct Args {
max_batch_total_tokens: u32,
#[clap(default_value = "20", long, env)]
max_waiting_tokens: usize,
#[clap(default_value = "0.0.0.0", long, env)]
hostname: String,
#[clap(default_value = "3000", long, short, env)]
port: u16,
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
@ -82,6 +84,7 @@ fn main() -> Result<(), std::io::Error> {
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
hostname,
port,
master_shard_uds_path,
tokenizer_name,
@ -213,8 +216,13 @@ fn main() -> Result<(), std::io::Error> {
.expect("Unable to warmup model");
tracing::info!("Connected");
// Binds on localhost
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port);
let addr = match hostname.parse() {
Ok(ip) => SocketAddr::new(ip, port),
Err(_) => {
tracing::warn!("Invalid hostname, defaulting to 0.0.0.0");
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), port)
}
};
// Run server
server::run(