parent
e943a294bc
commit
b4024edd45
|
@ -10,8 +10,9 @@ use opentelemetry_otlp::WithExportConfig;
|
||||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use text_generation_client::ShardedClient;
|
use text_generation_client::{ClientError, ShardedClient};
|
||||||
use text_generation_router::{server, HubModelInfo};
|
use text_generation_router::{server, HubModelInfo};
|
||||||
|
use thiserror::Error;
|
||||||
use tokenizers::{FromPretrainedParameters, Tokenizer};
|
use tokenizers::{FromPretrainedParameters, Tokenizer};
|
||||||
use tower_http::cors::AllowOrigin;
|
use tower_http::cors::AllowOrigin;
|
||||||
use tracing_subscriber::layer::SubscriberExt;
|
use tracing_subscriber::layer::SubscriberExt;
|
||||||
|
@ -70,7 +71,7 @@ struct Args {
|
||||||
ngrok_password: Option<String>,
|
ngrok_password: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<(), std::io::Error> {
|
fn main() -> Result<(), RouterError> {
|
||||||
// Get args
|
// Get args
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
// Pattern match configuration
|
// Pattern match configuration
|
||||||
|
@ -149,8 +150,7 @@ fn main() -> Result<(), std::io::Error> {
|
||||||
// Launch Tokio runtime
|
// Launch Tokio runtime
|
||||||
tokio::runtime::Builder::new_multi_thread()
|
tokio::runtime::Builder::new_multi_thread()
|
||||||
.enable_all()
|
.enable_all()
|
||||||
.build()
|
.build()?
|
||||||
.unwrap()
|
|
||||||
.block_on(async {
|
.block_on(async {
|
||||||
init_logging(otlp_endpoint, json_output);
|
init_logging(otlp_endpoint, json_output);
|
||||||
|
|
||||||
|
@ -192,17 +192,14 @@ fn main() -> Result<(), std::io::Error> {
|
||||||
// Instantiate sharded client from the master unix socket
|
// Instantiate sharded client from the master unix socket
|
||||||
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
||||||
.await
|
.await
|
||||||
.expect("Could not connect to server");
|
.map_err(RouterError::Connection)?;
|
||||||
// Clear the cache; useful if the webserver rebooted
|
// Clear the cache; useful if the webserver rebooted
|
||||||
sharded_client
|
sharded_client
|
||||||
.clear_cache(None)
|
.clear_cache(None)
|
||||||
.await
|
.await
|
||||||
.expect("Unable to clear cache");
|
.map_err(RouterError::Cache)?;
|
||||||
// Get info from the shard
|
// Get info from the shard
|
||||||
let shard_info = sharded_client
|
let shard_info = sharded_client.info().await.map_err(RouterError::Info)?;
|
||||||
.info()
|
|
||||||
.await
|
|
||||||
.expect("Unable to get shard info");
|
|
||||||
|
|
||||||
// Warmup model
|
// Warmup model
|
||||||
tracing::info!("Warming up model");
|
tracing::info!("Warming up model");
|
||||||
|
@ -213,7 +210,7 @@ fn main() -> Result<(), std::io::Error> {
|
||||||
max_batch_total_tokens,
|
max_batch_total_tokens,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.expect("Unable to warmup model");
|
.map_err(RouterError::Warmup)?;
|
||||||
tracing::info!("Connected");
|
tracing::info!("Connected");
|
||||||
|
|
||||||
let addr = match hostname.parse() {
|
let addr = match hostname.parse() {
|
||||||
|
@ -249,7 +246,7 @@ fn main() -> Result<(), std::io::Error> {
|
||||||
ngrok_username,
|
ngrok_username,
|
||||||
ngrok_password,
|
ngrok_password,
|
||||||
)
|
)
|
||||||
.await;
|
.await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -331,3 +328,19 @@ pub async fn get_model_info(
|
||||||
}
|
}
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
enum RouterError {
|
||||||
|
#[error("Unable to connect to the Python model shards: {0}")]
|
||||||
|
Connection(ClientError),
|
||||||
|
#[error("Unable to clear the Python model shards cache: {0}")]
|
||||||
|
Cache(ClientError),
|
||||||
|
#[error("Unable to get the Python model shards info: {0}")]
|
||||||
|
Info(ClientError),
|
||||||
|
#[error("Unable to warmup the Python model shards: {0}")]
|
||||||
|
Warmup(ClientError),
|
||||||
|
#[error("Tokio runtime failed to start: {0}")]
|
||||||
|
Tokio(#[from] std::io::Error),
|
||||||
|
#[error("Axum webserver failed: {0}")]
|
||||||
|
Axum(#[from] axum::BoxError),
|
||||||
|
}
|
||||||
|
|
|
@ -527,7 +527,7 @@ pub async fn run(
|
||||||
ngrok_domain: Option<String>,
|
ngrok_domain: Option<String>,
|
||||||
ngrok_username: Option<String>,
|
ngrok_username: Option<String>,
|
||||||
ngrok_password: Option<String>,
|
ngrok_password: Option<String>,
|
||||||
) {
|
) -> Result<(), axum::BoxError> {
|
||||||
// OpenAPI documentation
|
// OpenAPI documentation
|
||||||
#[derive(OpenApi)]
|
#[derive(OpenApi)]
|
||||||
#[openapi(
|
#[openapi(
|
||||||
|
@ -726,8 +726,7 @@ pub async fn run(
|
||||||
.serve(app.into_make_service())
|
.serve(app.into_make_service())
|
||||||
//Wait until all requests are finished to shut down
|
//Wait until all requests are finished to shut down
|
||||||
.with_graceful_shutdown(shutdown_signal())
|
.with_graceful_shutdown(shutdown_signal())
|
||||||
.await
|
.await?;
|
||||||
.unwrap();
|
|
||||||
}
|
}
|
||||||
#[cfg(not(feature = "ngrok"))]
|
#[cfg(not(feature = "ngrok"))]
|
||||||
{
|
{
|
||||||
|
@ -744,9 +743,9 @@ pub async fn run(
|
||||||
.serve(app.into_make_service())
|
.serve(app.into_make_service())
|
||||||
// Wait until all requests are finished to shut down
|
// Wait until all requests are finished to shut down
|
||||||
.with_graceful_shutdown(shutdown_signal())
|
.with_graceful_shutdown(shutdown_signal())
|
||||||
.await
|
.await?;
|
||||||
.unwrap();
|
|
||||||
}
|
}
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Shutdown signal handler
|
/// Shutdown signal handler
|
||||||
|
|
|
@ -256,6 +256,11 @@ class BloomAttention(nn.Module):
|
||||||
self.beta = 1.0
|
self.beta = 1.0
|
||||||
|
|
||||||
process_group = weights.process_group
|
process_group = weights.process_group
|
||||||
|
if self.num_heads % process_group.size() != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||||
|
f"and `num_shards`: {process_group.size()}"
|
||||||
|
)
|
||||||
self.num_heads = self.num_heads // process_group.size()
|
self.num_heads = self.num_heads // process_group.size()
|
||||||
self.query_key_value = TensorParallelColumnLinear.load(
|
self.query_key_value = TensorParallelColumnLinear.load(
|
||||||
config=config,
|
config=config,
|
||||||
|
|
|
@ -112,6 +112,11 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
|
|
||||||
self.softmax_scale = self.head_size**-0.5
|
self.softmax_scale = self.head_size**-0.5
|
||||||
|
|
||||||
|
if self.num_heads % weights.process_group.size() != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||||
|
f"and `num_shards`: {weights.process_group.size()}"
|
||||||
|
)
|
||||||
self.num_heads = self.num_heads // weights.process_group.size()
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
self.query_key_value = TensorParallelColumnLinear.load_multi(
|
self.query_key_value = TensorParallelColumnLinear.load_multi(
|
||||||
config,
|
config,
|
||||||
|
|
|
@ -95,6 +95,12 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.head_size = hidden_size // num_heads
|
self.head_size = hidden_size // num_heads
|
||||||
|
|
||||||
|
if self.num_heads % weights.process_group.size() != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||||
|
f"and `num_shards`: {weights.process_group.size()}"
|
||||||
|
)
|
||||||
self.num_heads = self.num_heads // weights.process_group.size()
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
|
|
||||||
self.rotary_emb = PositionRotaryEmbedding.load(
|
self.rotary_emb = PositionRotaryEmbedding.load(
|
||||||
|
|
|
@ -118,6 +118,12 @@ class FlashRWAttention(torch.nn.Module):
|
||||||
dim=self.head_size, base=10000.0, device=weights.device
|
dim=self.head_size, base=10000.0, device=weights.device
|
||||||
)
|
)
|
||||||
self.softmax_scale = self.head_size ** (-0.5)
|
self.softmax_scale = self.head_size ** (-0.5)
|
||||||
|
|
||||||
|
if self.num_heads % weights.process_group.size() != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||||
|
f"and `num_shards`: {weights.process_group.size()}"
|
||||||
|
)
|
||||||
self.num_heads = self.num_heads // weights.process_group.size()
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
|
|
||||||
self.query_key_value = TensorParallelColumnLinear.load(
|
self.query_key_value = TensorParallelColumnLinear.load(
|
||||||
|
|
|
@ -208,7 +208,11 @@ class FlashMQAttention(torch.nn.Module):
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.head_size = hidden_size // num_heads
|
self.head_size = hidden_size // num_heads
|
||||||
|
|
||||||
assert self.num_heads % weights.process_group.size() == 0
|
if self.num_heads % weights.process_group.size() != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||||
|
f"and `num_shards`: {weights.process_group.size()}"
|
||||||
|
)
|
||||||
self.num_heads = self.num_heads // weights.process_group.size()
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
|
|
||||||
self.softmax_scale = self.head_size ** (-0.5)
|
self.softmax_scale = self.head_size ** (-0.5)
|
||||||
|
|
|
@ -319,6 +319,12 @@ class MultiheadAttention(nn.Module):
|
||||||
if self.softmax_scale is None:
|
if self.softmax_scale is None:
|
||||||
self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
|
self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
|
||||||
self.attn_dropout_p = config.attn_config["attn_pdrop"]
|
self.attn_dropout_p = config.attn_config["attn_pdrop"]
|
||||||
|
|
||||||
|
if self.n_heads % weights.process_group.size() != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"`n_heads` must be divisible by `num_shards` (got `n_heads`: {self.n_heads} "
|
||||||
|
f"and `num_shards`: {weights.process_group.size()}"
|
||||||
|
)
|
||||||
self.n_heads = self.n_heads // weights.process_group.size()
|
self.n_heads = self.n_heads // weights.process_group.size()
|
||||||
self.Wqkv = load_col(
|
self.Wqkv = load_col(
|
||||||
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
|
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
|
||||||
|
|
|
@ -154,7 +154,12 @@ class GPTNeoXAttention(nn.Module):
|
||||||
torch.tensor(self.head_size, dtype=torch.float32)
|
torch.tensor(self.head_size, dtype=torch.float32)
|
||||||
).to(torch.get_default_dtype())
|
).to(torch.get_default_dtype())
|
||||||
|
|
||||||
assert self.num_attention_heads % weights.process_group.size() == 0
|
if self.num_attention_heads % weights.process_group.size() != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"`num_attention_heads` must be divisible by `num_shards` "
|
||||||
|
f"(got `num_attention_heads`: {self.num_attention_heads} "
|
||||||
|
f"and `num_shards`: {weights.process_group.size()}"
|
||||||
|
)
|
||||||
self.num_attention_heads = (
|
self.num_attention_heads = (
|
||||||
self.num_attention_heads // weights.process_group.size()
|
self.num_attention_heads // weights.process_group.size()
|
||||||
)
|
)
|
||||||
|
|
|
@ -147,7 +147,11 @@ class OPTAttention(nn.Module):
|
||||||
self.is_decoder = is_decoder
|
self.is_decoder = is_decoder
|
||||||
|
|
||||||
process_group = weights.process_group
|
process_group = weights.process_group
|
||||||
assert self.num_heads % process_group.size() == 0
|
if self.num_heads % weights.process_group.size() != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||||
|
f"and `num_shards`: {weights.process_group.size()}"
|
||||||
|
)
|
||||||
self.num_heads = self.num_heads // process_group.size()
|
self.num_heads = self.num_heads // process_group.size()
|
||||||
self.embed_dim = self.embed_dim // process_group.size()
|
self.embed_dim = self.embed_dim // process_group.size()
|
||||||
|
|
||||||
|
|
|
@ -246,6 +246,11 @@ class T5Attention(nn.Module):
|
||||||
self.o = TensorParallelRowLinear.load(
|
self.o = TensorParallelRowLinear.load(
|
||||||
config, prefix=f"{prefix}.o", weights=weights, bias=False
|
config, prefix=f"{prefix}.o", weights=weights, bias=False
|
||||||
)
|
)
|
||||||
|
if self.n_heads % weights.process_group.size() != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"`n_heads` must be divisible by `num_shards` (got `n_heads`: {self.n_heads} "
|
||||||
|
f"and `num_shards`: {weights.process_group.size()}"
|
||||||
|
)
|
||||||
self.n_heads = self.n_heads // process_group.size()
|
self.n_heads = self.n_heads // process_group.size()
|
||||||
self.inner_dim = self.inner_dim // process_group.size()
|
self.inner_dim = self.inner_dim // process_group.size()
|
||||||
|
|
||||||
|
|
|
@ -727,12 +727,11 @@ class FlashCausalLM(Model):
|
||||||
)
|
)
|
||||||
_, batch = self.generate_token(batch)
|
_, batch = self.generate_token(batch)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(
|
raise RuntimeError(
|
||||||
f"Not enough memory to handle {max_total_tokens} total tokens with {len(batch.input_ids)} "
|
f"Not enough memory to handle {max_total_tokens} total tokens with {len(batch.input_ids)} "
|
||||||
f"prefill tokens. "
|
f"prefill tokens. "
|
||||||
f"You need to decrease `--max-batch-total-tokens` or `--max-batch-prefill-tokens`"
|
f"You need to decrease `--max-batch-total-tokens` or `--max-batch-prefill-tokens`"
|
||||||
)
|
) from e
|
||||||
raise e
|
|
||||||
del batch
|
del batch
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue