Easier defaults for models stemmed from configs.

This commit is contained in:
Nicolas Patry 2024-04-11 12:48:39 +00:00
parent 10d9083b2d
commit b83aab9bb3
5 changed files with 153 additions and 44 deletions

2
Cargo.lock generated
View File

@ -3452,7 +3452,9 @@ dependencies = [
"clap", "clap",
"ctrlc", "ctrlc",
"float_eq", "float_eq",
"hf-hub",
"nix", "nix",
"once_cell",
"reqwest", "reqwest",
"serde", "serde",
"serde_json", "serde_json",

View File

@ -60,9 +60,9 @@ Options:
[env: QUANTIZE=] [env: QUANTIZE=]
Possible values: Possible values:
- awq: 4 bit quantization. Requires a specific AWQ quantized model: https://hf.co/models?search=awq. Should replace GPTQ models wherever possible because of the better latency - awq: 4 bit quantization. Requires a specific AWQ quantized model: <https://hf.co/models?search=awq>. Should replace GPTQ models wherever possible because of the better latency
- eetq: 8 bit quantization, doesn't require specific model. Should be a drop-in replacement to bitsandbytes with much better performance. Kernels are from https://github.com/NetEase-FuXi/EETQ.git - eetq: 8 bit quantization, doesn't require specific model. Should be a drop-in replacement to bitsandbytes with much better performance. Kernels are from <https://github.com/NetEase-FuXi/EETQ.git>
- gptq: 4 bit quantization. Requires a specific GTPQ quantized model: https://hf.co/models?search=gptq. text-generation-inference will use exllama (faster) kernels wherever possible, and use triton kernel (wider support) when it's not. AWQ has faster kernels - gptq: 4 bit quantization. Requires a specific GTPQ quantized model: <https://hf.co/models?search=gptq>. text-generation-inference will use exllama (faster) kernels wherever possible, and use triton kernel (wider support) when it's not. AWQ has faster kernels
- bitsandbytes: Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, but it is known that the model will be much slower to run than the native f16 - bitsandbytes: Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, but it is known that the model will be much slower to run than the native f16
- bitsandbytes-nf4: Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, but it is known that the model will be much slower to run than the native f16 - bitsandbytes-nf4: Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, but it is known that the model will be much slower to run than the native f16
- bitsandbytes-fp4: Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better perplexity performance for you model - bitsandbytes-fp4: Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better perplexity performance for you model
@ -128,23 +128,29 @@ Options:
[env: MAX_TOP_N_TOKENS=] [env: MAX_TOP_N_TOKENS=]
[default: 5] [default: 5]
```
## MAX_INPUT_TOKENS
```shell
--max-input-tokens <MAX_INPUT_TOKENS>
This is the maximum allowed input length (expressed in number of tokens) for users. The larger this value, the longer prompt users can send which can impact the overall memory required to handle the load. Please note that some models have a finite range of sequence they can handle. Default to min(max_position_embeddings - 1, 13383)
[env: MAX_INPUT_TOKENS=]
``` ```
## MAX_INPUT_LENGTH ## MAX_INPUT_LENGTH
```shell ```shell
--max-input-length <MAX_INPUT_LENGTH> --max-input-length <MAX_INPUT_LENGTH>
This is the maximum allowed input length (expressed in number of tokens) for users. The larger this value, the longer prompt users can send which can impact the overall memory required to handle the load. Please note that some models have a finite range of sequence they can handle Legacy version of [`Args::max_input_tokens`]
[env: MAX_INPUT_LENGTH=] [env: MAX_INPUT_LENGTH=]
[default: 1024]
``` ```
## MAX_TOTAL_TOKENS ## MAX_TOTAL_TOKENS
```shell ```shell
--max-total-tokens <MAX_TOTAL_TOKENS> --max-total-tokens <MAX_TOTAL_TOKENS>
This is the most important value to set as it defines the "memory budget" of running clients requests. Clients will send input sequences and ask to generate `max_new_tokens` on top. with a value of `1512` users can send either a prompt of `1000` and ask for `512` new tokens, or send a prompt of `1` and ask for `1511` max_new_tokens. The larger this value, the larger amount each request will be in your RAM and the less effective batching can be This is the most important value to set as it defines the "memory budget" of running clients requests. Clients will send input sequences and ask to generate `max_new_tokens` on top. with a value of `1512` users can send either a prompt of `1000` and ask for `512` new tokens, or send a prompt of `1` and ask for `1511` max_new_tokens. The larger this value, the larger amount each request will be in your RAM and the less effective batching can be. Default to min(max_position_embeddings - 1, 16384)
[env: MAX_TOTAL_TOKENS=] [env: MAX_TOTAL_TOKENS=]
[default: 2048]
``` ```
## WAITING_SERVED_RATIO ## WAITING_SERVED_RATIO
@ -164,7 +170,6 @@ Options:
Limits the number of tokens for the prefill operation. Since this operation take the most memory and is compute bound, it is interesting to limit the number of requests that can be sent Limits the number of tokens for the prefill operation. Since this operation take the most memory and is compute bound, it is interesting to limit the number of requests that can be sent
[env: MAX_BATCH_PREFILL_TOKENS=] [env: MAX_BATCH_PREFILL_TOKENS=]
[default: 4096]
``` ```
## MAX_BATCH_TOTAL_TOKENS ## MAX_BATCH_TOTAL_TOKENS

View File

@ -9,7 +9,9 @@ homepage.workspace = true
[dependencies] [dependencies]
clap = { version = "4.4.5", features = ["derive", "env"] } clap = { version = "4.4.5", features = ["derive", "env"] }
ctrlc = { version = "3.4.1", features = ["termination"] } ctrlc = { version = "3.4.1", features = ["termination"] }
hf-hub = "0.3.2"
nix = { version = "0.28.0", features = ["signal"] } nix = { version = "0.28.0", features = ["signal"] }
once_cell = "1.19.0"
serde = { version = "1.0.188", features = ["derive"] } serde = { version = "1.0.188", features = ["derive"] }
serde_json = "1.0.107" serde_json = "1.0.107"
tracing = "0.1.37" tracing = "0.1.37"

View File

@ -22,14 +22,14 @@ mod env_runtime;
#[derive(Clone, Copy, Debug, ValueEnum)] #[derive(Clone, Copy, Debug, ValueEnum)]
enum Quantization { enum Quantization {
/// 4 bit quantization. Requires a specific AWQ quantized model: /// 4 bit quantization. Requires a specific AWQ quantized model:
/// https://hf.co/models?search=awq. /// <https://hf.co/models?search=awq>.
/// Should replace GPTQ models wherever possible because of the better latency /// Should replace GPTQ models wherever possible because of the better latency
Awq, Awq,
/// 8 bit quantization, doesn't require specific model. /// 8 bit quantization, doesn't require specific model.
/// Should be a drop-in replacement to bitsandbytes with much better performance. /// Should be a drop-in replacement to bitsandbytes with much better performance.
/// Kernels are from https://github.com/NetEase-FuXi/EETQ.git /// Kernels are from <https://github.com/NetEase-FuXi/EETQ.git>
Eetq, Eetq,
/// 4 bit quantization. Requires a specific GTPQ quantized model: https://hf.co/models?search=gptq. /// 4 bit quantization. Requires a specific GTPQ quantized model: <https://hf.co/models?search=gptq>.
/// text-generation-inference will use exllama (faster) kernels wherever possible, and use /// text-generation-inference will use exllama (faster) kernels wherever possible, and use
/// triton kernel (wider support) when it's not. /// triton kernel (wider support) when it's not.
/// AWQ has faster kernels. /// AWQ has faster kernels.
@ -206,8 +206,13 @@ struct Args {
/// for users. The larger this value, the longer prompt users can send which /// for users. The larger this value, the longer prompt users can send which
/// can impact the overall memory required to handle the load. /// can impact the overall memory required to handle the load.
/// Please note that some models have a finite range of sequence they can handle. /// Please note that some models have a finite range of sequence they can handle.
#[clap(default_value = "1024", long, env)] /// Default to min(max_position_embeddings - 1, 13383)
max_input_length: usize, #[clap(long, env)]
max_input_tokens: Option<usize>,
/// Legacy version of [`Args::max_input_tokens`].
#[clap(long, env)]
max_input_length: Option<usize>,
/// This is the most important value to set as it defines the "memory budget" /// This is the most important value to set as it defines the "memory budget"
/// of running clients requests. /// of running clients requests.
@ -217,8 +222,9 @@ struct Args {
/// `1511` max_new_tokens. /// `1511` max_new_tokens.
/// The larger this value, the larger amount each request will be in your RAM /// The larger this value, the larger amount each request will be in your RAM
/// and the less effective batching can be. /// and the less effective batching can be.
#[clap(default_value = "2048", long, env)] /// Default to min(max_position_embeddings - 1, 16384)
max_total_tokens: usize, #[clap(long, env)]
max_total_tokens: Option<usize>,
/// This represents the ratio of waiting queries vs running queries where /// This represents the ratio of waiting queries vs running queries where
/// you want to start considering pausing the running queries to include the waiting /// you want to start considering pausing the running queries to include the waiting
@ -236,8 +242,8 @@ struct Args {
/// Limits the number of tokens for the prefill operation. /// Limits the number of tokens for the prefill operation.
/// Since this operation take the most memory and is compute bound, it is interesting /// Since this operation take the most memory and is compute bound, it is interesting
/// to limit the number of requests that can be sent. /// to limit the number of requests that can be sent.
#[clap(default_value = "4096", long, env)] #[clap(long, env)]
max_batch_prefill_tokens: u32, max_batch_prefill_tokens: Option<u32>,
/// **IMPORTANT** This is one critical control to allow maximum usage /// **IMPORTANT** This is one critical control to allow maximum usage
/// of the available hardware. /// of the available hardware.
@ -1045,6 +1051,9 @@ fn compute_type(num_shard: usize) -> Option<String> {
fn spawn_webserver( fn spawn_webserver(
num_shard: usize, num_shard: usize,
args: Args, args: Args,
max_input_tokens: usize,
max_total_tokens: usize,
max_batch_prefill_tokens: u32,
shutdown: Arc<AtomicBool>, shutdown: Arc<AtomicBool>,
shutdown_receiver: &mpsc::Receiver<()>, shutdown_receiver: &mpsc::Receiver<()>,
) -> Result<Child, LauncherError> { ) -> Result<Child, LauncherError> {
@ -1060,12 +1069,12 @@ fn spawn_webserver(
args.max_stop_sequences.to_string(), args.max_stop_sequences.to_string(),
"--max-top-n-tokens".to_string(), "--max-top-n-tokens".to_string(),
args.max_top_n_tokens.to_string(), args.max_top_n_tokens.to_string(),
"--max-input-length".to_string(), "--max-input-tokens".to_string(),
args.max_input_length.to_string(), max_input_tokens.to_string(),
"--max-total-tokens".to_string(), "--max-total-tokens".to_string(),
args.max_total_tokens.to_string(), max_total_tokens.to_string(),
"--max-batch-prefill-tokens".to_string(), "--max-batch-prefill-tokens".to_string(),
args.max_batch_prefill_tokens.to_string(), max_batch_prefill_tokens.to_string(),
"--waiting-served-ratio".to_string(), "--waiting-served-ratio".to_string(),
args.waiting_served_ratio.to_string(), args.waiting_served_ratio.to_string(),
"--max-waiting-tokens".to_string(), "--max-waiting-tokens".to_string(),
@ -1253,16 +1262,99 @@ fn main() -> Result<(), LauncherError> {
tracing::info!("{:?}", args); tracing::info!("{:?}", args);
// Validate args use hf_hub::{api::sync::Api, Repo, RepoType};
if args.max_input_length >= args.max_total_tokens {
#[derive(Deserialize)]
struct Config {
max_position_embeddings: usize,
}
let config: Config = {
let model_id = args.model_id.clone();
let mut path = std::path::Path::new(&args.model_id).to_path_buf();
let filename = if !path.exists() {
// Assume it's a hub id
let api = Api::new().unwrap();
let repo = if let Some(ref revision) = args.revision {
api.repo(Repo::with_revision(
model_id,
RepoType::Model,
revision.to_string(),
))
} else {
api.model(model_id)
};
repo.get("config.json").unwrap()
} else {
path.push("config.json");
path
};
let content = std::fs::read_to_string(filename).unwrap();
let config: Config = serde_json::from_str(&content).unwrap();
let max_default = 2usize.pow(14);
let max_position_embeddings = if config.max_position_embeddings > max_default {
let max = config.max_position_embeddings;
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max - 1, max - 1);
max_default
} else {
config.max_position_embeddings
};
Config {
max_position_embeddings,
}
};
let max_input_tokens = {
match (args.max_input_tokens, args.max_input_length) {
(Some(max_input_tokens), Some(max_input_length)) => {
return Err(LauncherError::ArgumentValidation( return Err(LauncherError::ArgumentValidation(
"`max_input_length` must be < `max_total_tokens`".to_string(), format!("Both `max_input_tokens` ({max_input_tokens}) and `max_input_length` ({max_input_length}) are set. Please define only `max_input_tokens` as `max_input_length is deprecated for naming consistency.",
)));
}
(Some(max_input_tokens), None) | (None, Some(max_input_tokens)) => max_input_tokens,
(None, None) => {
let value = config.max_position_embeddings - 1;
tracing::info!("Default `max_input_tokens` to {value}");
value
}
}
};
let max_total_tokens = {
match args.max_total_tokens {
Some(max_total_tokens) => max_total_tokens,
None => {
let value = config.max_position_embeddings;
tracing::info!("Default `max_total_tokens` to {value}");
value
}
}
};
let max_batch_prefill_tokens = {
// TODO get config.
match args.max_batch_prefill_tokens {
Some(max_batch_prefill_tokens) => max_batch_prefill_tokens,
None => {
let value = config.max_position_embeddings as u32 - 1;
tracing::info!("Default `max_batch_prefill_tokens` to {value}");
value
}
}
};
// Validate args
if max_input_tokens >= max_total_tokens {
return Err(LauncherError::ArgumentValidation(
"`max_input_tokens must be < `max_total_tokens`".to_string(),
)); ));
} }
if args.max_input_length as u32 > args.max_batch_prefill_tokens { if max_input_tokens as u32 > max_batch_prefill_tokens {
return Err(LauncherError::ArgumentValidation(format!( return Err(LauncherError::ArgumentValidation(format!(
"`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {} and {}", "`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {} and {}",
args.max_batch_prefill_tokens, args.max_input_length max_batch_prefill_tokens, max_input_tokens
))); )));
} }
@ -1284,16 +1376,16 @@ fn main() -> Result<(), LauncherError> {
} }
if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens { if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
if args.max_batch_prefill_tokens > *max_batch_total_tokens { if max_batch_prefill_tokens > *max_batch_total_tokens {
return Err(LauncherError::ArgumentValidation(format!( return Err(LauncherError::ArgumentValidation(format!(
"`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}", "`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
args.max_batch_prefill_tokens, max_batch_total_tokens max_batch_prefill_tokens, max_batch_total_tokens
))); )));
} }
if args.max_total_tokens as u32 > *max_batch_total_tokens { if max_total_tokens as u32 > *max_batch_total_tokens {
return Err(LauncherError::ArgumentValidation(format!( return Err(LauncherError::ArgumentValidation(format!(
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}", "`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
args.max_total_tokens, max_batch_total_tokens max_total_tokens, max_batch_total_tokens
))); )));
} }
} }
@ -1354,7 +1446,15 @@ fn main() -> Result<(), LauncherError> {
return Ok(()); return Ok(());
} }
let mut webserver = spawn_webserver(num_shard, args, shutdown.clone(), &shutdown_receiver) let mut webserver = spawn_webserver(
num_shard,
args,
max_input_tokens,
max_total_tokens,
max_batch_prefill_tokens,
shutdown.clone(),
&shutdown_receiver,
)
.map_err(|err| { .map_err(|err| {
shutdown_shards(shutdown.clone(), &shutdown_receiver); shutdown_shards(shutdown.clone(), &shutdown_receiver);
err err

View File

@ -35,7 +35,7 @@ struct Args {
#[clap(default_value = "5", long, env)] #[clap(default_value = "5", long, env)]
max_top_n_tokens: u32, max_top_n_tokens: u32,
#[clap(default_value = "1024", long, env)] #[clap(default_value = "1024", long, env)]
max_input_length: usize, max_input_tokens: usize,
#[clap(default_value = "2048", long, env)] #[clap(default_value = "2048", long, env)]
max_total_tokens: usize, max_total_tokens: usize,
#[clap(default_value = "1.2", long, env)] #[clap(default_value = "1.2", long, env)]
@ -90,7 +90,7 @@ async fn main() -> Result<(), RouterError> {
max_best_of, max_best_of,
max_stop_sequences, max_stop_sequences,
max_top_n_tokens, max_top_n_tokens,
max_input_length, max_input_tokens,
max_total_tokens, max_total_tokens,
waiting_served_ratio, waiting_served_ratio,
max_batch_prefill_tokens, max_batch_prefill_tokens,
@ -118,13 +118,13 @@ async fn main() -> Result<(), RouterError> {
init_logging(otlp_endpoint, json_output); init_logging(otlp_endpoint, json_output);
// Validate args // Validate args
if max_input_length >= max_total_tokens { if max_input_tokens >= max_total_tokens {
return Err(RouterError::ArgumentValidation( return Err(RouterError::ArgumentValidation(
"`max_input_length` must be < `max_total_tokens`".to_string(), "`max_input_tokens` must be < `max_total_tokens`".to_string(),
)); ));
} }
if max_input_length as u32 > max_batch_prefill_tokens { if max_input_tokens as u32 > max_batch_prefill_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {max_batch_prefill_tokens} and {max_input_length}"))); return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
} }
if validation_workers == 0 { if validation_workers == 0 {
@ -311,7 +311,7 @@ async fn main() -> Result<(), RouterError> {
tracing::info!("Warming up model"); tracing::info!("Warming up model");
let max_supported_batch_total_tokens = match sharded_client let max_supported_batch_total_tokens = match sharded_client
.warmup( .warmup(
max_input_length as u32, max_input_tokens as u32,
max_batch_prefill_tokens, max_batch_prefill_tokens,
max_total_tokens as u32, max_total_tokens as u32,
max_batch_size, max_batch_size,
@ -374,7 +374,7 @@ async fn main() -> Result<(), RouterError> {
max_best_of, max_best_of,
max_stop_sequences, max_stop_sequences,
max_top_n_tokens, max_top_n_tokens,
max_input_length, max_input_tokens,
max_total_tokens, max_total_tokens,
waiting_served_ratio, waiting_served_ratio,
max_batch_prefill_tokens, max_batch_prefill_tokens,