feat: better errors for warmup and TP (#575)

Close #571
This commit is contained in:
OlivierDehaene 2023-07-10 14:47:15 +02:00 committed by GitHub
parent e943a294bc
commit b4024edd45
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 80 additions and 23 deletions

View File

@ -10,8 +10,9 @@ use opentelemetry_otlp::WithExportConfig;
use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::Path; use std::path::Path;
use std::time::Duration; use std::time::Duration;
use text_generation_client::ShardedClient; use text_generation_client::{ClientError, ShardedClient};
use text_generation_router::{server, HubModelInfo}; use text_generation_router::{server, HubModelInfo};
use thiserror::Error;
use tokenizers::{FromPretrainedParameters, Tokenizer}; use tokenizers::{FromPretrainedParameters, Tokenizer};
use tower_http::cors::AllowOrigin; use tower_http::cors::AllowOrigin;
use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::layer::SubscriberExt;
@ -70,7 +71,7 @@ struct Args {
ngrok_password: Option<String>, ngrok_password: Option<String>,
} }
fn main() -> Result<(), std::io::Error> { fn main() -> Result<(), RouterError> {
// Get args // Get args
let args = Args::parse(); let args = Args::parse();
// Pattern match configuration // Pattern match configuration
@ -149,8 +150,7 @@ fn main() -> Result<(), std::io::Error> {
// Launch Tokio runtime // Launch Tokio runtime
tokio::runtime::Builder::new_multi_thread() tokio::runtime::Builder::new_multi_thread()
.enable_all() .enable_all()
.build() .build()?
.unwrap()
.block_on(async { .block_on(async {
init_logging(otlp_endpoint, json_output); init_logging(otlp_endpoint, json_output);
@ -192,17 +192,14 @@ fn main() -> Result<(), std::io::Error> {
// Instantiate sharded client from the master unix socket // Instantiate sharded client from the master unix socket
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path) let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
.await .await
.expect("Could not connect to server"); .map_err(RouterError::Connection)?;
// Clear the cache; useful if the webserver rebooted // Clear the cache; useful if the webserver rebooted
sharded_client sharded_client
.clear_cache(None) .clear_cache(None)
.await .await
.expect("Unable to clear cache"); .map_err(RouterError::Cache)?;
// Get info from the shard // Get info from the shard
let shard_info = sharded_client let shard_info = sharded_client.info().await.map_err(RouterError::Info)?;
.info()
.await
.expect("Unable to get shard info");
// Warmup model // Warmup model
tracing::info!("Warming up model"); tracing::info!("Warming up model");
@ -213,7 +210,7 @@ fn main() -> Result<(), std::io::Error> {
max_batch_total_tokens, max_batch_total_tokens,
) )
.await .await
.expect("Unable to warmup model"); .map_err(RouterError::Warmup)?;
tracing::info!("Connected"); tracing::info!("Connected");
let addr = match hostname.parse() { let addr = match hostname.parse() {
@ -249,7 +246,7 @@ fn main() -> Result<(), std::io::Error> {
ngrok_username, ngrok_username,
ngrok_password, ngrok_password,
) )
.await; .await?;
Ok(()) Ok(())
}) })
} }
@ -331,3 +328,19 @@ pub async fn get_model_info(
} }
None None
} }
#[derive(Debug, Error)]
enum RouterError {
#[error("Unable to connect to the Python model shards: {0}")]
Connection(ClientError),
#[error("Unable to clear the Python model shards cache: {0}")]
Cache(ClientError),
#[error("Unable to get the Python model shards info: {0}")]
Info(ClientError),
#[error("Unable to warmup the Python model shards: {0}")]
Warmup(ClientError),
#[error("Tokio runtime failed to start: {0}")]
Tokio(#[from] std::io::Error),
#[error("Axum webserver failed: {0}")]
Axum(#[from] axum::BoxError),
}

View File

@ -527,7 +527,7 @@ pub async fn run(
ngrok_domain: Option<String>, ngrok_domain: Option<String>,
ngrok_username: Option<String>, ngrok_username: Option<String>,
ngrok_password: Option<String>, ngrok_password: Option<String>,
) { ) -> Result<(), axum::BoxError> {
// OpenAPI documentation // OpenAPI documentation
#[derive(OpenApi)] #[derive(OpenApi)]
#[openapi( #[openapi(
@ -726,8 +726,7 @@ pub async fn run(
.serve(app.into_make_service()) .serve(app.into_make_service())
//Wait until all requests are finished to shut down //Wait until all requests are finished to shut down
.with_graceful_shutdown(shutdown_signal()) .with_graceful_shutdown(shutdown_signal())
.await .await?;
.unwrap();
} }
#[cfg(not(feature = "ngrok"))] #[cfg(not(feature = "ngrok"))]
{ {
@ -744,9 +743,9 @@ pub async fn run(
.serve(app.into_make_service()) .serve(app.into_make_service())
// Wait until all requests are finished to shut down // Wait until all requests are finished to shut down
.with_graceful_shutdown(shutdown_signal()) .with_graceful_shutdown(shutdown_signal())
.await .await?;
.unwrap();
} }
Ok(())
} }
/// Shutdown signal handler /// Shutdown signal handler

View File

@ -256,6 +256,11 @@ class BloomAttention(nn.Module):
self.beta = 1.0 self.beta = 1.0
process_group = weights.process_group process_group = weights.process_group
if self.num_heads % process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {process_group.size()}"
)
self.num_heads = self.num_heads // process_group.size() self.num_heads = self.num_heads // process_group.size()
self.query_key_value = TensorParallelColumnLinear.load( self.query_key_value = TensorParallelColumnLinear.load(
config=config, config=config,

View File

@ -112,6 +112,11 @@ class FlashLlamaAttention(torch.nn.Module):
self.softmax_scale = self.head_size**-0.5 self.softmax_scale = self.head_size**-0.5
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size() self.num_heads = self.num_heads // weights.process_group.size()
self.query_key_value = TensorParallelColumnLinear.load_multi( self.query_key_value = TensorParallelColumnLinear.load_multi(
config, config,

View File

@ -95,6 +95,12 @@ class FlashNeoxAttention(torch.nn.Module):
self.num_heads = num_heads self.num_heads = num_heads
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.head_size = hidden_size // num_heads self.head_size = hidden_size // num_heads
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size() self.num_heads = self.num_heads // weights.process_group.size()
self.rotary_emb = PositionRotaryEmbedding.load( self.rotary_emb = PositionRotaryEmbedding.load(

View File

@ -118,6 +118,12 @@ class FlashRWAttention(torch.nn.Module):
dim=self.head_size, base=10000.0, device=weights.device dim=self.head_size, base=10000.0, device=weights.device
) )
self.softmax_scale = self.head_size ** (-0.5) self.softmax_scale = self.head_size ** (-0.5)
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size() self.num_heads = self.num_heads // weights.process_group.size()
self.query_key_value = TensorParallelColumnLinear.load( self.query_key_value = TensorParallelColumnLinear.load(

View File

@ -208,7 +208,11 @@ class FlashMQAttention(torch.nn.Module):
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.head_size = hidden_size // num_heads self.head_size = hidden_size // num_heads
assert self.num_heads % weights.process_group.size() == 0 if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size() self.num_heads = self.num_heads // weights.process_group.size()
self.softmax_scale = self.head_size ** (-0.5) self.softmax_scale = self.head_size ** (-0.5)

View File

@ -319,6 +319,12 @@ class MultiheadAttention(nn.Module):
if self.softmax_scale is None: if self.softmax_scale is None:
self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads) self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
self.attn_dropout_p = config.attn_config["attn_pdrop"] self.attn_dropout_p = config.attn_config["attn_pdrop"]
if self.n_heads % weights.process_group.size() != 0:
raise ValueError(
f"`n_heads` must be divisible by `num_shards` (got `n_heads`: {self.n_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.n_heads = self.n_heads // weights.process_group.size() self.n_heads = self.n_heads // weights.process_group.size()
self.Wqkv = load_col( self.Wqkv = load_col(
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias

View File

@ -154,7 +154,12 @@ class GPTNeoXAttention(nn.Module):
torch.tensor(self.head_size, dtype=torch.float32) torch.tensor(self.head_size, dtype=torch.float32)
).to(torch.get_default_dtype()) ).to(torch.get_default_dtype())
assert self.num_attention_heads % weights.process_group.size() == 0 if self.num_attention_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_attention_heads` must be divisible by `num_shards` "
f"(got `num_attention_heads`: {self.num_attention_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_attention_heads = ( self.num_attention_heads = (
self.num_attention_heads // weights.process_group.size() self.num_attention_heads // weights.process_group.size()
) )

View File

@ -147,7 +147,11 @@ class OPTAttention(nn.Module):
self.is_decoder = is_decoder self.is_decoder = is_decoder
process_group = weights.process_group process_group = weights.process_group
assert self.num_heads % process_group.size() == 0 if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // process_group.size() self.num_heads = self.num_heads // process_group.size()
self.embed_dim = self.embed_dim // process_group.size() self.embed_dim = self.embed_dim // process_group.size()

View File

@ -246,6 +246,11 @@ class T5Attention(nn.Module):
self.o = TensorParallelRowLinear.load( self.o = TensorParallelRowLinear.load(
config, prefix=f"{prefix}.o", weights=weights, bias=False config, prefix=f"{prefix}.o", weights=weights, bias=False
) )
if self.n_heads % weights.process_group.size() != 0:
raise ValueError(
f"`n_heads` must be divisible by `num_shards` (got `n_heads`: {self.n_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.n_heads = self.n_heads // process_group.size() self.n_heads = self.n_heads // process_group.size()
self.inner_dim = self.inner_dim // process_group.size() self.inner_dim = self.inner_dim // process_group.size()

View File

@ -727,12 +727,11 @@ class FlashCausalLM(Model):
) )
_, batch = self.generate_token(batch) _, batch = self.generate_token(batch)
except Exception as e: except Exception as e:
logger.exception( raise RuntimeError(
f"Not enough memory to handle {max_total_tokens} total tokens with {len(batch.input_ids)} " f"Not enough memory to handle {max_total_tokens} total tokens with {len(batch.input_ids)} "
f"prefill tokens. " f"prefill tokens. "
f"You need to decrease `--max-batch-total-tokens` or `--max-batch-prefill-tokens`" f"You need to decrease `--max-batch-total-tokens` or `--max-batch-prefill-tokens`"
) ) from e
raise e
del batch del batch
torch.cuda.empty_cache() torch.cuda.empty_cache()