diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 64f4f515..10fd561b 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -29,6 +29,26 @@ use tracing_subscriber::{filter::LevelFilter, EnvFilter}; mod env_runtime; mod gpu; +fn compute_optimal(config: Option<&Config>, compute: Option<&ComputeType>) -> Option { + if let (Some(config), Some(compute)) = (config, compute) { + if let (Some(f16_max_compute), Some(model_compute)) = (compute.f16_flop(), config.flop()) { + tracing::debug!("MAx compute {f16_max_compute} model compute {model_compute}"); + let optimal_size = (f16_max_compute / model_compute) as usize; + if optimal_size > 100 { + // Ignore calculations that's too low + // Most likely an error + Some(optimal_size) + } else { + None + } + } else { + None + } + } else { + None + } +} + fn get_config( model_id: &str, revision: &Option, @@ -144,7 +164,10 @@ struct RawConfig { quantization_config: Option, n_embd: Option, hidden_size: Option, + intermediate_size: Option, num_attention_heads: Option, + num_key_value_heads: Option, + num_hidden_layers: Option, head_dim: Option, vision_config: Option, is_encoder_decoder: Option, @@ -155,19 +178,42 @@ struct QuantizationConfig { quant_method: Option, } -#[derive(Deserialize)] +#[derive(Debug, Deserialize)] struct VisionConfig {} -#[derive(Deserialize)] +#[derive(Debug, Deserialize)] struct Config { max_position_embeddings: Option, quantize: Option, head_dim: Option, + num_heads: Option, + num_kv_heads: Option, + num_layers: Option, + intermediate_size: Option, + hidden_size: Option, model_type: Option, vision_config: Option, is_encoder_decoder: bool, } +impl Config { + fn flop(&self) -> Option { + let num_heads = self.num_heads? as u64; + let num_kv_heads = self.num_kv_heads? as u64; + let head_dim = self.head_dim? as u64; + let hidden_size = self.hidden_size? as u64; + let intermediate_size = self.intermediate_size? as u64; + let num_layers = self.num_layers? as u64; + + let attn_flops = 2 * (num_heads + 2 * num_kv_heads) * head_dim * hidden_size; + let o_flops = 2 * num_kv_heads * head_dim * hidden_size; + let gate_up_down_flops = 2 * 3 * hidden_size * intermediate_size; + let layer_flops = attn_flops + o_flops + gate_up_down_flops; + let total = layer_flops * num_layers; + Some(total) + } +} + impl From for Config { fn from(other: RawConfig) -> Self { let max_position_embeddings = other @@ -175,22 +221,21 @@ impl From for Config { .or(other.max_seq_len) .or(other.n_positions); let quantize = other.quantization_config.and_then(|q| q.quant_method); - let head_dim = other.head_dim.or_else(|| { - match (other.hidden_size, other.n_embd, other.num_attention_heads) { - (Some(hidden_size), _, Some(num_attention_heads)) - if hidden_size % num_attention_heads == 0 => - { - Some(hidden_size / num_attention_heads) - } - // Legacy - (_, Some(hidden_size), Some(num_attention_heads)) + let hidden_size = other.hidden_size.or(other.n_embd); + let head_dim = other + .head_dim + .or_else(|| match (hidden_size, other.num_attention_heads) { + (Some(hidden_size), Some(num_attention_heads)) if hidden_size % num_attention_heads == 0 => { Some(hidden_size / num_attention_heads) } _ => None, - } - }); + }); + let num_heads = other.num_attention_heads; + let num_layers = other.num_hidden_layers; + let num_kv_heads = other.num_key_value_heads.or(other.num_attention_heads); + let intermediate_size = other.intermediate_size; let model_type = other.model_type; let vision_config = other.vision_config; let is_encoder_decoder = other.is_encoder_decoder.unwrap_or(false); @@ -201,6 +246,11 @@ impl From for Config { model_type, vision_config, is_encoder_decoder, + hidden_size, + num_heads, + num_kv_heads, + intermediate_size, + num_layers, } } } @@ -1423,7 +1473,32 @@ fn spawn_shards( Ok(()) } -fn compute_type(num_shard: usize) -> Option { +#[derive(Debug)] +struct ComputeType { + count: usize, + card: String, +} + +impl ComputeType { + fn f16_flop(&self) -> Option { + match &self.card[..] { + // https://www.nvidia.com/en-us/data-center/l4/ + "nvidia-l4" => Some(121 * 10u64.pow(12)), + card => { + tracing::warn!("Unkown compute for card {card}"); + None + } + } + } +} + +impl From for OsString { + fn from(value: ComputeType) -> Self { + format!("{}-{}", value.count, value.card).into() + } +} + +fn compute_type(num_shard: usize) -> Option { let output = Command::new("nvidia-smi") .args(["--query-gpu=gpu_name", "--format=csv"]) .output() @@ -1431,8 +1506,10 @@ fn compute_type(num_shard: usize) -> Option { let output = String::from_utf8(output.stdout).ok()?; let fullname = output.split('\n').nth(1)?; let cardname = fullname.replace(' ', "-").to_lowercase(); - let compute_type = format!("{num_shard}-{cardname}"); - Some(compute_type) + Some(ComputeType { + count: num_shard, + card: cardname, + }) } fn spawn_webserver( @@ -1682,26 +1759,22 @@ fn main() -> Result<(), LauncherError> { let config: Option = get_config(&args.model_id, &args.revision).ok(); let quantize = config.as_ref().and_then(|c| c.quantize); // Quantization usually means you're even more RAM constrained. - let max_default = 4096; - let max_position_embeddings = if let Some(config) = &config { - if let Some(max_position_embeddings) = config.max_position_embeddings { - if max_position_embeddings > max_default { - max_default - } else { - max_position_embeddings - } - } else { - max_default - } - } else { - max_default - }; let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters); tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}"); std::env::set_var("PREFIX_CACHING", prefix_caching); std::env::set_var("ATTENTION", attention); + let num_shard = find_num_shards(args.sharded, args.num_shard)?; + if num_shard > 1 { + if matches!(args.quantize, Some(Quantization::Exl2)) { + return Err(LauncherError::ArgumentValidation( + "Sharding is currently not supported with `exl2` quantization".into(), + )); + } + tracing::info!("Sharding model on {num_shard} processes"); + } + let max_input_tokens = { match (args.max_input_tokens, args.max_input_length) { (Some(max_input_tokens), Some(max_input_length)) => { @@ -1721,9 +1794,19 @@ fn main() -> Result<(), LauncherError> { Some(max_batch_prefill_tokens) => max_batch_prefill_tokens, None => { // TODO figure out hardware optimal value - let value = 4096.min(max_position_embeddings as u32); + let compute_type = compute_type(num_shard); + tracing::info!("Compute type {compute_type:?}"); + tracing::info!("Config {config:?}"); + let compute_optimal = compute_optimal(config.as_ref(), compute_type.as_ref()); + let default = compute_optimal.unwrap_or(4096); + let max_position_embeddings = config.and_then(|c| c.max_position_embeddings); + let value = if let Some(max_position_embeddings) = max_position_embeddings { + default.min(max_position_embeddings) + } else { + default + }; tracing::info!("Default `max_batch_prefill_tokens` to {value}"); - value + value as u32 } } }; @@ -1778,16 +1861,6 @@ fn main() -> Result<(), LauncherError> { ); } - let num_shard = find_num_shards(args.sharded, args.num_shard)?; - if num_shard > 1 { - if matches!(args.quantize, Some(Quantization::Exl2)) { - return Err(LauncherError::ArgumentValidation( - "Sharding is currently not supported with `exl2` quantization".into(), - )); - } - tracing::info!("Sharding model on {num_shard} processes"); - } - if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens { if let Some(max_total_tokens) = max_total_tokens { if max_total_tokens as u32 > *max_batch_total_tokens {