Choosing input/total tokens automatically based on available VRAM? (#2673)

* Choosing input/total tokens automatically based on available VRAM?

* Update doc.

* Remove generated files.

* Trying to fix non chunking targets.

* Attempt #2

* fix.

* QuantLinear is rocm compatible.

* Much simpler logic after the overhead.

* Updating logic + non flash.

* Revert doc text.

* Simple updates.

* Fix integration mt0 (transformers update).
This commit is contained in:
Nicolas Patry 2024-10-28 04:59:49 +01:00 committed by GitHub
parent 2e4f4ba1bb
commit 0c9b6cdd76
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 285 additions and 136 deletions

2
.gitignore vendored
View File

@ -5,6 +5,8 @@ router/tokenizer.json
backends/v2/src/client/pb backends/v2/src/client/pb
backends/v3/src/client/pb backends/v3/src/client/pb
backends/client/src/v2/pb
backends/client/src/v3/pb
# ROCm auto-generated files # ROCm auto-generated files
*.hip *.hip

View File

@ -107,20 +107,22 @@ impl Client {
#[instrument(skip_all)] #[instrument(skip_all)]
pub async fn warmup( pub async fn warmup(
&mut self, &mut self,
max_input_length: u32, max_input_tokens: Option<u32>,
max_prefill_tokens: u32, max_prefill_tokens: u32,
max_total_tokens: u32, max_total_tokens: Option<u32>,
max_batch_size: Option<usize>, max_batch_size: Option<usize>,
) -> Result<Option<u32>> { ) -> Result<(Option<u32>, u32, u32)> {
let mut n_tokens = 0; let mut n_tokens = 0;
let mut requests = Vec::new(); let mut requests = Vec::new();
// Create requests // Create requests
while n_tokens < max_prefill_tokens { while n_tokens < max_prefill_tokens {
let truncate = min(max_input_length, max_prefill_tokens - n_tokens); let mut truncate = max_prefill_tokens - n_tokens;
if let Some(max_input_tokens) = max_input_tokens {
truncate = min(max_input_tokens, truncate);
}
let mut input_chunks = Vec::new(); let mut input_chunks = Vec::new();
input_chunks input_chunks.push(Chunk::Text("_test ".to_string().repeat(truncate as usize)).into());
.push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into());
if n_tokens == 0 { if n_tokens == 0 {
input_chunks.push( input_chunks.push(
Chunk::Image(Image { Chunk::Image(Image {
@ -136,7 +138,7 @@ impl Client {
// been updated to support chunks. // been updated to support chunks.
let mut inputs = String::new(); let mut inputs = String::new();
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize)); inputs.push_str(&"_test ".to_string().repeat(truncate as usize));
if n_tokens == 0 { if n_tokens == 0 {
// 1 request is enough to test vision heads. // 1 request is enough to test vision heads.
// Sending images on other queries messes up easily with truncation. // Sending images on other queries messes up easily with truncation.
@ -145,6 +147,12 @@ impl Client {
)); ));
} }
let max_new_tokens = if let Some(max_total_tokens) = max_total_tokens {
max_total_tokens - truncate
} else {
1
};
requests.push(Request { requests.push(Request {
id: 0, id: 0,
inputs, inputs,
@ -175,7 +183,7 @@ impl Client {
grammar_type: GrammarType::None as i32, grammar_type: GrammarType::None as i32,
}), }),
stopping_parameters: Some(StoppingCriteriaParameters { stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: max_total_tokens - truncate, max_new_tokens,
stop_sequences: vec![], stop_sequences: vec![],
ignore_eos_token: true, ignore_eos_token: true,
}), }),
@ -183,7 +191,7 @@ impl Client {
top_n_tokens: 20, top_n_tokens: 20,
adapter_id: None, adapter_id: None,
}); });
n_tokens += max_input_length; n_tokens += truncate;
// Check max_batch_size // Check max_batch_size
if Some(requests.len()) == max_batch_size { if Some(requests.len()) == max_batch_size {
@ -195,19 +203,23 @@ impl Client {
id: 0, id: 0,
size: requests.len() as u32, size: requests.len() as u32,
requests, requests,
max_tokens: max_input_length, max_tokens: max_input_tokens.unwrap_or(0),
max_blocks: 0, max_blocks: 0,
}; };
let request = tonic::Request::new(WarmupRequest { let request = tonic::Request::new(WarmupRequest {
batch: Some(batch), batch: Some(batch),
max_input_length, max_input_tokens,
max_prefill_tokens, max_prefill_tokens,
max_total_tokens, max_total_tokens,
}) })
.inject_context(); .inject_context();
let response = self.stub.warmup(request).await?.into_inner(); let response = self.stub.warmup(request).await?.into_inner();
Ok(response.max_supported_total_tokens) Ok((
response.max_supported_total_tokens,
response.max_input_tokens,
response.max_total_tokens,
))
} }
/// Generate one token for each request in the given batch /// Generate one token for each request in the given batch

View File

@ -101,11 +101,11 @@ impl ShardedClient {
#[instrument(skip(self))] #[instrument(skip(self))]
pub async fn warmup( pub async fn warmup(
&mut self, &mut self,
max_input_length: u32, max_input_length: Option<u32>,
max_prefill_tokens: u32, max_prefill_tokens: u32,
max_total_tokens: u32, max_total_tokens: Option<u32>,
max_batch_size: Option<usize>, max_batch_size: Option<usize>,
) -> Result<Option<u32>> { ) -> Result<(Option<u32>, u32, u32)> {
let futures: Vec<_> = self let futures: Vec<_> = self
.clients .clients
.iter_mut() .iter_mut()
@ -122,8 +122,16 @@ impl ShardedClient {
let results = join_all(futures) let results = join_all(futures)
.await .await
.into_iter() .into_iter()
.collect::<Result<Vec<Option<u32>>>>()?; .collect::<Result<Vec<(Option<u32>, u32, u32)>>>()?;
Ok(results.into_iter().flatten().min())
// Take the minimum value
// Different shards hold different parts of vocab, might yield
// different available block size.
let min = results
.iter()
.min()
.expect("Expect at least 1 warmup result");
Ok(*min)
} }
/// Generate one token for each request in the given batch /// Generate one token for each request in the given batch

View File

@ -108,20 +108,22 @@ impl Client {
#[instrument(skip_all)] #[instrument(skip_all)]
pub async fn warmup( pub async fn warmup(
&mut self, &mut self,
max_input_length: u32, max_input_tokens: Option<u32>,
max_prefill_tokens: u32, max_prefill_tokens: u32,
max_total_tokens: u32, max_total_tokens: Option<u32>,
max_batch_size: Option<usize>, max_batch_size: Option<usize>,
) -> Result<Option<u32>> { ) -> Result<(Option<u32>, u32, u32)> {
let mut n_tokens = 0; let mut n_tokens = 0;
let mut requests = Vec::new(); let mut requests = Vec::new();
// Create requests // Create requests
while n_tokens < max_prefill_tokens { while n_tokens < max_prefill_tokens {
let truncate = min(max_input_length, max_prefill_tokens - n_tokens); let mut truncate = max_prefill_tokens - n_tokens;
if let Some(max_input_tokens) = max_input_tokens {
truncate = min(max_input_tokens, truncate);
}
let mut input_chunks = Vec::new(); let mut input_chunks = Vec::new();
input_chunks input_chunks.push(Chunk::Text("_test ".to_string().repeat(truncate as usize)).into());
.push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into());
if n_tokens == 0 { if n_tokens == 0 {
input_chunks.push( input_chunks.push(
Chunk::Image(Image { Chunk::Image(Image {
@ -137,7 +139,7 @@ impl Client {
// been updated to support chunks. // been updated to support chunks.
let mut inputs = String::new(); let mut inputs = String::new();
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize)); inputs.push_str(&"_test ".to_string().repeat(truncate as usize));
if n_tokens == 0 { if n_tokens == 0 {
// 1 request is enough to test vision heads. // 1 request is enough to test vision heads.
// Sending images on other queries messes up easily with truncation. // Sending images on other queries messes up easily with truncation.
@ -146,6 +148,12 @@ impl Client {
)); ));
} }
let max_new_tokens = if let Some(max_total_tokens) = max_total_tokens {
max_total_tokens - truncate
} else {
1
};
requests.push(Request { requests.push(Request {
id: 0, id: 0,
inputs, inputs,
@ -175,7 +183,7 @@ impl Client {
grammar_type: GrammarType::None as i32, grammar_type: GrammarType::None as i32,
}), }),
stopping_parameters: Some(StoppingCriteriaParameters { stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: max_total_tokens - truncate, max_new_tokens,
stop_sequences: vec![], stop_sequences: vec![],
ignore_eos_token: true, ignore_eos_token: true,
}), }),
@ -183,7 +191,7 @@ impl Client {
top_n_tokens: 20, top_n_tokens: 20,
adapter_id: None, adapter_id: None,
}); });
n_tokens += max_input_length; n_tokens += truncate;
// Check max_batch_size // Check max_batch_size
if Some(requests.len()) == max_batch_size { if Some(requests.len()) == max_batch_size {
@ -195,19 +203,23 @@ impl Client {
id: 0, id: 0,
size: requests.len() as u32, size: requests.len() as u32,
requests, requests,
max_tokens: max_input_length, max_tokens: max_input_tokens.unwrap_or(0),
max_blocks: 0, max_blocks: 0,
}; };
let request = tonic::Request::new(WarmupRequest { let request = tonic::Request::new(WarmupRequest {
batch: Some(batch), batch: Some(batch),
max_input_length, max_input_tokens,
max_prefill_tokens, max_prefill_tokens,
max_total_tokens, max_total_tokens,
}) })
.inject_context(); .inject_context();
let response = self.stub.warmup(request).await?.into_inner(); let response = self.stub.warmup(request).await?.into_inner();
Ok(response.max_supported_total_tokens) Ok((
response.max_supported_total_tokens,
response.max_input_tokens,
response.max_total_tokens,
))
} }
/// Generate one token for each request in the given batch /// Generate one token for each request in the given batch

View File

@ -102,11 +102,11 @@ impl ShardedClient {
#[instrument(skip(self))] #[instrument(skip(self))]
pub async fn warmup( pub async fn warmup(
&mut self, &mut self,
max_input_length: u32, max_input_length: Option<u32>,
max_prefill_tokens: u32, max_prefill_tokens: u32,
max_total_tokens: u32, max_total_tokens: Option<u32>,
max_batch_size: Option<usize>, max_batch_size: Option<usize>,
) -> Result<Option<u32>> { ) -> Result<(Option<u32>, u32, u32)> {
let futures: Vec<_> = self let futures: Vec<_> = self
.clients .clients
.iter_mut() .iter_mut()
@ -119,12 +119,19 @@ impl ShardedClient {
)) ))
}) })
.collect(); .collect();
// Take the minimum value
let results = join_all(futures) let results = join_all(futures)
.await .await
.into_iter() .into_iter()
.collect::<Result<Vec<Option<u32>>>>()?; .collect::<Result<Vec<(Option<u32>, u32, u32)>>>()?;
Ok(results.into_iter().flatten().min())
// Take the minimum value
// Different shards hold different parts of vocab, might yield
// different available block size.
let min = results
.iter()
.min()
.expect("Expect at least 1 warmup result");
Ok(*min)
} }
/// Generate one token for each request in the given batch /// Generate one token for each request in the given batch

View File

@ -37,12 +37,17 @@ pub struct BackendInfo {
pub attention_impl: String, pub attention_impl: String,
#[schema(example = "1")] #[schema(example = "1")]
pub block_size: u32, pub block_size: u32,
#[schema(example = "30000")]
pub max_input_tokens: usize,
#[schema(example = "32000")]
pub max_total_tokens: usize,
} }
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub async fn connect_backend( pub async fn connect_backend(
max_input_tokens: usize, max_input_tokens: Option<usize>,
max_total_tokens: usize, max_total_tokens: Option<usize>,
master_shard_uds_path: String, master_shard_uds_path: String,
waiting_served_ratio: f32, waiting_served_ratio: f32,
max_batch_prefill_tokens: u32, max_batch_prefill_tokens: u32,
@ -51,14 +56,32 @@ pub async fn connect_backend(
max_batch_size: Option<usize>, max_batch_size: Option<usize>,
) -> Result<(BackendV3, BackendInfo), V3Error> { ) -> Result<(BackendV3, BackendInfo), V3Error> {
// Helper function // Helper function
let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option<u32>| { let check_max_batch_total_tokens = |(
max_supported_batch_total_tokens,
shard_max_input_tokens,
shard_max_total_tokens,
): (Option<u32>, u32, u32)|
-> Result<(u32, usize, usize), V3Error> {
if let Some(max_input_tokens) = max_input_tokens {
assert_eq!(max_input_tokens as u32, shard_max_input_tokens);
}
if let Some(max_total_tokens) = max_total_tokens {
assert_eq!(max_total_tokens as u32, shard_max_total_tokens);
}
match max_supported_batch_total_tokens { match max_supported_batch_total_tokens {
// Older models do not support automatic max-batch-total-tokens // Older models do not support automatic max-batch-total-tokens
None => { None => {
let max_batch_total_tokens = max_batch_total_tokens let max_batch_total_tokens = max_batch_total_tokens.unwrap_or(
.unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens))); 16000
.max(shard_max_total_tokens)
.max(max_batch_prefill_tokens),
);
tracing::warn!("Model does not support automatic max batch total tokens"); tracing::warn!("Model does not support automatic max batch total tokens");
Ok(max_batch_total_tokens) Ok((
max_batch_total_tokens,
shard_max_input_tokens as usize,
shard_max_total_tokens as usize,
))
} }
// Flash attention models return their max supported total tokens // Flash attention models return their max supported total tokens
Some(max_supported_batch_total_tokens) => { Some(max_supported_batch_total_tokens) => {
@ -72,11 +95,15 @@ pub async fn connect_backend(
"Inferred max batch total tokens: {max_supported_batch_total_tokens}" "Inferred max batch total tokens: {max_supported_batch_total_tokens}"
); );
} }
if max_total_tokens as u32 > max_supported_batch_total_tokens { if shard_max_total_tokens > max_supported_batch_total_tokens {
return Err(V3Error::NotEnoughMemory(max_total_tokens)); return Err(V3Error::NotEnoughMemory(shard_max_total_tokens as usize));
} }
Ok(max_supported_batch_total_tokens) Ok((
max_supported_batch_total_tokens,
shard_max_input_tokens as usize,
shard_max_total_tokens as usize,
))
} }
} }
}; };
@ -96,23 +123,25 @@ pub async fn connect_backend(
// Warmup model // Warmup model
tracing::info!("Warming up model"); tracing::info!("Warming up model");
let max_batch_total_tokens = check_max_batch_total_tokens( let answer = sharded_client
sharded_client .warmup(
.warmup( max_input_tokens.map(|p| p as u32),
max_input_tokens as u32, max_batch_prefill_tokens,
max_batch_prefill_tokens, max_total_tokens.map(|p| p as u32),
max_total_tokens as u32, max_batch_size,
max_batch_size, )
) .await
.await .map_err(V3Error::Warmup)?;
.map_err(V3Error::Warmup)?, let (max_batch_total_tokens, max_input_tokens, max_total_tokens) =
)?; check_max_batch_total_tokens(answer)?;
tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}"); tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}");
metrics::gauge!("tgi_batch_max_total_tokens").set(max_batch_total_tokens); metrics::gauge!("tgi_batch_max_total_tokens").set(max_batch_total_tokens);
let backend_info = BackendInfo { let backend_info = BackendInfo {
waiting_served_ratio, waiting_served_ratio,
max_batch_total_tokens, max_batch_total_tokens,
max_input_tokens,
max_total_tokens,
max_waiting_tokens, max_waiting_tokens,
max_batch_size, max_batch_size,
model_device_type: shard_info.device_type.clone(), model_device_type: shard_info.device_type.clone(),

View File

@ -18,10 +18,10 @@ struct Args {
max_stop_sequences: usize, max_stop_sequences: usize,
#[clap(default_value = "5", long, env)] #[clap(default_value = "5", long, env)]
max_top_n_tokens: u32, max_top_n_tokens: u32,
#[clap(default_value = "1024", long, env)] #[clap(long, env)]
max_input_tokens: usize, max_input_tokens: Option<usize>,
#[clap(default_value = "2048", long, env)] #[clap(long, env)]
max_total_tokens: usize, max_total_tokens: Option<usize>,
#[clap(default_value = "1.2", long, env)] #[clap(default_value = "1.2", long, env)]
waiting_served_ratio: f32, waiting_served_ratio: f32,
#[clap(default_value = "4096", long, env)] #[clap(default_value = "4096", long, env)]
@ -126,12 +126,6 @@ async fn main() -> Result<(), RouterError> {
text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output); text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output);
// Validate args // Validate args
if max_input_tokens >= max_total_tokens {
return Err(RouterError::ArgumentValidation(
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
));
}
if validation_workers == 0 { if validation_workers == 0 {
return Err(RouterError::ArgumentValidation( return Err(RouterError::ArgumentValidation(
"`validation_workers` must be > 0".to_string(), "`validation_workers` must be > 0".to_string(),
@ -160,6 +154,28 @@ async fn main() -> Result<(), RouterError> {
// Validate remaining args now that the backend is known // Validate remaining args now that the backend is known
let support_chunking = backend_info.support_chunking; let support_chunking = backend_info.support_chunking;
let max_batch_total_tokens = backend_info.max_batch_total_tokens; let max_batch_total_tokens = backend_info.max_batch_total_tokens;
if max_input_tokens.is_none() {
tracing::info!(
"Maximum input tokens defaulted to {}",
backend_info.max_input_tokens
);
}
if max_total_tokens.is_none() {
tracing::info!(
"Maximum total tokens defaulted to {}",
backend_info.max_total_tokens
);
}
let max_input_tokens = backend_info.max_input_tokens;
let max_total_tokens = backend_info.max_total_tokens;
if max_input_tokens >= max_total_tokens {
return Err(RouterError::ArgumentValidation(
"`max_input_tokens` must be < `max_total_tokens`".to_string(),
));
}
if max_input_tokens as u32 > max_batch_prefill_tokens && !support_chunking { if max_input_tokens as u32 > max_batch_prefill_tokens && !support_chunking {
return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}"))); return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}")));
} }

View File

@ -146,7 +146,7 @@ Options:
## MAX_INPUT_TOKENS ## MAX_INPUT_TOKENS
```shell ```shell
--max-input-tokens <MAX_INPUT_TOKENS> --max-input-tokens <MAX_INPUT_TOKENS>
This is the maximum allowed input length (expressed in number of tokens) for users. The larger this value, the longer prompt users can send which can impact the overall memory required to handle the load. Please note that some models have a finite range of sequence they can handle. Default to min(max_position_embeddings - 1, 4095) This is the maximum allowed input length (expressed in number of tokens) for users. The larger this value, the longer prompt users can send which can impact the overall memory required to handle the load. Please note that some models have a finite range of sequence they can handle. Default to min(max_allocatable, max_position_embeddings) - 1
[env: MAX_INPUT_TOKENS=] [env: MAX_INPUT_TOKENS=]
@ -162,7 +162,7 @@ Options:
## MAX_TOTAL_TOKENS ## MAX_TOTAL_TOKENS
```shell ```shell
--max-total-tokens <MAX_TOTAL_TOKENS> --max-total-tokens <MAX_TOTAL_TOKENS>
This is the most important value to set as it defines the "memory budget" of running clients requests. Clients will send input sequences and ask to generate `max_new_tokens` on top. with a value of `1512` users can send either a prompt of `1000` and ask for `512` new tokens, or send a prompt of `1` and ask for `1511` max_new_tokens. The larger this value, the larger amount each request will be in your RAM and the less effective batching can be. Default to min(max_position_embeddings, 4096) This is the most important value to set as it defines the "memory budget" of running clients requests. Clients will send input sequences and ask to generate `max_new_tokens` on top. with a value of `1512` users can send either a prompt of `1000` and ask for `512` new tokens, or send a prompt of `1` and ask for `1511` max_new_tokens. The larger this value, the larger amount each request will be in your RAM and the less effective batching can be. Default to min(max_allocatable, max_position_embeddings)
[env: MAX_TOTAL_TOKENS=] [env: MAX_TOTAL_TOKENS=]

View File

@ -472,7 +472,7 @@ struct Args {
/// for users. The larger this value, the longer prompt users can send which /// for users. The larger this value, the longer prompt users can send which
/// can impact the overall memory required to handle the load. /// can impact the overall memory required to handle the load.
/// Please note that some models have a finite range of sequence they can handle. /// Please note that some models have a finite range of sequence they can handle.
/// Default to min(max_position_embeddings - 1, 4095) /// Default to min(max_allocatable, max_position_embeddings) - 1
#[clap(long, env)] #[clap(long, env)]
max_input_tokens: Option<usize>, max_input_tokens: Option<usize>,
@ -488,7 +488,7 @@ struct Args {
/// `1511` max_new_tokens. /// `1511` max_new_tokens.
/// The larger this value, the larger amount each request will be in your RAM /// The larger this value, the larger amount each request will be in your RAM
/// and the less effective batching can be. /// and the less effective batching can be.
/// Default to min(max_position_embeddings, 4096) /// Default to min(max_allocatable, max_position_embeddings)
#[clap(long, env)] #[clap(long, env)]
max_total_tokens: Option<usize>, max_total_tokens: Option<usize>,
@ -718,9 +718,9 @@ fn shard_manager(
cuda_memory_fraction: f32, cuda_memory_fraction: f32,
rope_scaling: Option<RopeScaling>, rope_scaling: Option<RopeScaling>,
rope_factor: Option<f32>, rope_factor: Option<f32>,
max_total_tokens: usize, max_total_tokens: Option<usize>,
max_batch_size: Option<usize>, max_batch_size: Option<usize>,
max_input_tokens: usize, max_input_tokens: Option<usize>,
lora_adapters: Option<String>, lora_adapters: Option<String>,
otlp_endpoint: Option<String>, otlp_endpoint: Option<String>,
otlp_service_name: String, otlp_service_name: String,
@ -805,8 +805,10 @@ fn shard_manager(
shard_args.push(otlp_service_name); shard_args.push(otlp_service_name);
// In case we use sliding window, we may ignore the sliding in flash for some backends depending on the parameter. // In case we use sliding window, we may ignore the sliding in flash for some backends depending on the parameter.
shard_args.push("--max-input-tokens".to_string()); if let Some(max_input_tokens) = max_input_tokens {
shard_args.push(max_input_tokens.to_string()); shard_args.push("--max-input-tokens".to_string());
shard_args.push(max_input_tokens.to_string());
}
// Copy current process env // Copy current process env
let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect(); let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect();
@ -854,10 +856,12 @@ fn shard_manager(
envs.push(("ROPE_FACTOR".into(), factor.to_string().into())); envs.push(("ROPE_FACTOR".into(), factor.to_string().into()));
} }
envs.push(( if let Some(max_total_tokens) = max_total_tokens {
"MAX_TOTAL_TOKENS".into(), envs.push((
max_total_tokens.to_string().into(), "MAX_TOTAL_TOKENS".into(),
)); max_total_tokens.to_string().into(),
));
}
if let Some(max_batch_size) = max_batch_size { if let Some(max_batch_size) = max_batch_size {
envs.push(("MAX_BATCH_SIZE".into(), max_batch_size.to_string().into())); envs.push(("MAX_BATCH_SIZE".into(), max_batch_size.to_string().into()));
} }
@ -1315,8 +1319,8 @@ fn spawn_shards(
num_shard: usize, num_shard: usize,
args: &Args, args: &Args,
cuda_graphs: Vec<usize>, cuda_graphs: Vec<usize>,
max_total_tokens: usize, max_total_tokens: Option<usize>,
max_input_tokens: usize, max_input_tokens: Option<usize>,
quantize: Option<Quantization>, quantize: Option<Quantization>,
max_log_level: LevelFilter, max_log_level: LevelFilter,
shutdown: Arc<AtomicBool>, shutdown: Arc<AtomicBool>,
@ -1434,8 +1438,8 @@ fn compute_type(num_shard: usize) -> Option<String> {
fn spawn_webserver( fn spawn_webserver(
num_shard: usize, num_shard: usize,
args: Args, args: Args,
max_input_tokens: usize, max_input_tokens: Option<usize>,
max_total_tokens: usize, max_total_tokens: Option<usize>,
max_batch_prefill_tokens: u32, max_batch_prefill_tokens: u32,
shutdown: Arc<AtomicBool>, shutdown: Arc<AtomicBool>,
shutdown_receiver: &mpsc::Receiver<()>, shutdown_receiver: &mpsc::Receiver<()>,
@ -1454,10 +1458,6 @@ fn spawn_webserver(
args.max_stop_sequences.to_string(), args.max_stop_sequences.to_string(),
"--max-top-n-tokens".to_string(), "--max-top-n-tokens".to_string(),
args.max_top_n_tokens.to_string(), args.max_top_n_tokens.to_string(),
"--max-input-tokens".to_string(),
max_input_tokens.to_string(),
"--max-total-tokens".to_string(),
max_total_tokens.to_string(),
"--max-batch-prefill-tokens".to_string(), "--max-batch-prefill-tokens".to_string(),
max_batch_prefill_tokens.to_string(), max_batch_prefill_tokens.to_string(),
"--waiting-served-ratio".to_string(), "--waiting-served-ratio".to_string(),
@ -1475,6 +1475,18 @@ fn spawn_webserver(
"--tokenizer-name".to_string(), "--tokenizer-name".to_string(),
args.model_id, args.model_id,
]; ];
if let Some(max_input_tokens) = max_input_tokens {
router_args.extend_from_slice(&[
"--max-input-tokens".to_string(),
max_input_tokens.to_string(),
]);
}
if let Some(max_total_tokens) = max_total_tokens {
router_args.extend_from_slice(&[
"--max-total-tokens".to_string(),
max_total_tokens.to_string(),
]);
}
// Pass usage stats flags to router // Pass usage stats flags to router
router_args.push("--usage-stats".to_string()); router_args.push("--usage-stats".to_string());
@ -1704,35 +1716,19 @@ fn main() -> Result<(), LauncherError> {
format!("Both `max_input_tokens` ({max_input_tokens}) and `max_input_length` ({max_input_length}) are set. Please define only `max_input_tokens` as `max_input_length is deprecated for naming consistency.", format!("Both `max_input_tokens` ({max_input_tokens}) and `max_input_length` ({max_input_length}) are set. Please define only `max_input_tokens` as `max_input_length is deprecated for naming consistency.",
))); )));
} }
(Some(max_input_tokens), None) | (None, Some(max_input_tokens)) => max_input_tokens, (Some(max_input_tokens), None) | (None, Some(max_input_tokens)) => {
(None, None) => { Some(max_input_tokens)
let value = max_position_embeddings - 1;
tracing::info!("Default `max_input_tokens` to {value}");
value
}
}
};
let max_total_tokens = {
match args.max_total_tokens {
Some(max_total_tokens) => max_total_tokens,
None => {
let value = max_position_embeddings;
tracing::info!("Default `max_total_tokens` to {value}");
value
} }
(None, None) => None,
} }
}; };
let max_total_tokens = args.max_total_tokens;
let max_batch_prefill_tokens = { let max_batch_prefill_tokens = {
match args.max_batch_prefill_tokens { match args.max_batch_prefill_tokens {
Some(max_batch_prefill_tokens) => max_batch_prefill_tokens, Some(max_batch_prefill_tokens) => max_batch_prefill_tokens,
None => { None => {
let value: u32 = if let Some(max_batch_size) = args.max_batch_size { // TODO figure out hardware optimal value
max_batch_size * max_input_tokens let value = 4096.min(max_position_embeddings as u32);
} else {
// Adding some edge in order to account for potential block_size alignement
// issue.
max_input_tokens + 50
} as u32;
tracing::info!("Default `max_batch_prefill_tokens` to {value}"); tracing::info!("Default `max_batch_prefill_tokens` to {value}");
value value
} }
@ -1740,10 +1736,12 @@ fn main() -> Result<(), LauncherError> {
}; };
// Validate args // Validate args
if max_input_tokens >= max_total_tokens { if let (Some(max_input_tokens), Some(max_total_tokens)) = (max_input_tokens, max_total_tokens) {
return Err(LauncherError::ArgumentValidation( if max_input_tokens >= max_total_tokens {
"`max_input_tokens must be < `max_total_tokens`".to_string(), return Err(LauncherError::ArgumentValidation(
)); format!("`max_input_tokens`({max_input_tokens}) must be < `max_total_tokens`({max_total_tokens})"),
));
}
} }
if matches!(args.quantize, Some(Quantization::Bitsandbytes)) { if matches!(args.quantize, Some(Quantization::Bitsandbytes)) {
@ -1798,11 +1796,13 @@ fn main() -> Result<(), LauncherError> {
} }
if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens { if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
if max_total_tokens as u32 > *max_batch_total_tokens { if let Some(max_total_tokens) = max_total_tokens {
return Err(LauncherError::ArgumentValidation(format!( if max_total_tokens as u32 > *max_batch_total_tokens {
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}", return Err(LauncherError::ArgumentValidation(format!(
max_total_tokens, max_batch_total_tokens "`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
))); max_total_tokens, max_batch_total_tokens
)));
}
} }
} }

View File

@ -272,12 +272,18 @@ message DecodeResponse {
message WarmupRequest { message WarmupRequest {
/// Batch to warmup on /// Batch to warmup on
Batch batch = 1; Batch batch = 1;
uint32 max_input_length = 2; optional uint32 max_input_tokens = 2;
uint32 max_prefill_tokens = 3; uint32 max_prefill_tokens = 3;
uint32 max_total_tokens = 4; optional uint32 max_total_tokens = 4;
} }
message WarmupResponse { message WarmupResponse {
/// Maximum number of tokens supported by the model /// Maximum number of tokens supported by the model
optional uint32 max_supported_total_tokens = 1; optional uint32 max_supported_total_tokens = 1;
/// Maximum input tokens by clients should be equal to request value if it's set
/// Otherwise warmup automatically allocates a value here
uint32 max_input_tokens = 2;
/// Maximum total tokens by clients should be equal to request value if it's set
/// Otherwise warmup automatically allocates a value here
uint32 max_total_tokens = 3;
} }

View File

@ -86,6 +86,10 @@ tracer = trace.get_tracer(__name__)
SLIDING_WINDOW: Optional[int] = None SLIDING_WINDOW: Optional[int] = None
def small_power_of_2(n: int):
return 1 << ((n - 1).bit_length() - 1)
def set_sliding_window(sliding_window: int): def set_sliding_window(sliding_window: int):
global SLIDING_WINDOW global SLIDING_WINDOW
SLIDING_WINDOW = sliding_window SLIDING_WINDOW = sliding_window
@ -1495,11 +1499,22 @@ class FlashCausalLM(Model):
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
torch.cuda.synchronize() torch.cuda.synchronize()
def warmup(self, batch: FlashCausalLMBatch): def warmup(
self,
batch: FlashCausalLMBatch,
max_input_tokens: Optional[int],
max_total_tokens: Optional[int],
):
# The warmup batch is the biggest batch we could ever receive # The warmup batch is the biggest batch we could ever receive
self.kv_cache = [] self.kv_cache = []
empty_cache() empty_cache()
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
# Calculate the number of blocks that can be allocated with the free memory
dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size()
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
try: try:
self.init_kv_cache( self.init_kv_cache(
batch.num_blocks, batch.num_blocks,
@ -1511,10 +1526,11 @@ class FlashCausalLM(Model):
) )
max_bt = batch.max_blocks max_bt = batch.max_blocks
max_s = max_bt * BLOCK_SIZE max_s = max_bt * BLOCK_SIZE
batch_num_blocks = batch.num_blocks
if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False): if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
torch.cuda.tunable.tuning_enable(False) torch.cuda.tunable.tuning_enable(False)
_, batch, _ = self.generate_token(batch) _, _batch, _ = self.generate_token(batch)
except torch.cuda.OutOfMemoryError as e: except torch.cuda.OutOfMemoryError as e:
raise RuntimeError( raise RuntimeError(
f"Not enough memory to handle {batch.to_pb().current_tokens} prefill tokens. " f"Not enough memory to handle {batch.to_pb().current_tokens} prefill tokens. "
@ -1523,14 +1539,7 @@ class FlashCausalLM(Model):
synchronize(self.device) synchronize(self.device)
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
# Calculate the number of blocks that can be allocated with the free memory
dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size()
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
free_memory = get_free_memory(self.device, MEMORY_FRACTION) free_memory = get_free_memory(self.device, MEMORY_FRACTION)
batch_num_blocks = batch.num_blocks if batch is not None else 0
num_blocks = ( num_blocks = (
# Leave 5% for some wiggle room # Leave 5% for some wiggle room
@ -1540,8 +1549,27 @@ class FlashCausalLM(Model):
) )
log_master(logger.info, f"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}") log_master(logger.info, f"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}")
if max_total_tokens is None:
if get_support_chunking():
model_max_length = self.tokenizer.model_max_length
max_input_tokens = (
min((num_blocks * BLOCK_SIZE - 1), model_max_length)
if max_input_tokens is None
else max_input_tokens
)
max_total_tokens = num_blocks * BLOCK_SIZE
del batch else:
max_total_tokens = sum(batch.cache_lengths)
max_input_tokens = (
max_total_tokens - 1
if max_input_tokens is None
else max_input_tokens
)
del _batch, batch
self.kv_cache = []
empty_cache()
self.init_kv_cache( self.init_kv_cache(
num_blocks, num_blocks,
@ -1623,7 +1651,9 @@ class FlashCausalLM(Model):
logger.info, f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS})." logger.info, f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS})."
) )
return int(num_blocks * BLOCK_SIZE) assert max_input_tokens is not None
assert max_total_tokens is not None
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
def tunableop_warmup(self, seqlen: int): def tunableop_warmup(self, seqlen: int):
input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device) input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device)

