Attempt at automatic max batch prefill.
This commit is contained in:
parent
a5593ba83e
commit
6624cbe7f2
|
@ -29,6 +29,26 @@ use tracing_subscriber::{filter::LevelFilter, EnvFilter};
|
|||
mod env_runtime;
|
||||
mod gpu;
|
||||
|
||||
fn compute_optimal(config: Option<&Config>, compute: Option<&ComputeType>) -> Option<usize> {
|
||||
if let (Some(config), Some(compute)) = (config, compute) {
|
||||
if let (Some(f16_max_compute), Some(model_compute)) = (compute.f16_flop(), config.flop()) {
|
||||
tracing::debug!("MAx compute {f16_max_compute} model compute {model_compute}");
|
||||
let optimal_size = (f16_max_compute / model_compute) as usize;
|
||||
if optimal_size > 100 {
|
||||
// Ignore calculations that's too low
|
||||
// Most likely an error
|
||||
Some(optimal_size)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn get_config(
|
||||
model_id: &str,
|
||||
revision: &Option<String>,
|
||||
|
@ -144,7 +164,10 @@ struct RawConfig {
|
|||
quantization_config: Option<QuantizationConfig>,
|
||||
n_embd: Option<usize>,
|
||||
hidden_size: Option<usize>,
|
||||
intermediate_size: Option<usize>,
|
||||
num_attention_heads: Option<usize>,
|
||||
num_key_value_heads: Option<usize>,
|
||||
num_hidden_layers: Option<usize>,
|
||||
head_dim: Option<usize>,
|
||||
vision_config: Option<VisionConfig>,
|
||||
is_encoder_decoder: Option<bool>,
|
||||
|
@ -155,19 +178,42 @@ struct QuantizationConfig {
|
|||
quant_method: Option<Quantization>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct VisionConfig {}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Config {
|
||||
max_position_embeddings: Option<usize>,
|
||||
quantize: Option<Quantization>,
|
||||
head_dim: Option<usize>,
|
||||
num_heads: Option<usize>,
|
||||
num_kv_heads: Option<usize>,
|
||||
num_layers: Option<usize>,
|
||||
intermediate_size: Option<usize>,
|
||||
hidden_size: Option<usize>,
|
||||
model_type: Option<String>,
|
||||
vision_config: Option<VisionConfig>,
|
||||
is_encoder_decoder: bool,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
fn flop(&self) -> Option<u64> {
|
||||
let num_heads = self.num_heads? as u64;
|
||||
let num_kv_heads = self.num_kv_heads? as u64;
|
||||
let head_dim = self.head_dim? as u64;
|
||||
let hidden_size = self.hidden_size? as u64;
|
||||
let intermediate_size = self.intermediate_size? as u64;
|
||||
let num_layers = self.num_layers? as u64;
|
||||
|
||||
let attn_flops = 2 * (num_heads + 2 * num_kv_heads) * head_dim * hidden_size;
|
||||
let o_flops = 2 * num_kv_heads * head_dim * hidden_size;
|
||||
let gate_up_down_flops = 2 * 3 * hidden_size * intermediate_size;
|
||||
let layer_flops = attn_flops + o_flops + gate_up_down_flops;
|
||||
let total = layer_flops * num_layers;
|
||||
Some(total)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RawConfig> for Config {
|
||||
fn from(other: RawConfig) -> Self {
|
||||
let max_position_embeddings = other
|
||||
|
@ -175,22 +221,21 @@ impl From<RawConfig> for Config {
|
|||
.or(other.max_seq_len)
|
||||
.or(other.n_positions);
|
||||
let quantize = other.quantization_config.and_then(|q| q.quant_method);
|
||||
let head_dim = other.head_dim.or_else(|| {
|
||||
match (other.hidden_size, other.n_embd, other.num_attention_heads) {
|
||||
(Some(hidden_size), _, Some(num_attention_heads))
|
||||
if hidden_size % num_attention_heads == 0 =>
|
||||
{
|
||||
Some(hidden_size / num_attention_heads)
|
||||
}
|
||||
// Legacy
|
||||
(_, Some(hidden_size), Some(num_attention_heads))
|
||||
let hidden_size = other.hidden_size.or(other.n_embd);
|
||||
let head_dim = other
|
||||
.head_dim
|
||||
.or_else(|| match (hidden_size, other.num_attention_heads) {
|
||||
(Some(hidden_size), Some(num_attention_heads))
|
||||
if hidden_size % num_attention_heads == 0 =>
|
||||
{
|
||||
Some(hidden_size / num_attention_heads)
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
});
|
||||
});
|
||||
let num_heads = other.num_attention_heads;
|
||||
let num_layers = other.num_hidden_layers;
|
||||
let num_kv_heads = other.num_key_value_heads.or(other.num_attention_heads);
|
||||
let intermediate_size = other.intermediate_size;
|
||||
let model_type = other.model_type;
|
||||
let vision_config = other.vision_config;
|
||||
let is_encoder_decoder = other.is_encoder_decoder.unwrap_or(false);
|
||||
|
@ -201,6 +246,11 @@ impl From<RawConfig> for Config {
|
|||
model_type,
|
||||
vision_config,
|
||||
is_encoder_decoder,
|
||||
hidden_size,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
intermediate_size,
|
||||
num_layers,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1423,7 +1473,32 @@ fn spawn_shards(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn compute_type(num_shard: usize) -> Option<String> {
|
||||
#[derive(Debug)]
|
||||
struct ComputeType {
|
||||
count: usize,
|
||||
card: String,
|
||||
}
|
||||
|
||||
impl ComputeType {
|
||||
fn f16_flop(&self) -> Option<u64> {
|
||||
match &self.card[..] {
|
||||
// https://www.nvidia.com/en-us/data-center/l4/
|
||||
"nvidia-l4" => Some(121 * 10u64.pow(12)),
|
||||
card => {
|
||||
tracing::warn!("Unkown compute for card {card}");
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ComputeType> for OsString {
|
||||
fn from(value: ComputeType) -> Self {
|
||||
format!("{}-{}", value.count, value.card).into()
|
||||
}
|
||||
}
|
||||
|
||||
fn compute_type(num_shard: usize) -> Option<ComputeType> {
|
||||
let output = Command::new("nvidia-smi")
|
||||
.args(["--query-gpu=gpu_name", "--format=csv"])
|
||||
.output()
|
||||
|
@ -1431,8 +1506,10 @@ fn compute_type(num_shard: usize) -> Option<String> {
|
|||
let output = String::from_utf8(output.stdout).ok()?;
|
||||
let fullname = output.split('\n').nth(1)?;
|
||||
let cardname = fullname.replace(' ', "-").to_lowercase();
|
||||
let compute_type = format!("{num_shard}-{cardname}");
|
||||
Some(compute_type)
|
||||
Some(ComputeType {
|
||||
count: num_shard,
|
||||
card: cardname,
|
||||
})
|
||||
}
|
||||
|
||||
fn spawn_webserver(
|
||||
|
@ -1682,26 +1759,22 @@ fn main() -> Result<(), LauncherError> {
|
|||
let config: Option<Config> = get_config(&args.model_id, &args.revision).ok();
|
||||
let quantize = config.as_ref().and_then(|c| c.quantize);
|
||||
// Quantization usually means you're even more RAM constrained.
|
||||
let max_default = 4096;
|
||||
|
||||
let max_position_embeddings = if let Some(config) = &config {
|
||||
if let Some(max_position_embeddings) = config.max_position_embeddings {
|
||||
if max_position_embeddings > max_default {
|
||||
max_default
|
||||
} else {
|
||||
max_position_embeddings
|
||||
}
|
||||
} else {
|
||||
max_default
|
||||
}
|
||||
} else {
|
||||
max_default
|
||||
};
|
||||
let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters);
|
||||
tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}");
|
||||
std::env::set_var("PREFIX_CACHING", prefix_caching);
|
||||
std::env::set_var("ATTENTION", attention);
|
||||
|
||||
let num_shard = find_num_shards(args.sharded, args.num_shard)?;
|
||||
if num_shard > 1 {
|
||||
if matches!(args.quantize, Some(Quantization::Exl2)) {
|
||||
return Err(LauncherError::ArgumentValidation(
|
||||
"Sharding is currently not supported with `exl2` quantization".into(),
|
||||
));
|
||||
}
|
||||
tracing::info!("Sharding model on {num_shard} processes");
|
||||
}
|
||||
|
||||
let max_input_tokens = {
|
||||
match (args.max_input_tokens, args.max_input_length) {
|
||||
(Some(max_input_tokens), Some(max_input_length)) => {
|
||||
|
@ -1721,9 +1794,19 @@ fn main() -> Result<(), LauncherError> {
|
|||
Some(max_batch_prefill_tokens) => max_batch_prefill_tokens,
|
||||
None => {
|
||||
// TODO figure out hardware optimal value
|
||||
let value = 4096.min(max_position_embeddings as u32);
|
||||
let compute_type = compute_type(num_shard);
|
||||
tracing::info!("Compute type {compute_type:?}");
|
||||
tracing::info!("Config {config:?}");
|
||||
let compute_optimal = compute_optimal(config.as_ref(), compute_type.as_ref());
|
||||
let default = compute_optimal.unwrap_or(4096);
|
||||
let max_position_embeddings = config.and_then(|c| c.max_position_embeddings);
|
||||
let value = if let Some(max_position_embeddings) = max_position_embeddings {
|
||||
default.min(max_position_embeddings)
|
||||
} else {
|
||||
default
|
||||
};
|
||||
tracing::info!("Default `max_batch_prefill_tokens` to {value}");
|
||||
value
|
||||
value as u32
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -1778,16 +1861,6 @@ fn main() -> Result<(), LauncherError> {
|
|||
);
|
||||
}
|
||||
|
||||
let num_shard = find_num_shards(args.sharded, args.num_shard)?;
|
||||
if num_shard > 1 {
|
||||
if matches!(args.quantize, Some(Quantization::Exl2)) {
|
||||
return Err(LauncherError::ArgumentValidation(
|
||||
"Sharding is currently not supported with `exl2` quantization".into(),
|
||||
));
|
||||
}
|
||||
tracing::info!("Sharding model on {num_shard} processes");
|
||||
}
|
||||
|
||||
if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
|
||||
if let Some(max_total_tokens) = max_total_tokens {
|
||||
if max_total_tokens as u32 > *max_batch_total_tokens {
|
||||
|
|
Loading…
Reference in New Issue