feat: better errors for warmup and TP (#575)

Close #571
This commit is contained in:
OlivierDehaene 2023-07-10 14:47:15 +02:00 committed by GitHub
parent e943a294bc
commit b4024edd45
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 80 additions and 23 deletions

View File

@ -10,8 +10,9 @@ use opentelemetry_otlp::WithExportConfig;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::Path;
use std::time::Duration;
use text_generation_client::ShardedClient;
use text_generation_client::{ClientError, ShardedClient};
use text_generation_router::{server, HubModelInfo};
use thiserror::Error;
use tokenizers::{FromPretrainedParameters, Tokenizer};
use tower_http::cors::AllowOrigin;
use tracing_subscriber::layer::SubscriberExt;
@ -70,7 +71,7 @@ struct Args {
ngrok_password: Option<String>,
}
fn main() -> Result<(), std::io::Error> {
fn main() -> Result<(), RouterError> {
// Get args
let args = Args::parse();
// Pattern match configuration
@ -149,8 +150,7 @@ fn main() -> Result<(), std::io::Error> {
// Launch Tokio runtime
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap()
.build()?
.block_on(async {
init_logging(otlp_endpoint, json_output);
@ -192,17 +192,14 @@ fn main() -> Result<(), std::io::Error> {
// Instantiate sharded client from the master unix socket
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
.await
.expect("Could not connect to server");
.map_err(RouterError::Connection)?;
// Clear the cache; useful if the webserver rebooted
sharded_client
.clear_cache(None)
.await
.expect("Unable to clear cache");
.map_err(RouterError::Cache)?;
// Get info from the shard
let shard_info = sharded_client
.info()
.await
.expect("Unable to get shard info");
let shard_info = sharded_client.info().await.map_err(RouterError::Info)?;
// Warmup model
tracing::info!("Warming up model");
@ -213,7 +210,7 @@ fn main() -> Result<(), std::io::Error> {
max_batch_total_tokens,
)
.await
.expect("Unable to warmup model");
.map_err(RouterError::Warmup)?;
tracing::info!("Connected");
let addr = match hostname.parse() {
@ -249,7 +246,7 @@ fn main() -> Result<(), std::io::Error> {
ngrok_username,
ngrok_password,
)
.await;
.await?;
Ok(())
})
}
@ -331,3 +328,19 @@ pub async fn get_model_info(
}
None
}
#[derive(Debug, Error)]
enum RouterError {
#[error("Unable to connect to the Python model shards: {0}")]
Connection(ClientError),
#[error("Unable to clear the Python model shards cache: {0}")]
Cache(ClientError),
#[error("Unable to get the Python model shards info: {0}")]
Info(ClientError),
#[error("Unable to warmup the Python model shards: {0}")]
Warmup(ClientError),
#[error("Tokio runtime failed to start: {0}")]
Tokio(#[from] std::io::Error),
#[error("Axum webserver failed: {0}")]
Axum(#[from] axum::BoxError),
}

View File

@ -527,7 +527,7 @@ pub async fn run(
ngrok_domain: Option<String>,
ngrok_username: Option<String>,
ngrok_password: Option<String>,
) {
) -> Result<(), axum::BoxError> {
// OpenAPI documentation
#[derive(OpenApi)]
#[openapi(
@ -726,8 +726,7 @@ pub async fn run(
.serve(app.into_make_service())
//Wait until all requests are finished to shut down
.with_graceful_shutdown(shutdown_signal())
.await
.unwrap();
.await?;
}
#[cfg(not(feature = "ngrok"))]
{
@ -744,9 +743,9 @@ pub async fn run(
.serve(app.into_make_service())
// Wait until all requests are finished to shut down
.with_graceful_shutdown(shutdown_signal())
.await
.unwrap();
.await?;
}
Ok(())
}
/// Shutdown signal handler

View File

@ -256,6 +256,11 @@ class BloomAttention(nn.Module):
self.beta = 1.0
process_group = weights.process_group
if self.num_heads % process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {process_group.size()}"
)
self.num_heads = self.num_heads // process_group.size()
self.query_key_value = TensorParallelColumnLinear.load(
config=config,

View File

@ -112,6 +112,11 @@ class FlashLlamaAttention(torch.nn.Module):
self.softmax_scale = self.head_size**-0.5
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.query_key_value = TensorParallelColumnLinear.load_multi(
config,

View File

@ -95,6 +95,12 @@ class FlashNeoxAttention(torch.nn.Module):
self.num_heads = num_heads
self.hidden_size = hidden_size
self.head_size = hidden_size // num_heads
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.rotary_emb = PositionRotaryEmbedding.load(

View File

@ -118,6 +118,12 @@ class FlashRWAttention(torch.nn.Module):
dim=self.head_size, base=10000.0, device=weights.device
)
self.softmax_scale = self.head_size ** (-0.5)
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.query_key_value = TensorParallelColumnLinear.load(

View File

@ -208,7 +208,11 @@ class FlashMQAttention(torch.nn.Module):
self.hidden_size = hidden_size
self.head_size = hidden_size // num_heads
assert self.num_heads % weights.process_group.size() == 0
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.softmax_scale = self.head_size ** (-0.5)

View File

@ -319,6 +319,12 @@ class MultiheadAttention(nn.Module):
if self.softmax_scale is None:
self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
self.attn_dropout_p = config.attn_config["attn_pdrop"]
if self.n_heads % weights.process_group.size() != 0:
raise ValueError(
f"`n_heads` must be divisible by `num_shards` (got `n_heads`: {self.n_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.n_heads = self.n_heads // weights.process_group.size()
self.Wqkv = load_col(
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias

View File

@ -154,7 +154,12 @@ class GPTNeoXAttention(nn.Module):
torch.tensor(self.head_size, dtype=torch.float32)
).to(torch.get_default_dtype())
assert self.num_attention_heads % weights.process_group.size() == 0
if self.num_attention_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_attention_heads` must be divisible by `num_shards` "
f"(got `num_attention_heads`: {self.num_attention_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_attention_heads = (
self.num_attention_heads // weights.process_group.size()
)

View File

@ -147,7 +147,11 @@ class OPTAttention(nn.Module):
self.is_decoder = is_decoder
process_group = weights.process_group
assert self.num_heads % process_group.size() == 0
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // process_group.size()
self.embed_dim = self.embed_dim // process_group.size()

View File

@ -246,6 +246,11 @@ class T5Attention(nn.Module):
self.o = TensorParallelRowLinear.load(
config, prefix=f"{prefix}.o", weights=weights, bias=False
)
if self.n_heads % weights.process_group.size() != 0:
raise ValueError(
f"`n_heads` must be divisible by `num_shards` (got `n_heads`: {self.n_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.n_heads = self.n_heads // process_group.size()
self.inner_dim = self.inner_dim // process_group.size()

View File

@ -727,12 +727,11 @@ class FlashCausalLM(Model):
)
_, batch = self.generate_token(batch)
except Exception as e:
logger.exception(
raise RuntimeError(
f"Not enough memory to handle {max_total_tokens} total tokens with {len(batch.input_ids)} "
f"prefill tokens. "
f"You need to decrease `--max-batch-total-tokens` or `--max-batch-prefill-tokens`"
)
raise e
) from e
del batch
torch.cuda.empty_cache()