View File

@ -1,7 +1,7 @@
import torch import torch
import torch.distributed import torch.distributed
from transformers import AutoTokenizer, PreTrainedTokenizerBase from transformers import AutoTokenizer, PreTrainedTokenizerBase
from typing import Optional from typing import Optional, Union
from text_generation_server.models.custom_modeling.mamba_modeling import ( from text_generation_server.models.custom_modeling.mamba_modeling import (
MambaConfig, MambaConfig,
) )
@ -475,7 +475,9 @@ class Mamba(Model):
def batch_type(self) -> Type[MambaBatch]: def batch_type(self) -> Type[MambaBatch]:
return MambaBatch return MambaBatch
def warmup(self, batch) -> Optional[int]: def warmup(
self, batch, max_input_tokens: Optional[int], max_total_tokens: Optional[int]
) -> Union[Optional[int], Optional[int], Optional[int]]:
# TODO: implement warmup for Mamba if needed # TODO: implement warmup for Mamba if needed
if CUDA_GRAPHS: if CUDA_GRAPHS:
if self.speculate is None or self.speculate == 0: if self.speculate is None or self.speculate == 0:
@ -489,7 +491,12 @@ class Mamba(Model):
else: else:
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).") logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
return None if max_total_tokens is None:
max_total_tokens = min(self.tokenizer.model_max_length, 4096)
if max_input_tokens is None:
max_input_tokens = max_total_tokens - 1
return None, max_input_tokens, max_total_tokens
def cuda_graph_warmup(self, batch_size: int): def cuda_graph_warmup(self, batch_size: int):
input_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=self.device) input_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=self.device)

