diff --git a/.gitignore b/.gitignore index 4270a1ae..9434d75c 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,8 @@ router/tokenizer.json backends/v2/src/client/pb backends/v3/src/client/pb +backends/client/src/v2/pb +backends/client/src/v3/pb # ROCm auto-generated files *.hip diff --git a/backends/client/src/v3/client.rs b/backends/client/src/v3/client.rs index d43f789e..968c1f45 100644 --- a/backends/client/src/v3/client.rs +++ b/backends/client/src/v3/client.rs @@ -107,20 +107,22 @@ impl Client { #[instrument(skip_all)] pub async fn warmup( &mut self, - max_input_length: u32, + max_input_tokens: Option, max_prefill_tokens: u32, - max_total_tokens: u32, + max_total_tokens: Option, max_batch_size: Option, - ) -> Result> { + ) -> Result<(Option, u32, u32)> { let mut n_tokens = 0; let mut requests = Vec::new(); // Create requests while n_tokens < max_prefill_tokens { - let truncate = min(max_input_length, max_prefill_tokens - n_tokens); + let mut truncate = max_prefill_tokens - n_tokens; + if let Some(max_input_tokens) = max_input_tokens { + truncate = min(max_input_tokens, truncate); + } let mut input_chunks = Vec::new(); - input_chunks - .push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into()); + input_chunks.push(Chunk::Text("_test ".to_string().repeat(truncate as usize)).into()); if n_tokens == 0 { input_chunks.push( Chunk::Image(Image { @@ -136,7 +138,7 @@ impl Client { // been updated to support chunks. let mut inputs = String::new(); - inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize)); + inputs.push_str(&"_test ".to_string().repeat(truncate as usize)); if n_tokens == 0 { // 1 request is enough to test vision heads. // Sending images on other queries messes up easily with truncation. @@ -145,6 +147,12 @@ impl Client { )); } + let max_new_tokens = if let Some(max_total_tokens) = max_total_tokens { + max_total_tokens - truncate + } else { + 1 + }; + requests.push(Request { id: 0, inputs, @@ -175,7 +183,7 @@ impl Client { grammar_type: GrammarType::None as i32, }), stopping_parameters: Some(StoppingCriteriaParameters { - max_new_tokens: max_total_tokens - truncate, + max_new_tokens, stop_sequences: vec![], ignore_eos_token: true, }), @@ -183,7 +191,7 @@ impl Client { top_n_tokens: 20, adapter_id: None, }); - n_tokens += max_input_length; + n_tokens += truncate; // Check max_batch_size if Some(requests.len()) == max_batch_size { @@ -195,19 +203,23 @@ impl Client { id: 0, size: requests.len() as u32, requests, - max_tokens: max_input_length, + max_tokens: max_input_tokens.unwrap_or(0), max_blocks: 0, }; let request = tonic::Request::new(WarmupRequest { batch: Some(batch), - max_input_length, + max_input_tokens, max_prefill_tokens, max_total_tokens, }) .inject_context(); let response = self.stub.warmup(request).await?.into_inner(); - Ok(response.max_supported_total_tokens) + Ok(( + response.max_supported_total_tokens, + response.max_input_tokens, + response.max_total_tokens, + )) } /// Generate one token for each request in the given batch diff --git a/backends/client/src/v3/sharded_client.rs b/backends/client/src/v3/sharded_client.rs index 854a5895..dc3bcdde 100644 --- a/backends/client/src/v3/sharded_client.rs +++ b/backends/client/src/v3/sharded_client.rs @@ -101,11 +101,11 @@ impl ShardedClient { #[instrument(skip(self))] pub async fn warmup( &mut self, - max_input_length: u32, + max_input_length: Option, max_prefill_tokens: u32, - max_total_tokens: u32, + max_total_tokens: Option, max_batch_size: Option, - ) -> Result> { + ) -> Result<(Option, u32, u32)> { let futures: Vec<_> = self .clients .iter_mut() @@ -122,8 +122,16 @@ impl ShardedClient { let results = join_all(futures) .await .into_iter() - .collect::>>>()?; - Ok(results.into_iter().flatten().min()) + .collect::, u32, u32)>>>()?; + + // Take the minimum value + // Different shards hold different parts of vocab, might yield + // different available block size. + let min = results + .iter() + .min() + .expect("Expect at least 1 warmup result"); + Ok(*min) } /// Generate one token for each request in the given batch diff --git a/backends/v3/src/client/grpc_client.rs b/backends/v3/src/client/grpc_client.rs index fe810f24..f4942f64 100644 --- a/backends/v3/src/client/grpc_client.rs +++ b/backends/v3/src/client/grpc_client.rs @@ -108,20 +108,22 @@ impl Client { #[instrument(skip_all)] pub async fn warmup( &mut self, - max_input_length: u32, + max_input_tokens: Option, max_prefill_tokens: u32, - max_total_tokens: u32, + max_total_tokens: Option, max_batch_size: Option, - ) -> Result> { + ) -> Result<(Option, u32, u32)> { let mut n_tokens = 0; let mut requests = Vec::new(); // Create requests while n_tokens < max_prefill_tokens { - let truncate = min(max_input_length, max_prefill_tokens - n_tokens); + let mut truncate = max_prefill_tokens - n_tokens; + if let Some(max_input_tokens) = max_input_tokens { + truncate = min(max_input_tokens, truncate); + } let mut input_chunks = Vec::new(); - input_chunks - .push(Chunk::Text("_test ".to_string().repeat(max_input_length as usize)).into()); + input_chunks.push(Chunk::Text("_test ".to_string().repeat(truncate as usize)).into()); if n_tokens == 0 { input_chunks.push( Chunk::Image(Image { @@ -137,7 +139,7 @@ impl Client { // been updated to support chunks. let mut inputs = String::new(); - inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize)); + inputs.push_str(&"_test ".to_string().repeat(truncate as usize)); if n_tokens == 0 { // 1 request is enough to test vision heads. // Sending images on other queries messes up easily with truncation. @@ -146,6 +148,12 @@ impl Client { )); } + let max_new_tokens = if let Some(max_total_tokens) = max_total_tokens { + max_total_tokens - truncate + } else { + 1 + }; + requests.push(Request { id: 0, inputs, @@ -175,7 +183,7 @@ impl Client { grammar_type: GrammarType::None as i32, }), stopping_parameters: Some(StoppingCriteriaParameters { - max_new_tokens: max_total_tokens - truncate, + max_new_tokens, stop_sequences: vec![], ignore_eos_token: true, }), @@ -183,7 +191,7 @@ impl Client { top_n_tokens: 20, adapter_id: None, }); - n_tokens += max_input_length; + n_tokens += truncate; // Check max_batch_size if Some(requests.len()) == max_batch_size { @@ -195,19 +203,23 @@ impl Client { id: 0, size: requests.len() as u32, requests, - max_tokens: max_input_length, + max_tokens: max_input_tokens.unwrap_or(0), max_blocks: 0, }; let request = tonic::Request::new(WarmupRequest { batch: Some(batch), - max_input_length, + max_input_tokens, max_prefill_tokens, max_total_tokens, }) .inject_context(); let response = self.stub.warmup(request).await?.into_inner(); - Ok(response.max_supported_total_tokens) + Ok(( + response.max_supported_total_tokens, + response.max_input_tokens, + response.max_total_tokens, + )) } /// Generate one token for each request in the given batch diff --git a/backends/v3/src/client/sharded_client.rs b/backends/v3/src/client/sharded_client.rs index e181cd28..6d4e207b 100644 --- a/backends/v3/src/client/sharded_client.rs +++ b/backends/v3/src/client/sharded_client.rs @@ -102,11 +102,11 @@ impl ShardedClient { #[instrument(skip(self))] pub async fn warmup( &mut self, - max_input_length: u32, + max_input_length: Option, max_prefill_tokens: u32, - max_total_tokens: u32, + max_total_tokens: Option, max_batch_size: Option, - ) -> Result> { + ) -> Result<(Option, u32, u32)> { let futures: Vec<_> = self .clients .iter_mut() @@ -119,12 +119,19 @@ impl ShardedClient { )) }) .collect(); - // Take the minimum value let results = join_all(futures) .await .into_iter() - .collect::>>>()?; - Ok(results.into_iter().flatten().min()) + .collect::, u32, u32)>>>()?; + + // Take the minimum value + // Different shards hold different parts of vocab, might yield + // different available block size. + let min = results + .iter() + .min() + .expect("Expect at least 1 warmup result"); + Ok(*min) } /// Generate one token for each request in the given batch diff --git a/backends/v3/src/lib.rs b/backends/v3/src/lib.rs index 7daf9eae..09137853 100644 --- a/backends/v3/src/lib.rs +++ b/backends/v3/src/lib.rs @@ -37,12 +37,17 @@ pub struct BackendInfo { pub attention_impl: String, #[schema(example = "1")] pub block_size: u32, + + #[schema(example = "30000")] + pub max_input_tokens: usize, + #[schema(example = "32000")] + pub max_total_tokens: usize, } #[allow(clippy::too_many_arguments)] pub async fn connect_backend( - max_input_tokens: usize, - max_total_tokens: usize, + max_input_tokens: Option, + max_total_tokens: Option, master_shard_uds_path: String, waiting_served_ratio: f32, max_batch_prefill_tokens: u32, @@ -51,14 +56,32 @@ pub async fn connect_backend( max_batch_size: Option, ) -> Result<(BackendV3, BackendInfo), V3Error> { // Helper function - let check_max_batch_total_tokens = |max_supported_batch_total_tokens: Option| { + let check_max_batch_total_tokens = |( + max_supported_batch_total_tokens, + shard_max_input_tokens, + shard_max_total_tokens, + ): (Option, u32, u32)| + -> Result<(u32, usize, usize), V3Error> { + if let Some(max_input_tokens) = max_input_tokens { + assert_eq!(max_input_tokens as u32, shard_max_input_tokens); + } + if let Some(max_total_tokens) = max_total_tokens { + assert_eq!(max_total_tokens as u32, shard_max_total_tokens); + } match max_supported_batch_total_tokens { // Older models do not support automatic max-batch-total-tokens None => { - let max_batch_total_tokens = max_batch_total_tokens - .unwrap_or(16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens))); + let max_batch_total_tokens = max_batch_total_tokens.unwrap_or( + 16000 + .max(shard_max_total_tokens) + .max(max_batch_prefill_tokens), + ); tracing::warn!("Model does not support automatic max batch total tokens"); - Ok(max_batch_total_tokens) + Ok(( + max_batch_total_tokens, + shard_max_input_tokens as usize, + shard_max_total_tokens as usize, + )) } // Flash attention models return their max supported total tokens Some(max_supported_batch_total_tokens) => { @@ -72,11 +95,15 @@ pub async fn connect_backend( "Inferred max batch total tokens: {max_supported_batch_total_tokens}" ); } - if max_total_tokens as u32 > max_supported_batch_total_tokens { - return Err(V3Error::NotEnoughMemory(max_total_tokens)); + if shard_max_total_tokens > max_supported_batch_total_tokens { + return Err(V3Error::NotEnoughMemory(shard_max_total_tokens as usize)); } - Ok(max_supported_batch_total_tokens) + Ok(( + max_supported_batch_total_tokens, + shard_max_input_tokens as usize, + shard_max_total_tokens as usize, + )) } } }; @@ -96,23 +123,25 @@ pub async fn connect_backend( // Warmup model tracing::info!("Warming up model"); - let max_batch_total_tokens = check_max_batch_total_tokens( - sharded_client - .warmup( - max_input_tokens as u32, - max_batch_prefill_tokens, - max_total_tokens as u32, - max_batch_size, - ) - .await - .map_err(V3Error::Warmup)?, - )?; + let answer = sharded_client + .warmup( + max_input_tokens.map(|p| p as u32), + max_batch_prefill_tokens, + max_total_tokens.map(|p| p as u32), + max_batch_size, + ) + .await + .map_err(V3Error::Warmup)?; + let (max_batch_total_tokens, max_input_tokens, max_total_tokens) = + check_max_batch_total_tokens(answer)?; tracing::info!("Setting max batch total tokens to {max_batch_total_tokens}"); metrics::gauge!("tgi_batch_max_total_tokens").set(max_batch_total_tokens); let backend_info = BackendInfo { waiting_served_ratio, max_batch_total_tokens, + max_input_tokens, + max_total_tokens, max_waiting_tokens, max_batch_size, model_device_type: shard_info.device_type.clone(), diff --git a/backends/v3/src/main.rs b/backends/v3/src/main.rs index bc4bdb93..279a8252 100644 --- a/backends/v3/src/main.rs +++ b/backends/v3/src/main.rs @@ -18,10 +18,10 @@ struct Args { max_stop_sequences: usize, #[clap(default_value = "5", long, env)] max_top_n_tokens: u32, - #[clap(default_value = "1024", long, env)] - max_input_tokens: usize, - #[clap(default_value = "2048", long, env)] - max_total_tokens: usize, + #[clap(long, env)] + max_input_tokens: Option, + #[clap(long, env)] + max_total_tokens: Option, #[clap(default_value = "1.2", long, env)] waiting_served_ratio: f32, #[clap(default_value = "4096", long, env)] @@ -126,12 +126,6 @@ async fn main() -> Result<(), RouterError> { text_generation_router::logging::init_logging(otlp_endpoint, otlp_service_name, json_output); // Validate args - if max_input_tokens >= max_total_tokens { - return Err(RouterError::ArgumentValidation( - "`max_input_tokens` must be < `max_total_tokens`".to_string(), - )); - } - if validation_workers == 0 { return Err(RouterError::ArgumentValidation( "`validation_workers` must be > 0".to_string(), @@ -160,6 +154,28 @@ async fn main() -> Result<(), RouterError> { // Validate remaining args now that the backend is known let support_chunking = backend_info.support_chunking; let max_batch_total_tokens = backend_info.max_batch_total_tokens; + + if max_input_tokens.is_none() { + tracing::info!( + "Maximum input tokens defaulted to {}", + backend_info.max_input_tokens + ); + } + if max_total_tokens.is_none() { + tracing::info!( + "Maximum total tokens defaulted to {}", + backend_info.max_total_tokens + ); + } + + let max_input_tokens = backend_info.max_input_tokens; + let max_total_tokens = backend_info.max_total_tokens; + if max_input_tokens >= max_total_tokens { + return Err(RouterError::ArgumentValidation( + "`max_input_tokens` must be < `max_total_tokens`".to_string(), + )); + } + if max_input_tokens as u32 > max_batch_prefill_tokens && !support_chunking { return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens}"))); } diff --git a/docs/source/reference/launcher.md b/docs/source/reference/launcher.md index 68e487d0..da0c8717 100644 --- a/docs/source/reference/launcher.md +++ b/docs/source/reference/launcher.md @@ -146,7 +146,7 @@ Options: ## MAX_INPUT_TOKENS ```shell --max-input-tokens - This is the maximum allowed input length (expressed in number of tokens) for users. The larger this value, the longer prompt users can send which can impact the overall memory required to handle the load. Please note that some models have a finite range of sequence they can handle. Default to min(max_position_embeddings - 1, 4095) + This is the maximum allowed input length (expressed in number of tokens) for users. The larger this value, the longer prompt users can send which can impact the overall memory required to handle the load. Please note that some models have a finite range of sequence they can handle. Default to min(max_allocatable, max_position_embeddings) - 1 [env: MAX_INPUT_TOKENS=] @@ -162,7 +162,7 @@ Options: ## MAX_TOTAL_TOKENS ```shell --max-total-tokens - This is the most important value to set as it defines the "memory budget" of running clients requests. Clients will send input sequences and ask to generate `max_new_tokens` on top. with a value of `1512` users can send either a prompt of `1000` and ask for `512` new tokens, or send a prompt of `1` and ask for `1511` max_new_tokens. The larger this value, the larger amount each request will be in your RAM and the less effective batching can be. Default to min(max_position_embeddings, 4096) + This is the most important value to set as it defines the "memory budget" of running clients requests. Clients will send input sequences and ask to generate `max_new_tokens` on top. with a value of `1512` users can send either a prompt of `1000` and ask for `512` new tokens, or send a prompt of `1` and ask for `1511` max_new_tokens. The larger this value, the larger amount each request will be in your RAM and the less effective batching can be. Default to min(max_allocatable, max_position_embeddings) [env: MAX_TOTAL_TOKENS=] diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 71bbcbd8..19a79115 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -472,7 +472,7 @@ struct Args { /// for users. The larger this value, the longer prompt users can send which /// can impact the overall memory required to handle the load. /// Please note that some models have a finite range of sequence they can handle. - /// Default to min(max_position_embeddings - 1, 4095) + /// Default to min(max_allocatable, max_position_embeddings) - 1 #[clap(long, env)] max_input_tokens: Option, @@ -488,7 +488,7 @@ struct Args { /// `1511` max_new_tokens. /// The larger this value, the larger amount each request will be in your RAM /// and the less effective batching can be. - /// Default to min(max_position_embeddings, 4096) + /// Default to min(max_allocatable, max_position_embeddings) #[clap(long, env)] max_total_tokens: Option, @@ -718,9 +718,9 @@ fn shard_manager( cuda_memory_fraction: f32, rope_scaling: Option, rope_factor: Option, - max_total_tokens: usize, + max_total_tokens: Option, max_batch_size: Option, - max_input_tokens: usize, + max_input_tokens: Option, lora_adapters: Option, otlp_endpoint: Option, otlp_service_name: String, @@ -805,8 +805,10 @@ fn shard_manager( shard_args.push(otlp_service_name); // In case we use sliding window, we may ignore the sliding in flash for some backends depending on the parameter. - shard_args.push("--max-input-tokens".to_string()); - shard_args.push(max_input_tokens.to_string()); + if let Some(max_input_tokens) = max_input_tokens { + shard_args.push("--max-input-tokens".to_string()); + shard_args.push(max_input_tokens.to_string()); + } // Copy current process env let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect(); @@ -854,10 +856,12 @@ fn shard_manager( envs.push(("ROPE_FACTOR".into(), factor.to_string().into())); } - envs.push(( - "MAX_TOTAL_TOKENS".into(), - max_total_tokens.to_string().into(), - )); + if let Some(max_total_tokens) = max_total_tokens { + envs.push(( + "MAX_TOTAL_TOKENS".into(), + max_total_tokens.to_string().into(), + )); + } if let Some(max_batch_size) = max_batch_size { envs.push(("MAX_BATCH_SIZE".into(), max_batch_size.to_string().into())); } @@ -1315,8 +1319,8 @@ fn spawn_shards( num_shard: usize, args: &Args, cuda_graphs: Vec, - max_total_tokens: usize, - max_input_tokens: usize, + max_total_tokens: Option, + max_input_tokens: Option, quantize: Option, max_log_level: LevelFilter, shutdown: Arc, @@ -1434,8 +1438,8 @@ fn compute_type(num_shard: usize) -> Option { fn spawn_webserver( num_shard: usize, args: Args, - max_input_tokens: usize, - max_total_tokens: usize, + max_input_tokens: Option, + max_total_tokens: Option, max_batch_prefill_tokens: u32, shutdown: Arc, shutdown_receiver: &mpsc::Receiver<()>, @@ -1454,10 +1458,6 @@ fn spawn_webserver( args.max_stop_sequences.to_string(), "--max-top-n-tokens".to_string(), args.max_top_n_tokens.to_string(), - "--max-input-tokens".to_string(), - max_input_tokens.to_string(), - "--max-total-tokens".to_string(), - max_total_tokens.to_string(), "--max-batch-prefill-tokens".to_string(), max_batch_prefill_tokens.to_string(), "--waiting-served-ratio".to_string(), @@ -1475,6 +1475,18 @@ fn spawn_webserver( "--tokenizer-name".to_string(), args.model_id, ]; + if let Some(max_input_tokens) = max_input_tokens { + router_args.extend_from_slice(&[ + "--max-input-tokens".to_string(), + max_input_tokens.to_string(), + ]); + } + if let Some(max_total_tokens) = max_total_tokens { + router_args.extend_from_slice(&[ + "--max-total-tokens".to_string(), + max_total_tokens.to_string(), + ]); + } // Pass usage stats flags to router router_args.push("--usage-stats".to_string()); @@ -1704,35 +1716,19 @@ fn main() -> Result<(), LauncherError> { format!("Both `max_input_tokens` ({max_input_tokens}) and `max_input_length` ({max_input_length}) are set. Please define only `max_input_tokens` as `max_input_length is deprecated for naming consistency.", ))); } - (Some(max_input_tokens), None) | (None, Some(max_input_tokens)) => max_input_tokens, - (None, None) => { - let value = max_position_embeddings - 1; - tracing::info!("Default `max_input_tokens` to {value}"); - value - } - } - }; - let max_total_tokens = { - match args.max_total_tokens { - Some(max_total_tokens) => max_total_tokens, - None => { - let value = max_position_embeddings; - tracing::info!("Default `max_total_tokens` to {value}"); - value + (Some(max_input_tokens), None) | (None, Some(max_input_tokens)) => { + Some(max_input_tokens) } + (None, None) => None, } }; + let max_total_tokens = args.max_total_tokens; let max_batch_prefill_tokens = { match args.max_batch_prefill_tokens { Some(max_batch_prefill_tokens) => max_batch_prefill_tokens, None => { - let value: u32 = if let Some(max_batch_size) = args.max_batch_size { - max_batch_size * max_input_tokens - } else { - // Adding some edge in order to account for potential block_size alignement - // issue. - max_input_tokens + 50 - } as u32; + // TODO figure out hardware optimal value + let value = 4096.min(max_position_embeddings as u32); tracing::info!("Default `max_batch_prefill_tokens` to {value}"); value } @@ -1740,10 +1736,12 @@ fn main() -> Result<(), LauncherError> { }; // Validate args - if max_input_tokens >= max_total_tokens { - return Err(LauncherError::ArgumentValidation( - "`max_input_tokens must be < `max_total_tokens`".to_string(), - )); + if let (Some(max_input_tokens), Some(max_total_tokens)) = (max_input_tokens, max_total_tokens) { + if max_input_tokens >= max_total_tokens { + return Err(LauncherError::ArgumentValidation( + format!("`max_input_tokens`({max_input_tokens}) must be < `max_total_tokens`({max_total_tokens})"), + )); + } } if matches!(args.quantize, Some(Quantization::Bitsandbytes)) { @@ -1798,11 +1796,13 @@ fn main() -> Result<(), LauncherError> { } if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens { - if max_total_tokens as u32 > *max_batch_total_tokens { - return Err(LauncherError::ArgumentValidation(format!( - "`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}", - max_total_tokens, max_batch_total_tokens - ))); + if let Some(max_total_tokens) = max_total_tokens { + if max_total_tokens as u32 > *max_batch_total_tokens { + return Err(LauncherError::ArgumentValidation(format!( + "`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}", + max_total_tokens, max_batch_total_tokens + ))); + } } } diff --git a/proto/v3/generate.proto b/proto/v3/generate.proto index c91e7cc4..02980b6f 100644 --- a/proto/v3/generate.proto +++ b/proto/v3/generate.proto @@ -272,12 +272,18 @@ message DecodeResponse { message WarmupRequest { /// Batch to warmup on Batch batch = 1; - uint32 max_input_length = 2; + optional uint32 max_input_tokens = 2; uint32 max_prefill_tokens = 3; - uint32 max_total_tokens = 4; + optional uint32 max_total_tokens = 4; } message WarmupResponse { /// Maximum number of tokens supported by the model optional uint32 max_supported_total_tokens = 1; + /// Maximum input tokens by clients should be equal to request value if it's set + /// Otherwise warmup automatically allocates a value here + uint32 max_input_tokens = 2; + /// Maximum total tokens by clients should be equal to request value if it's set + /// Otherwise warmup automatically allocates a value here + uint32 max_total_tokens = 3; } diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 87e904f4..8ab1a811 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -86,6 +86,10 @@ tracer = trace.get_tracer(__name__) SLIDING_WINDOW: Optional[int] = None +def small_power_of_2(n: int): + return 1 << ((n - 1).bit_length() - 1) + + def set_sliding_window(sliding_window: int): global SLIDING_WINDOW SLIDING_WINDOW = sliding_window @@ -1495,11 +1499,22 @@ class FlashCausalLM(Model): self.cuda_graphs[bs]["speculative_logits"] = speculative_logits torch.cuda.synchronize() - def warmup(self, batch: FlashCausalLMBatch): + def warmup( + self, + batch: FlashCausalLMBatch, + max_input_tokens: Optional[int], + max_total_tokens: Optional[int], + ): # The warmup batch is the biggest batch we could ever receive self.kv_cache = [] empty_cache() + # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm) + # Calculate the number of blocks that can be allocated with the free memory + dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size() + cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size + total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size + try: self.init_kv_cache( batch.num_blocks, @@ -1511,10 +1526,11 @@ class FlashCausalLM(Model): ) max_bt = batch.max_blocks max_s = max_bt * BLOCK_SIZE + batch_num_blocks = batch.num_blocks if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False): torch.cuda.tunable.tuning_enable(False) - _, batch, _ = self.generate_token(batch) + _, _batch, _ = self.generate_token(batch) except torch.cuda.OutOfMemoryError as e: raise RuntimeError( f"Not enough memory to handle {batch.to_pb().current_tokens} prefill tokens. " @@ -1523,14 +1539,7 @@ class FlashCausalLM(Model): synchronize(self.device) - # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm) - # Calculate the number of blocks that can be allocated with the free memory - dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size() - cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size - total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size - free_memory = get_free_memory(self.device, MEMORY_FRACTION) - batch_num_blocks = batch.num_blocks if batch is not None else 0 num_blocks = ( # Leave 5% for some wiggle room @@ -1540,8 +1549,27 @@ class FlashCausalLM(Model): ) log_master(logger.info, f"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}") + if max_total_tokens is None: + if get_support_chunking(): + model_max_length = self.tokenizer.model_max_length + max_input_tokens = ( + min((num_blocks * BLOCK_SIZE - 1), model_max_length) + if max_input_tokens is None + else max_input_tokens + ) + max_total_tokens = num_blocks * BLOCK_SIZE - del batch + else: + max_total_tokens = sum(batch.cache_lengths) + max_input_tokens = ( + max_total_tokens - 1 + if max_input_tokens is None + else max_input_tokens + ) + + del _batch, batch + self.kv_cache = [] + empty_cache() self.init_kv_cache( num_blocks, @@ -1623,7 +1651,9 @@ class FlashCausalLM(Model): logger.info, f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS})." ) - return int(num_blocks * BLOCK_SIZE) + assert max_input_tokens is not None + assert max_total_tokens is not None + return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens def tunableop_warmup(self, seqlen: int): input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device) diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index dfc61fb8..3bba1cf2 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -1,7 +1,7 @@ import torch import torch.distributed from transformers import AutoTokenizer, PreTrainedTokenizerBase -from typing import Optional +from typing import Optional, Union from text_generation_server.models.custom_modeling.mamba_modeling import ( MambaConfig, ) @@ -475,7 +475,9 @@ class Mamba(Model): def batch_type(self) -> Type[MambaBatch]: return MambaBatch - def warmup(self, batch) -> Optional[int]: + def warmup( + self, batch, max_input_tokens: Optional[int], max_total_tokens: Optional[int] + ) -> Union[Optional[int], Optional[int], Optional[int]]: # TODO: implement warmup for Mamba if needed if CUDA_GRAPHS: if self.speculate is None or self.speculate == 0: @@ -489,7 +491,12 @@ class Mamba(Model): else: logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).") - return None + if max_total_tokens is None: + max_total_tokens = min(self.tokenizer.model_max_length, 4096) + + if max_input_tokens is None: + max_input_tokens = max_total_tokens - 1 + return None, max_input_tokens, max_total_tokens def cuda_graph_warmup(self, batch_size: int): input_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=self.device) diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index b3630013..c75592c1 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -128,9 +128,17 @@ class Model(ABC): ) -> Tuple[List[Generation], Optional[B], Tuple[int, int]]: raise NotImplementedError - def warmup(self, batch: B) -> Optional[int]: + def warmup( + self, batch: B, max_input_tokens: Optional[int], max_total_tokens: Optional[int] + ) -> Tuple[Optional[int], int, int]: self.generate_token(batch) - return None + total = sum(len(i) for i in batch.input_ids) + if max_total_tokens is None: + max_total_tokens = total + + if max_input_tokens is None: + max_input_tokens = max_total_tokens - 1 + return None, max_input_tokens, max_total_tokens def decode_token( self, diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index aef00fb5..45b48df8 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -132,10 +132,22 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): batch = self.model.batch_type.from_pb( request.batch, self.model.tokenizer, self.model.dtype, self.model.device ) - max_supported_total_tokens = self.model.warmup(batch) + + # Override default values with None for clearer semantics. + max_input_tokens = ( + request.max_input_tokens if request.HasField("max_input_tokens") else None + ) + max_total_tokens = ( + request.max_total_tokens if request.HasField("max_total_tokens") else None + ) + max_supported_total_tokens, max_input_tokens, max_total_tokens = ( + self.model.warmup(batch, max_input_tokens, max_total_tokens) + ) return generate_pb2.WarmupResponse( - max_supported_total_tokens=max_supported_total_tokens + max_supported_total_tokens=max_supported_total_tokens, + max_input_tokens=max_input_tokens, + max_total_tokens=max_total_tokens, ) async def Prefill(self, request, context):