Improve the defaults for the launcher (#1727)
# What does this PR do? - Renamed `max_input_length` into `max_input_tokens` for consistency (backward compatible change, will yell if both are set.) - Will now use the config for `max_input_tokens` `max_total_token` and `max_batch_total_tokens`. - Capping the values to 16k in order to save VRAM on behalf of users (overriddable by simply setting the values). <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil -->
This commit is contained in:
parent
9d8f21cace
commit
1b2670c823
|
@ -3452,7 +3452,9 @@ dependencies = [
|
|||
"clap",
|
||||
"ctrlc",
|
||||
"float_eq",
|
||||
"hf-hub",
|
||||
"nix",
|
||||
"once_cell",
|
||||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
|
|
|
@ -60,9 +60,9 @@ Options:
|
|||
[env: QUANTIZE=]
|
||||
|
||||
Possible values:
|
||||
- awq: 4 bit quantization. Requires a specific AWQ quantized model: https://hf.co/models?search=awq. Should replace GPTQ models wherever possible because of the better latency
|
||||
- eetq: 8 bit quantization, doesn't require specific model. Should be a drop-in replacement to bitsandbytes with much better performance. Kernels are from https://github.com/NetEase-FuXi/EETQ.git
|
||||
- gptq: 4 bit quantization. Requires a specific GTPQ quantized model: https://hf.co/models?search=gptq. text-generation-inference will use exllama (faster) kernels wherever possible, and use triton kernel (wider support) when it's not. AWQ has faster kernels
|
||||
- awq: 4 bit quantization. Requires a specific AWQ quantized model: <https://hf.co/models?search=awq>. Should replace GPTQ models wherever possible because of the better latency
|
||||
- eetq: 8 bit quantization, doesn't require specific model. Should be a drop-in replacement to bitsandbytes with much better performance. Kernels are from <https://github.com/NetEase-FuXi/EETQ.git>
|
||||
- gptq: 4 bit quantization. Requires a specific GTPQ quantized model: <https://hf.co/models?search=gptq>. text-generation-inference will use exllama (faster) kernels wherever possible, and use triton kernel (wider support) when it's not. AWQ has faster kernels
|
||||
- bitsandbytes: Bitsandbytes 8bit. Can be applied on any model, will cut the memory requirement in half, but it is known that the model will be much slower to run than the native f16
|
||||
- bitsandbytes-nf4: Bitsandbytes 4bit. Can be applied on any model, will cut the memory requirement by 4x, but it is known that the model will be much slower to run than the native f16
|
||||
- bitsandbytes-fp4: Bitsandbytes 4bit. nf4 should be preferred in most cases but maybe this one has better perplexity performance for you model
|
||||
|
@ -129,23 +129,29 @@ Options:
|
|||
[env: MAX_TOP_N_TOKENS=]
|
||||
[default: 5]
|
||||
|
||||
```
|
||||
## MAX_INPUT_TOKENS
|
||||
```shell
|
||||
--max-input-tokens <MAX_INPUT_TOKENS>
|
||||
This is the maximum allowed input length (expressed in number of tokens) for users. The larger this value, the longer prompt users can send which can impact the overall memory required to handle the load. Please note that some models have a finite range of sequence they can handle. Default to min(max_position_embeddings - 1, 4095)
|
||||
|
||||
[env: MAX_INPUT_TOKENS=]
|
||||
|
||||
```
|
||||
## MAX_INPUT_LENGTH
|
||||
```shell
|
||||
--max-input-length <MAX_INPUT_LENGTH>
|
||||
This is the maximum allowed input length (expressed in number of tokens) for users. The larger this value, the longer prompt users can send which can impact the overall memory required to handle the load. Please note that some models have a finite range of sequence they can handle
|
||||
Legacy version of [`Args::max_input_tokens`]
|
||||
|
||||
[env: MAX_INPUT_LENGTH=]
|
||||
[default: 1024]
|
||||
|
||||
```
|
||||
## MAX_TOTAL_TOKENS
|
||||
```shell
|
||||
--max-total-tokens <MAX_TOTAL_TOKENS>
|
||||
This is the most important value to set as it defines the "memory budget" of running clients requests. Clients will send input sequences and ask to generate `max_new_tokens` on top. with a value of `1512` users can send either a prompt of `1000` and ask for `512` new tokens, or send a prompt of `1` and ask for `1511` max_new_tokens. The larger this value, the larger amount each request will be in your RAM and the less effective batching can be
|
||||
This is the most important value to set as it defines the "memory budget" of running clients requests. Clients will send input sequences and ask to generate `max_new_tokens` on top. with a value of `1512` users can send either a prompt of `1000` and ask for `512` new tokens, or send a prompt of `1` and ask for `1511` max_new_tokens. The larger this value, the larger amount each request will be in your RAM and the less effective batching can be. Default to min(max_position_embeddings, 4096)
|
||||
|
||||
[env: MAX_TOTAL_TOKENS=]
|
||||
[default: 2048]
|
||||
|
||||
```
|
||||
## WAITING_SERVED_RATIO
|
||||
|
@ -162,10 +168,9 @@ Options:
|
|||
## MAX_BATCH_PREFILL_TOKENS
|
||||
```shell
|
||||
--max-batch-prefill-tokens <MAX_BATCH_PREFILL_TOKENS>
|
||||
Limits the number of tokens for the prefill operation. Since this operation take the most memory and is compute bound, it is interesting to limit the number of requests that can be sent
|
||||
Limits the number of tokens for the prefill operation. Since this operation take the most memory and is compute bound, it is interesting to limit the number of requests that can be sent. Default to `max_input_tokens + 50` to give a bit of room
|
||||
|
||||
[env: MAX_BATCH_PREFILL_TOKENS=]
|
||||
[default: 4096]
|
||||
|
||||
```
|
||||
## MAX_BATCH_TOTAL_TOKENS
|
||||
|
@ -210,10 +215,9 @@ Options:
|
|||
## CUDA_GRAPHS
|
||||
```shell
|
||||
--cuda-graphs <CUDA_GRAPHS>
|
||||
Specify the batch sizes to compute cuda graphs for. Use "0" to disable
|
||||
Specify the batch sizes to compute cuda graphs for. Use "0" to disable. Default = "1,2,4,8,16,32"
|
||||
|
||||
[env: CUDA_GRAPHS=]
|
||||
[default: 1,2,4,8,16,32,64,96,128]
|
||||
|
||||
```
|
||||
## HOSTNAME
|
||||
|
|
|
@ -3,7 +3,7 @@ import pytest
|
|||
|
||||
@pytest.fixture(scope="module")
|
||||
def t5_sharded_handle(launcher):
|
||||
with launcher("google/flan-t5-xxl", num_shard=2) as handle:
|
||||
with launcher("google/flan-t5-xxl", num_shard=4) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
|
|
|
@ -9,7 +9,9 @@ homepage.workspace = true
|
|||
[dependencies]
|
||||
clap = { version = "4.4.5", features = ["derive", "env"] }
|
||||
ctrlc = { version = "3.4.1", features = ["termination"] }
|
||||
hf-hub = "0.3.2"
|
||||
nix = { version = "0.28.0", features = ["signal"] }
|
||||
once_cell = "1.19.0"
|
||||
serde = { version = "1.0.188", features = ["derive"] }
|
||||
serde_json = "1.0.107"
|
||||
tracing = "0.1.37"
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
use clap::{Parser, ValueEnum};
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use nix::sys::signal::{self, Signal};
|
||||
use nix::unistd::Pid;
|
||||
use serde::Deserialize;
|
||||
|
@ -19,17 +20,23 @@ use tracing_subscriber::EnvFilter;
|
|||
|
||||
mod env_runtime;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Config {
|
||||
max_position_embeddings: Option<usize>,
|
||||
max_seq_len: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum Quantization {
|
||||
/// 4 bit quantization. Requires a specific AWQ quantized model:
|
||||
/// https://hf.co/models?search=awq.
|
||||
/// <https://hf.co/models?search=awq>.
|
||||
/// Should replace GPTQ models wherever possible because of the better latency
|
||||
Awq,
|
||||
/// 8 bit quantization, doesn't require specific model.
|
||||
/// Should be a drop-in replacement to bitsandbytes with much better performance.
|
||||
/// Kernels are from https://github.com/NetEase-FuXi/EETQ.git
|
||||
/// Kernels are from <https://github.com/NetEase-FuXi/EETQ.git>
|
||||
Eetq,
|
||||
/// 4 bit quantization. Requires a specific GTPQ quantized model: https://hf.co/models?search=gptq.
|
||||
/// 4 bit quantization. Requires a specific GTPQ quantized model: <https://hf.co/models?search=gptq>.
|
||||
/// text-generation-inference will use exllama (faster) kernels wherever possible, and use
|
||||
/// triton kernel (wider support) when it's not.
|
||||
/// AWQ has faster kernels.
|
||||
|
@ -214,8 +221,13 @@ struct Args {
|
|||
/// for users. The larger this value, the longer prompt users can send which
|
||||
/// can impact the overall memory required to handle the load.
|
||||
/// Please note that some models have a finite range of sequence they can handle.
|
||||
#[clap(default_value = "1024", long, env)]
|
||||
max_input_length: usize,
|
||||
/// Default to min(max_position_embeddings - 1, 4095)
|
||||
#[clap(long, env)]
|
||||
max_input_tokens: Option<usize>,
|
||||
|
||||
/// Legacy version of [`Args::max_input_tokens`].
|
||||
#[clap(long, env)]
|
||||
max_input_length: Option<usize>,
|
||||
|
||||
/// This is the most important value to set as it defines the "memory budget"
|
||||
/// of running clients requests.
|
||||
|
@ -225,8 +237,9 @@ struct Args {
|
|||
/// `1511` max_new_tokens.
|
||||
/// The larger this value, the larger amount each request will be in your RAM
|
||||
/// and the less effective batching can be.
|
||||
#[clap(default_value = "2048", long, env)]
|
||||
max_total_tokens: usize,
|
||||
/// Default to min(max_position_embeddings, 4096)
|
||||
#[clap(long, env)]
|
||||
max_total_tokens: Option<usize>,
|
||||
|
||||
/// This represents the ratio of waiting queries vs running queries where
|
||||
/// you want to start considering pausing the running queries to include the waiting
|
||||
|
@ -244,8 +257,9 @@ struct Args {
|
|||
/// Limits the number of tokens for the prefill operation.
|
||||
/// Since this operation take the most memory and is compute bound, it is interesting
|
||||
/// to limit the number of requests that can be sent.
|
||||
#[clap(default_value = "4096", long, env)]
|
||||
max_batch_prefill_tokens: u32,
|
||||
/// Default to `max_input_tokens + 50` to give a bit of room.
|
||||
#[clap(long, env)]
|
||||
max_batch_prefill_tokens: Option<u32>,
|
||||
|
||||
/// **IMPORTANT** This is one critical control to allow maximum usage
|
||||
/// of the available hardware.
|
||||
|
@ -294,13 +308,9 @@ struct Args {
|
|||
|
||||
/// Specify the batch sizes to compute cuda graphs for.
|
||||
/// Use "0" to disable.
|
||||
#[clap(
|
||||
long,
|
||||
env,
|
||||
value_delimiter = ',',
|
||||
default_value = "1,2,4,8,16,32,64,96,128"
|
||||
)]
|
||||
cuda_graphs: Vec<usize>,
|
||||
/// Default = "1,2,4,8,16,32"
|
||||
#[clap(long, env, value_delimiter = ',')]
|
||||
cuda_graphs: Option<Vec<usize>>,
|
||||
|
||||
/// The IP address to listen on
|
||||
#[clap(default_value = "0.0.0.0", long, env)]
|
||||
|
@ -808,6 +818,14 @@ enum LauncherError {
|
|||
WebserverCannotStart,
|
||||
}
|
||||
|
||||
impl core::fmt::Display for LauncherError {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
|
||||
write!(f, "{self:?}")
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for LauncherError {}
|
||||
|
||||
fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), LauncherError> {
|
||||
// Enter download tracing span
|
||||
let _span = tracing::span!(tracing::Level::INFO, "download").entered();
|
||||
|
@ -944,6 +962,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
|
|||
fn spawn_shards(
|
||||
num_shard: usize,
|
||||
args: &Args,
|
||||
cuda_graphs: Vec<usize>,
|
||||
shutdown: Arc<AtomicBool>,
|
||||
shutdown_receiver: &mpsc::Receiver<()>,
|
||||
shutdown_sender: mpsc::Sender<()>,
|
||||
|
@ -971,11 +990,7 @@ fn spawn_shards(
|
|||
let disable_custom_kernels = args.disable_custom_kernels;
|
||||
let watermark_gamma = args.watermark_gamma;
|
||||
let watermark_delta = args.watermark_delta;
|
||||
let cuda_graphs: Vec<usize> = args
|
||||
.cuda_graphs
|
||||
.iter()
|
||||
.filter_map(|&c| if c > 0 { Some(c) } else { None })
|
||||
.collect();
|
||||
let cuda_graphs_clone = cuda_graphs.clone();
|
||||
let cuda_memory_fraction = args.cuda_memory_fraction;
|
||||
let rope_scaling = args.rope_scaling;
|
||||
let rope_factor = args.rope_factor;
|
||||
|
@ -997,7 +1012,7 @@ fn spawn_shards(
|
|||
disable_custom_kernels,
|
||||
watermark_gamma,
|
||||
watermark_delta,
|
||||
cuda_graphs,
|
||||
cuda_graphs_clone,
|
||||
cuda_memory_fraction,
|
||||
rope_scaling,
|
||||
rope_factor,
|
||||
|
@ -1053,6 +1068,9 @@ fn compute_type(num_shard: usize) -> Option<String> {
|
|||
fn spawn_webserver(
|
||||
num_shard: usize,
|
||||
args: Args,
|
||||
max_input_tokens: usize,
|
||||
max_total_tokens: usize,
|
||||
max_batch_prefill_tokens: u32,
|
||||
shutdown: Arc<AtomicBool>,
|
||||
shutdown_receiver: &mpsc::Receiver<()>,
|
||||
) -> Result<Child, LauncherError> {
|
||||
|
@ -1068,12 +1086,12 @@ fn spawn_webserver(
|
|||
args.max_stop_sequences.to_string(),
|
||||
"--max-top-n-tokens".to_string(),
|
||||
args.max_top_n_tokens.to_string(),
|
||||
"--max-input-length".to_string(),
|
||||
args.max_input_length.to_string(),
|
||||
"--max-input-tokens".to_string(),
|
||||
max_input_tokens.to_string(),
|
||||
"--max-total-tokens".to_string(),
|
||||
args.max_total_tokens.to_string(),
|
||||
max_total_tokens.to_string(),
|
||||
"--max-batch-prefill-tokens".to_string(),
|
||||
args.max_batch_prefill_tokens.to_string(),
|
||||
max_batch_prefill_tokens.to_string(),
|
||||
"--waiting-served-ratio".to_string(),
|
||||
args.waiting_served_ratio.to_string(),
|
||||
"--max-waiting-tokens".to_string(),
|
||||
|
@ -1251,18 +1269,128 @@ fn main() -> Result<(), LauncherError> {
|
|||
|
||||
tracing::info!("{:?}", args);
|
||||
|
||||
// Validate args
|
||||
if args.max_input_length >= args.max_total_tokens {
|
||||
return Err(LauncherError::ArgumentValidation(
|
||||
"`max_input_length` must be < `max_total_tokens`".to_string(),
|
||||
));
|
||||
let get_max_position_embeddings = || -> Result<usize, Box<dyn std::error::Error>> {
|
||||
let model_id = args.model_id.clone();
|
||||
let mut path = std::path::Path::new(&args.model_id).to_path_buf();
|
||||
let filename = if !path.exists() {
|
||||
// Assume it's a hub id
|
||||
let api = Api::new()?;
|
||||
let repo = if let Some(ref revision) = args.revision {
|
||||
api.repo(Repo::with_revision(
|
||||
model_id,
|
||||
RepoType::Model,
|
||||
revision.to_string(),
|
||||
))
|
||||
} else {
|
||||
api.model(model_id)
|
||||
};
|
||||
repo.get("config.json")?
|
||||
} else {
|
||||
path.push("config.json");
|
||||
path
|
||||
};
|
||||
|
||||
let content = std::fs::read_to_string(filename)?;
|
||||
let config: Config = serde_json::from_str(&content)?;
|
||||
|
||||
// Quantization usually means you're even more RAM constrained.
|
||||
let max_default = 4096;
|
||||
|
||||
let max_position_embeddings = match (config.max_position_embeddings, config.max_seq_len) {
|
||||
(Some(max_position_embeddings), _) | (None, Some(max_position_embeddings)) => {
|
||||
if max_position_embeddings > max_default {
|
||||
let max = max_position_embeddings;
|
||||
tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1);
|
||||
max_default
|
||||
} else {
|
||||
max_position_embeddings
|
||||
}
|
||||
if args.max_input_length as u32 > args.max_batch_prefill_tokens {
|
||||
return Err(LauncherError::ArgumentValidation(format!(
|
||||
"`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {} and {}",
|
||||
args.max_batch_prefill_tokens, args.max_input_length
|
||||
}
|
||||
_ => {
|
||||
return Err(Box::new(LauncherError::ArgumentValidation(
|
||||
"no max defined".to_string(),
|
||||
)));
|
||||
}
|
||||
};
|
||||
Ok(max_position_embeddings)
|
||||
};
|
||||
let max_position_embeddings: usize = get_max_position_embeddings().unwrap_or(4096);
|
||||
|
||||
let max_input_tokens = {
|
||||
match (args.max_input_tokens, args.max_input_length) {
|
||||
(Some(max_input_tokens), Some(max_input_length)) => {
|
||||
return Err(LauncherError::ArgumentValidation(
|
||||
format!("Both `max_input_tokens` ({max_input_tokens}) and `max_input_length` ({max_input_length}) are set. Please define only `max_input_tokens` as `max_input_length is deprecated for naming consistency.",
|
||||
)));
|
||||
}
|
||||
(Some(max_input_tokens), None) | (None, Some(max_input_tokens)) => max_input_tokens,
|
||||
(None, None) => {
|
||||
let value = max_position_embeddings - 1;
|
||||
tracing::info!("Default `max_input_tokens` to {value}");
|
||||
value
|
||||
}
|
||||
}
|
||||
};
|
||||
let max_total_tokens = {
|
||||
match args.max_total_tokens {
|
||||
Some(max_total_tokens) => max_total_tokens,
|
||||
None => {
|
||||
let value = max_position_embeddings;
|
||||
tracing::info!("Default `max_total_tokens` to {value}");
|
||||
value
|
||||
}
|
||||
}
|
||||
};
|
||||
let max_batch_prefill_tokens = {
|
||||
match args.max_batch_prefill_tokens {
|
||||
Some(max_batch_prefill_tokens) => max_batch_prefill_tokens,
|
||||
None => {
|
||||
let value: u32 = if let Some(max_batch_size) = args.max_batch_size {
|
||||
max_batch_size * max_input_tokens
|
||||
} else {
|
||||
// Adding some edge in order to account for potential block_size alignement
|
||||
// issue.
|
||||
max_input_tokens + 50
|
||||
} as u32;
|
||||
tracing::info!("Default `max_batch_prefill_tokens` to {value}");
|
||||
value
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Validate args
|
||||
if max_input_tokens >= max_total_tokens {
|
||||
return Err(LauncherError::ArgumentValidation(
|
||||
"`max_input_tokens must be < `max_total_tokens`".to_string(),
|
||||
));
|
||||
}
|
||||
if max_input_tokens as u32 > max_batch_prefill_tokens {
|
||||
return Err(LauncherError::ArgumentValidation(format!(
|
||||
"`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {} and {}",
|
||||
max_batch_prefill_tokens, max_input_tokens
|
||||
)));
|
||||
}
|
||||
|
||||
let cuda_graphs = match (&args.cuda_graphs, &args.quantize) {
|
||||
(Some(cuda_graphs), Some(_q)) => cuda_graphs.clone(),
|
||||
#[allow(deprecated)]
|
||||
(
|
||||
None,
|
||||
Some(
|
||||
Quantization::Bitsandbytes
|
||||
| Quantization::BitsandbytesNF4
|
||||
| Quantization::BitsandbytesFP4,
|
||||
),
|
||||
) => {
|
||||
tracing::info!("Bitsandbytes doesn't work with cuda graphs, deactivating them");
|
||||
vec![]
|
||||
}
|
||||
_ => {
|
||||
let cuda_graphs = vec![1, 2, 4, 8, 16, 32];
|
||||
tracing::info!("Using default cuda graphs {cuda_graphs:?}");
|
||||
cuda_graphs
|
||||
}
|
||||
};
|
||||
|
||||
if args.validation_workers == 0 {
|
||||
return Err(LauncherError::ArgumentValidation(
|
||||
|
@ -1282,16 +1410,16 @@ fn main() -> Result<(), LauncherError> {
|
|||
}
|
||||
|
||||
if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
|
||||
if args.max_batch_prefill_tokens > *max_batch_total_tokens {
|
||||
if max_batch_prefill_tokens > *max_batch_total_tokens {
|
||||
return Err(LauncherError::ArgumentValidation(format!(
|
||||
"`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
|
||||
args.max_batch_prefill_tokens, max_batch_total_tokens
|
||||
max_batch_prefill_tokens, max_batch_total_tokens
|
||||
)));
|
||||
}
|
||||
if args.max_total_tokens as u32 > *max_batch_total_tokens {
|
||||
if max_total_tokens as u32 > *max_batch_total_tokens {
|
||||
return Err(LauncherError::ArgumentValidation(format!(
|
||||
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
|
||||
args.max_total_tokens, max_batch_total_tokens
|
||||
max_total_tokens, max_batch_total_tokens
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
@ -1338,6 +1466,7 @@ fn main() -> Result<(), LauncherError> {
|
|||
spawn_shards(
|
||||
num_shard,
|
||||
&args,
|
||||
cuda_graphs,
|
||||
shutdown.clone(),
|
||||
&shutdown_receiver,
|
||||
shutdown_sender,
|
||||
|
@ -1352,7 +1481,15 @@ fn main() -> Result<(), LauncherError> {
|
|||
return Ok(());
|
||||
}
|
||||
|
||||
let mut webserver = spawn_webserver(num_shard, args, shutdown.clone(), &shutdown_receiver)
|
||||
let mut webserver = spawn_webserver(
|
||||
num_shard,
|
||||
args,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
max_batch_prefill_tokens,
|
||||
shutdown.clone(),
|
||||
&shutdown_receiver,
|
||||
)
|
||||
.map_err(|err| {
|
||||
shutdown_shards(shutdown.clone(), &shutdown_receiver);
|
||||
err
|
||||
|
|
|
@ -35,7 +35,7 @@ struct Args {
|
|||
#[clap(default_value = "5", long, env)]
|
||||
max_top_n_tokens: u32,
|
||||
#[clap(default_value = "1024", long, env)]
|
||||
max_input_length: usize,
|
||||
max_input_tokens: usize,
|
||||
#[clap(default_value = "2048", long, env)]
|
||||
max_total_tokens: usize,
|
||||
#[clap(default_value = "1.2", long, env)]
|
||||
|
@ -90,7 +90,7 @@ async fn main() -> Result<(), RouterError> {
|
|||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
max_input_length,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
waiting_served_ratio,
|
||||
max_batch_prefill_tokens,
|
||||
|
@ -118,13 +118,13 @@ async fn main() -> Result<(), RouterError> {
|
|||
init_logging(otlp_endpoint, json_output);
|
||||
|
||||
// Validate args
|
||||
if max_input_length >= max_total_tokens {
|
||||
if max_input_tokens >= max_total_tokens {
|
||||
return Err(RouterError::ArgumentValidation(
|
||||
"`max_input_length` must be < `max_total_tokens`".to_string(),
|
||||
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
|
||||
));
|
||||
}
|
||||
if max_input_length as u32 > max_batch_prefill_tokens {
|
||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {max_batch_prefill_tokens} and {max_input_length}")));
|
||||
if max_input_tokens as u32 > max_batch_prefill_tokens {
|
||||
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
|
||||
}
|
||||
|
||||
if validation_workers == 0 {
|
||||
|
@ -311,7 +311,7 @@ async fn main() -> Result<(), RouterError> {
|
|||
tracing::info!("Warming up model");
|
||||
let max_supported_batch_total_tokens = match sharded_client
|
||||
.warmup(
|
||||
max_input_length as u32,
|
||||
max_input_tokens as u32,
|
||||
max_batch_prefill_tokens,
|
||||
max_total_tokens as u32,
|
||||
max_batch_size,
|
||||
|
@ -374,7 +374,7 @@ async fn main() -> Result<(), RouterError> {
|
|||
max_best_of,
|
||||
max_stop_sequences,
|
||||
max_top_n_tokens,
|
||||
max_input_length,
|
||||
max_input_tokens,
|
||||
max_total_tokens,
|
||||
waiting_served_ratio,
|
||||
max_batch_prefill_tokens,
|
||||
|
@ -402,12 +402,15 @@ async fn main() -> Result<(), RouterError> {
|
|||
/// - otlp_endpoint is an optional URL to an Open Telemetry collector
|
||||
/// - LOG_LEVEL may be TRACE, DEBUG, INFO, WARN or ERROR (default to INFO)
|
||||
/// - LOG_FORMAT may be TEXT or JSON (default to TEXT)
|
||||
/// - LOG_COLORIZE may be "false" or "true" (default to "true" or ansi supported platforms)
|
||||
fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
|
||||
let mut layers = Vec::new();
|
||||
|
||||
// STDOUT/STDERR layer
|
||||
let ansi = std::env::var("LOG_COLORIZE") != Ok("1".to_string());
|
||||
let fmt_layer = tracing_subscriber::fmt::layer()
|
||||
.with_file(true)
|
||||
.with_ansi(ansi)
|
||||
.with_line_number(true);
|
||||
|
||||
let fmt_layer = match json_output {
|
||||
|
|
|
@ -190,12 +190,14 @@ impl State {
|
|||
token_budget: u32,
|
||||
) -> Option<NextBatch> {
|
||||
if self.entries.is_empty() {
|
||||
tracing::debug!("No queue");
|
||||
return None;
|
||||
}
|
||||
|
||||
// Check if we have enough entries
|
||||
if let Some(min_size) = min_size {
|
||||
if self.entries.len() < min_size {
|
||||
tracing::debug!("Not enough entries");
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
@ -222,6 +224,7 @@ impl State {
|
|||
// was dropped by the client)
|
||||
if entry.response_tx.is_closed() {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
||||
tracing::debug!("Dropping entry");
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -258,10 +261,12 @@ impl State {
|
|||
{
|
||||
// Entry is over budget
|
||||
// Add it back to the front
|
||||
tracing::debug!("Over budget: prefill_tokens={prefill_tokens} > {prefill_token_budget} || {prefill_tokens} + {decode_tokens} + {} > {token_budget}", self.speculate);
|
||||
self.entries.push_front((id, entry));
|
||||
break;
|
||||
}
|
||||
|
||||
tracing::debug!("Accepting entry");
|
||||
// Create a new span to link the batch back to this entry
|
||||
let entry_batch_span = info_span!(parent: &entry.span, "infer");
|
||||
// Add relationships
|
||||
|
@ -292,6 +297,7 @@ impl State {
|
|||
|
||||
// Empty batch
|
||||
if batch_requests.is_empty() {
|
||||
tracing::debug!("Filterered out all entries");
|
||||
return None;
|
||||
}
|
||||
|
||||
|
|
|
@ -161,14 +161,17 @@ impl Validation {
|
|||
} else {
|
||||
return Err(ValidationError::UnsetMaxNewTokens);
|
||||
};
|
||||
let input_length = truncate.unwrap_or(self.max_input_length);
|
||||
let mut input_length = truncate.unwrap_or(self.max_input_length);
|
||||
|
||||
// We don't have a tokenizer, therefore we have no idea how long is the query, let
|
||||
// them through and hope for the best.
|
||||
// Validate MaxNewTokens
|
||||
if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 {
|
||||
return Err(ValidationError::MaxNewTokens(
|
||||
self.max_total_tokens - self.max_input_length,
|
||||
max_new_tokens,
|
||||
));
|
||||
input_length = input_length.saturating_sub(max_new_tokens as usize);
|
||||
// return Err(ValidationError::MaxNewTokens(
|
||||
// self.max_total_tokens - self.max_input_length,
|
||||
// max_new_tokens,
|
||||
// ));
|
||||
}
|
||||
|
||||
Ok((inputs, input_length, max_new_tokens))
|
||||
|
@ -664,8 +667,9 @@ mod tests {
|
|||
.validate_input("Hello".to_string(), None, Some(max_new_tokens))
|
||||
.await
|
||||
{
|
||||
Err(ValidationError::MaxNewTokens(1, 10)) => (),
|
||||
_ => panic!("Unexpected not max new tokens"),
|
||||
// Err(ValidationError::MaxNewTokens(1, 10)) => (),
|
||||
Ok((_s, 0, 10)) => (),
|
||||
r => panic!("Unexpected not max new tokens: {r:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue