Attempt at automatic max batch prefill.

This commit is contained in:
Nicolas Patry 2024-11-04 10:59:07 +01:00
parent a5593ba83e
commit 6624cbe7f2
No known key found for this signature in database
GPG Key ID: D2920555C90F704C
1 changed files with 115 additions and 42 deletions

View File

@ -29,6 +29,26 @@ use tracing_subscriber::{filter::LevelFilter, EnvFilter};
mod env_runtime;
mod gpu;
fn compute_optimal(config: Option<&Config>, compute: Option<&ComputeType>) -> Option<usize> {
if let (Some(config), Some(compute)) = (config, compute) {
if let (Some(f16_max_compute), Some(model_compute)) = (compute.f16_flop(), config.flop()) {
tracing::debug!("MAx compute {f16_max_compute} model compute {model_compute}");
let optimal_size = (f16_max_compute / model_compute) as usize;
if optimal_size > 100 {
// Ignore calculations that's too low
// Most likely an error
Some(optimal_size)
} else {
None
}
} else {
None
}
} else {
None
}
}
fn get_config(
model_id: &str,
revision: &Option<String>,
@ -144,7 +164,10 @@ struct RawConfig {
quantization_config: Option<QuantizationConfig>,
n_embd: Option<usize>,
hidden_size: Option<usize>,
intermediate_size: Option<usize>,
num_attention_heads: Option<usize>,
num_key_value_heads: Option<usize>,
num_hidden_layers: Option<usize>,
head_dim: Option<usize>,
vision_config: Option<VisionConfig>,
is_encoder_decoder: Option<bool>,
@ -155,19 +178,42 @@ struct QuantizationConfig {
quant_method: Option<Quantization>,
}
#[derive(Deserialize)]
#[derive(Debug, Deserialize)]
struct VisionConfig {}
#[derive(Deserialize)]
#[derive(Debug, Deserialize)]
struct Config {
max_position_embeddings: Option<usize>,
quantize: Option<Quantization>,
head_dim: Option<usize>,
num_heads: Option<usize>,
num_kv_heads: Option<usize>,
num_layers: Option<usize>,
intermediate_size: Option<usize>,
hidden_size: Option<usize>,
model_type: Option<String>,
vision_config: Option<VisionConfig>,
is_encoder_decoder: bool,
}
impl Config {
fn flop(&self) -> Option<u64> {
let num_heads = self.num_heads? as u64;
let num_kv_heads = self.num_kv_heads? as u64;
let head_dim = self.head_dim? as u64;
let hidden_size = self.hidden_size? as u64;
let intermediate_size = self.intermediate_size? as u64;
let num_layers = self.num_layers? as u64;
let attn_flops = 2 * (num_heads + 2 * num_kv_heads) * head_dim * hidden_size;
let o_flops = 2 * num_kv_heads * head_dim * hidden_size;
let gate_up_down_flops = 2 * 3 * hidden_size * intermediate_size;
let layer_flops = attn_flops + o_flops + gate_up_down_flops;
let total = layer_flops * num_layers;
Some(total)
}
}
impl From<RawConfig> for Config {
fn from(other: RawConfig) -> Self {
let max_position_embeddings = other
@ -175,22 +221,21 @@ impl From<RawConfig> for Config {
.or(other.max_seq_len)
.or(other.n_positions);
let quantize = other.quantization_config.and_then(|q| q.quant_method);
let head_dim = other.head_dim.or_else(|| {
match (other.hidden_size, other.n_embd, other.num_attention_heads) {
(Some(hidden_size), _, Some(num_attention_heads))
if hidden_size % num_attention_heads == 0 =>
{
Some(hidden_size / num_attention_heads)
}
// Legacy
(_, Some(hidden_size), Some(num_attention_heads))
let hidden_size = other.hidden_size.or(other.n_embd);
let head_dim = other
.head_dim
.or_else(|| match (hidden_size, other.num_attention_heads) {
(Some(hidden_size), Some(num_attention_heads))
if hidden_size % num_attention_heads == 0 =>
{
Some(hidden_size / num_attention_heads)
}
_ => None,
}
});
});
let num_heads = other.num_attention_heads;
let num_layers = other.num_hidden_layers;
let num_kv_heads = other.num_key_value_heads.or(other.num_attention_heads);
let intermediate_size = other.intermediate_size;
let model_type = other.model_type;
let vision_config = other.vision_config;
let is_encoder_decoder = other.is_encoder_decoder.unwrap_or(false);
@ -201,6 +246,11 @@ impl From<RawConfig> for Config {
model_type,
vision_config,
is_encoder_decoder,
hidden_size,
num_heads,
num_kv_heads,
intermediate_size,
num_layers,
}
}
}
@ -1423,7 +1473,32 @@ fn spawn_shards(
Ok(())
}
fn compute_type(num_shard: usize) -> Option<String> {
#[derive(Debug)]
struct ComputeType {
count: usize,
card: String,
}
impl ComputeType {
fn f16_flop(&self) -> Option<u64> {
match &self.card[..] {
// https://www.nvidia.com/en-us/data-center/l4/
"nvidia-l4" => Some(121 * 10u64.pow(12)),
card => {
tracing::warn!("Unkown compute for card {card}");
None
}
}
}
}
impl From<ComputeType> for OsString {
fn from(value: ComputeType) -> Self {
format!("{}-{}", value.count, value.card).into()
}
}
fn compute_type(num_shard: usize) -> Option<ComputeType> {
let output = Command::new("nvidia-smi")
.args(["--query-gpu=gpu_name", "--format=csv"])
.output()
@ -1431,8 +1506,10 @@ fn compute_type(num_shard: usize) -> Option<String> {
let output = String::from_utf8(output.stdout).ok()?;
let fullname = output.split('\n').nth(1)?;
let cardname = fullname.replace(' ', "-").to_lowercase();
let compute_type = format!("{num_shard}-{cardname}");
Some(compute_type)
Some(ComputeType {
count: num_shard,
card: cardname,
})
}
fn spawn_webserver(
@ -1682,26 +1759,22 @@ fn main() -> Result<(), LauncherError> {
let config: Option<Config> = get_config(&args.model_id, &args.revision).ok();
let quantize = config.as_ref().and_then(|c| c.quantize);
// Quantization usually means you're even more RAM constrained.
let max_default = 4096;
let max_position_embeddings = if let Some(config) = &config {
if let Some(max_position_embeddings) = config.max_position_embeddings {
if max_position_embeddings > max_default {
max_default
} else {
max_position_embeddings
}
} else {
max_default
}
} else {
max_default
};
let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters);
tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}");
std::env::set_var("PREFIX_CACHING", prefix_caching);
std::env::set_var("ATTENTION", attention);
let num_shard = find_num_shards(args.sharded, args.num_shard)?;
if num_shard > 1 {
if matches!(args.quantize, Some(Quantization::Exl2)) {
return Err(LauncherError::ArgumentValidation(
"Sharding is currently not supported with `exl2` quantization".into(),
));
}
tracing::info!("Sharding model on {num_shard} processes");
}
let max_input_tokens = {
match (args.max_input_tokens, args.max_input_length) {
(Some(max_input_tokens), Some(max_input_length)) => {
@ -1721,9 +1794,19 @@ fn main() -> Result<(), LauncherError> {
Some(max_batch_prefill_tokens) => max_batch_prefill_tokens,
None => {
// TODO figure out hardware optimal value
let value = 4096.min(max_position_embeddings as u32);
let compute_type = compute_type(num_shard);
tracing::info!("Compute type {compute_type:?}");
tracing::info!("Config {config:?}");
let compute_optimal = compute_optimal(config.as_ref(), compute_type.as_ref());
let default = compute_optimal.unwrap_or(4096);
let max_position_embeddings = config.and_then(|c| c.max_position_embeddings);
let value = if let Some(max_position_embeddings) = max_position_embeddings {
default.min(max_position_embeddings)
} else {
default
};
tracing::info!("Default `max_batch_prefill_tokens` to {value}");
value
value as u32
}
}
};
@ -1778,16 +1861,6 @@ fn main() -> Result<(), LauncherError> {
);
}
let num_shard = find_num_shards(args.sharded, args.num_shard)?;
if num_shard > 1 {
if matches!(args.quantize, Some(Quantization::Exl2)) {
return Err(LauncherError::ArgumentValidation(
"Sharding is currently not supported with `exl2` quantization".into(),
));
}
tracing::info!("Sharding model on {num_shard} processes");
}
if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
if let Some(max_total_tokens) = max_total_tokens {
if max_total_tokens as u32 > *max_batch_total_tokens {