parent
e943a294bc
commit
b4024edd45
|
@ -10,8 +10,9 @@ use opentelemetry_otlp::WithExportConfig;
|
|||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use std::path::Path;
|
||||
use std::time::Duration;
|
||||
use text_generation_client::ShardedClient;
|
||||
use text_generation_client::{ClientError, ShardedClient};
|
||||
use text_generation_router::{server, HubModelInfo};
|
||||
use thiserror::Error;
|
||||
use tokenizers::{FromPretrainedParameters, Tokenizer};
|
||||
use tower_http::cors::AllowOrigin;
|
||||
use tracing_subscriber::layer::SubscriberExt;
|
||||
|
@ -70,7 +71,7 @@ struct Args {
|
|||
ngrok_password: Option<String>,
|
||||
}
|
||||
|
||||
fn main() -> Result<(), std::io::Error> {
|
||||
fn main() -> Result<(), RouterError> {
|
||||
// Get args
|
||||
let args = Args::parse();
|
||||
// Pattern match configuration
|
||||
|
@ -149,8 +150,7 @@ fn main() -> Result<(), std::io::Error> {
|
|||
// Launch Tokio runtime
|
||||
tokio::runtime::Builder::new_multi_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.unwrap()
|
||||
.build()?
|
||||
.block_on(async {
|
||||
init_logging(otlp_endpoint, json_output);
|
||||
|
||||
|
@ -192,17 +192,14 @@ fn main() -> Result<(), std::io::Error> {
|
|||
// Instantiate sharded client from the master unix socket
|
||||
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
||||
.await
|
||||
.expect("Could not connect to server");
|
||||
.map_err(RouterError::Connection)?;
|
||||
// Clear the cache; useful if the webserver rebooted
|
||||
sharded_client
|
||||
.clear_cache(None)
|
||||
.await
|
||||
.expect("Unable to clear cache");
|
||||
.map_err(RouterError::Cache)?;
|
||||
// Get info from the shard
|
||||
let shard_info = sharded_client
|
||||
.info()
|
||||
.await
|
||||
.expect("Unable to get shard info");
|
||||
let shard_info = sharded_client.info().await.map_err(RouterError::Info)?;
|
||||
|
||||
// Warmup model
|
||||
tracing::info!("Warming up model");
|
||||
|
@ -213,7 +210,7 @@ fn main() -> Result<(), std::io::Error> {
|
|||
max_batch_total_tokens,
|
||||
)
|
||||
.await
|
||||
.expect("Unable to warmup model");
|
||||
.map_err(RouterError::Warmup)?;
|
||||
tracing::info!("Connected");
|
||||
|
||||
let addr = match hostname.parse() {
|
||||
|
@ -249,7 +246,7 @@ fn main() -> Result<(), std::io::Error> {
|
|||
ngrok_username,
|
||||
ngrok_password,
|
||||
)
|
||||
.await;
|
||||
.await?;
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
@ -331,3 +328,19 @@ pub async fn get_model_info(
|
|||
}
|
||||
None
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
enum RouterError {
|
||||
#[error("Unable to connect to the Python model shards: {0}")]
|
||||
Connection(ClientError),
|
||||
#[error("Unable to clear the Python model shards cache: {0}")]
|
||||
Cache(ClientError),
|
||||
#[error("Unable to get the Python model shards info: {0}")]
|
||||
Info(ClientError),
|
||||
#[error("Unable to warmup the Python model shards: {0}")]
|
||||
Warmup(ClientError),
|
||||
#[error("Tokio runtime failed to start: {0}")]
|
||||
Tokio(#[from] std::io::Error),
|
||||
#[error("Axum webserver failed: {0}")]
|
||||
Axum(#[from] axum::BoxError),
|
||||
}
|
||||
|
|
|
@ -527,7 +527,7 @@ pub async fn run(
|
|||
ngrok_domain: Option<String>,
|
||||
ngrok_username: Option<String>,
|
||||
ngrok_password: Option<String>,
|
||||
) {
|
||||
) -> Result<(), axum::BoxError> {
|
||||
// OpenAPI documentation
|
||||
#[derive(OpenApi)]
|
||||
#[openapi(
|
||||
|
@ -726,8 +726,7 @@ pub async fn run(
|
|||
.serve(app.into_make_service())
|
||||
//Wait until all requests are finished to shut down
|
||||
.with_graceful_shutdown(shutdown_signal())
|
||||
.await
|
||||
.unwrap();
|
||||
.await?;
|
||||
}
|
||||
#[cfg(not(feature = "ngrok"))]
|
||||
{
|
||||
|
@ -744,9 +743,9 @@ pub async fn run(
|
|||
.serve(app.into_make_service())
|
||||
// Wait until all requests are finished to shut down
|
||||
.with_graceful_shutdown(shutdown_signal())
|
||||
.await
|
||||
.unwrap();
|
||||
.await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Shutdown signal handler
|
||||
|
|
|
@ -256,6 +256,11 @@ class BloomAttention(nn.Module):
|
|||
self.beta = 1.0
|
||||
|
||||
process_group = weights.process_group
|
||||
if self.num_heads % process_group.size() != 0:
|
||||
raise ValueError(
|
||||
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||
f"and `num_shards`: {process_group.size()}"
|
||||
)
|
||||
self.num_heads = self.num_heads // process_group.size()
|
||||
self.query_key_value = TensorParallelColumnLinear.load(
|
||||
config=config,
|
||||
|
|
|
@ -112,6 +112,11 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
|
||||
self.softmax_scale = self.head_size**-0.5
|
||||
|
||||
if self.num_heads % weights.process_group.size() != 0:
|
||||
raise ValueError(
|
||||
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||
f"and `num_shards`: {weights.process_group.size()}"
|
||||
)
|
||||
self.num_heads = self.num_heads // weights.process_group.size()
|
||||
self.query_key_value = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
|
|
|
@ -95,6 +95,12 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.head_size = hidden_size // num_heads
|
||||
|
||||
if self.num_heads % weights.process_group.size() != 0:
|
||||
raise ValueError(
|
||||
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||
f"and `num_shards`: {weights.process_group.size()}"
|
||||
)
|
||||
self.num_heads = self.num_heads // weights.process_group.size()
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.load(
|
||||
|
|
|
@ -118,6 +118,12 @@ class FlashRWAttention(torch.nn.Module):
|
|||
dim=self.head_size, base=10000.0, device=weights.device
|
||||
)
|
||||
self.softmax_scale = self.head_size ** (-0.5)
|
||||
|
||||
if self.num_heads % weights.process_group.size() != 0:
|
||||
raise ValueError(
|
||||
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||
f"and `num_shards`: {weights.process_group.size()}"
|
||||
)
|
||||
self.num_heads = self.num_heads // weights.process_group.size()
|
||||
|
||||
self.query_key_value = TensorParallelColumnLinear.load(
|
||||
|
|
|
@ -208,7 +208,11 @@ class FlashMQAttention(torch.nn.Module):
|
|||
self.hidden_size = hidden_size
|
||||
self.head_size = hidden_size // num_heads
|
||||
|
||||
assert self.num_heads % weights.process_group.size() == 0
|
||||
if self.num_heads % weights.process_group.size() != 0:
|
||||
raise ValueError(
|
||||
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||
f"and `num_shards`: {weights.process_group.size()}"
|
||||
)
|
||||
self.num_heads = self.num_heads // weights.process_group.size()
|
||||
|
||||
self.softmax_scale = self.head_size ** (-0.5)
|
||||
|
|
|
@ -319,6 +319,12 @@ class MultiheadAttention(nn.Module):
|
|||
if self.softmax_scale is None:
|
||||
self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
|
||||
self.attn_dropout_p = config.attn_config["attn_pdrop"]
|
||||
|
||||
if self.n_heads % weights.process_group.size() != 0:
|
||||
raise ValueError(
|
||||
f"`n_heads` must be divisible by `num_shards` (got `n_heads`: {self.n_heads} "
|
||||
f"and `num_shards`: {weights.process_group.size()}"
|
||||
)
|
||||
self.n_heads = self.n_heads // weights.process_group.size()
|
||||
self.Wqkv = load_col(
|
||||
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
|
||||
|
|
|
@ -154,7 +154,12 @@ class GPTNeoXAttention(nn.Module):
|
|||
torch.tensor(self.head_size, dtype=torch.float32)
|
||||
).to(torch.get_default_dtype())
|
||||
|
||||
assert self.num_attention_heads % weights.process_group.size() == 0
|
||||
if self.num_attention_heads % weights.process_group.size() != 0:
|
||||
raise ValueError(
|
||||
f"`num_attention_heads` must be divisible by `num_shards` "
|
||||
f"(got `num_attention_heads`: {self.num_attention_heads} "
|
||||
f"and `num_shards`: {weights.process_group.size()}"
|
||||
)
|
||||
self.num_attention_heads = (
|
||||
self.num_attention_heads // weights.process_group.size()
|
||||
)
|
||||
|
|
|
@ -147,7 +147,11 @@ class OPTAttention(nn.Module):
|
|||
self.is_decoder = is_decoder
|
||||
|
||||
process_group = weights.process_group
|
||||
assert self.num_heads % process_group.size() == 0
|
||||
if self.num_heads % weights.process_group.size() != 0:
|
||||
raise ValueError(
|
||||
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||
f"and `num_shards`: {weights.process_group.size()}"
|
||||
)
|
||||
self.num_heads = self.num_heads // process_group.size()
|
||||
self.embed_dim = self.embed_dim // process_group.size()
|
||||
|
||||
|
|
|
@ -246,6 +246,11 @@ class T5Attention(nn.Module):
|
|||
self.o = TensorParallelRowLinear.load(
|
||||
config, prefix=f"{prefix}.o", weights=weights, bias=False
|
||||
)
|
||||
if self.n_heads % weights.process_group.size() != 0:
|
||||
raise ValueError(
|
||||
f"`n_heads` must be divisible by `num_shards` (got `n_heads`: {self.n_heads} "
|
||||
f"and `num_shards`: {weights.process_group.size()}"
|
||||
)
|
||||
self.n_heads = self.n_heads // process_group.size()
|
||||
self.inner_dim = self.inner_dim // process_group.size()
|
||||
|
||||
|
|
|
@ -727,12 +727,11 @@ class FlashCausalLM(Model):
|
|||
)
|
||||
_, batch = self.generate_token(batch)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
raise RuntimeError(
|
||||
f"Not enough memory to handle {max_total_tokens} total tokens with {len(batch.input_ids)} "
|
||||
f"prefill tokens. "
|
||||
f"You need to decrease `--max-batch-total-tokens` or `--max-batch-prefill-tokens`"
|
||||
)
|
||||
raise e
|
||||
) from e
|
||||
del batch
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
|
Loading…
Reference in New Issue