View File

@ -128,9 +128,17 @@ class Model(ABC):
) -> Tuple[List[Generation], Optional[B], Tuple[int, int]]: ) -> Tuple[List[Generation], Optional[B], Tuple[int, int]]:
raise NotImplementedError raise NotImplementedError
def warmup(self, batch: B) -> Optional[int]: def warmup(
self, batch: B, max_input_tokens: Optional[int], max_total_tokens: Optional[int]
) -> Tuple[Optional[int], int, int]:
self.generate_token(batch) self.generate_token(batch)
return None total = sum(len(i) for i in batch.input_ids)
if max_total_tokens is None:
max_total_tokens = total
if max_input_tokens is None:
max_input_tokens = max_total_tokens - 1
return None, max_input_tokens, max_total_tokens
def decode_token( def decode_token(
self, self,

View File

@ -132,10 +132,22 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
batch = self.model.batch_type.from_pb( batch = self.model.batch_type.from_pb(
request.batch, self.model.tokenizer, self.model.dtype, self.model.device request.batch, self.model.tokenizer, self.model.dtype, self.model.device
) )
max_supported_total_tokens = self.model.warmup(batch)
# Override default values with None for clearer semantics.
max_input_tokens = (
request.max_input_tokens if request.HasField("max_input_tokens") else None
)
max_total_tokens = (
request.max_total_tokens if request.HasField("max_total_tokens") else None
)
max_supported_total_tokens, max_input_tokens, max_total_tokens = (
self.model.warmup(batch, max_input_tokens, max_total_tokens)
)
return generate_pb2.WarmupResponse( return generate_pb2.WarmupResponse(
max_supported_total_tokens=max_supported_total_tokens max_supported_total_tokens=max_supported_total_tokens,
max_input_tokens=max_input_tokens,
max_total_tokens=max_total_tokens,
) )
async def Prefill(self, request, context): async def Prefill(self, request, context):