From fe80f5360ca885af4c63a3f1db9f6786df7ebc76 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Wed, 19 Jul 2023 09:31:25 +0200 Subject: [PATCH] feat(server): auto max_batch_total_tokens for flash att models (#630) --- launcher/src/main.rs | 54 +++++++++--------- proto/generate.proto | 7 ++- router/client/src/client.rs | 13 ++--- router/client/src/sharded_client.rs | 7 +-- router/src/infer.rs | 2 +- router/src/main.rs | 57 +++++++++++++------ router/src/queue.rs | 51 +++++++++++------ .../models/flash_causal_lm.py | 45 ++++++++++++--- server/text_generation_server/models/model.py | 3 +- server/text_generation_server/server.py | 14 ++--- 10 files changed, 159 insertions(+), 94 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index e26244e..c03af4b 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -184,8 +184,8 @@ struct Args { /// depends on other parameters like if you're using quantization, flash attention /// or the model implementation, text-generation-inference cannot infer this number /// automatically. - #[clap(default_value = "16000", long, env)] - max_batch_total_tokens: u32, + #[clap(long, env)] + max_batch_total_tokens: Option, /// This setting defines how many tokens can be passed before forcing the waiting /// queries to be put on the batch (if the size of the batch allows for it). @@ -369,12 +369,6 @@ fn shard_manager( // Copy current process env let mut envs: Vec<(OsString, OsString)> = env::vars_os().collect(); - // Use cuda allocator. It leads to less memory fragmentation - envs.push(( - "PYTORCH_CUDA_ALLOC_CONF".into(), - "backend:cudaMallocAsync".into(), - )); - // Torch Distributed Env vars envs.push(("RANK".into(), rank.to_string().into())); envs.push(("WORLD_SIZE".into(), world_size.to_string().into())); @@ -428,7 +422,7 @@ fn shard_manager( } // Start process - tracing::info!("Starting shard {rank}"); + tracing::info!("Starting shard"); let mut p = match Command::new("text-generation-server") .args(shard_args) .envs(envs) @@ -493,17 +487,17 @@ fn shard_manager( if shutdown.load(Ordering::SeqCst) { p.kill().unwrap(); let _ = p.wait(); - tracing::info!("Shard {rank} terminated"); + tracing::info!("Shard terminated"); return; } // Shard is ready if uds.exists() && !ready { - tracing::info!("Shard {rank} ready in {:?}", start_time.elapsed()); + tracing::info!("Shard ready in {:?}", start_time.elapsed()); status_sender.send(ShardStatus::Ready).unwrap(); ready = true; } else if !ready && wait_time.elapsed() > Duration::from_secs(10) { - tracing::info!("Waiting for shard {rank} to be ready..."); + tracing::info!("Waiting for shard to be ready..."); wait_time = Instant::now(); } sleep(Duration::from_millis(100)); @@ -860,8 +854,6 @@ fn spawn_webserver( args.max_total_tokens.to_string(), "--max-batch-prefill-tokens".to_string(), args.max_batch_prefill_tokens.to_string(), - "--max-batch-total-tokens".to_string(), - args.max_batch_total_tokens.to_string(), "--waiting-served-ratio".to_string(), args.waiting_served_ratio.to_string(), "--max-waiting-tokens".to_string(), @@ -878,6 +870,12 @@ fn spawn_webserver( args.model_id, ]; + // Model optional max batch total tokens + if let Some(max_batch_total_tokens) = args.max_batch_total_tokens { + router_args.push("--max-batch-total-tokens".to_string()); + router_args.push(max_batch_total_tokens.to_string()); + } + // Model optional revision if let Some(ref revision) = args.revision { router_args.push("--revision".to_string()); @@ -1036,18 +1034,7 @@ fn main() -> Result<(), LauncherError> { args.max_batch_prefill_tokens, args.max_input_length ))); } - if args.max_batch_prefill_tokens > args.max_batch_total_tokens { - return Err(LauncherError::ArgumentValidation(format!( - "`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}", - args.max_batch_prefill_tokens, args.max_batch_total_tokens - ))); - } - if args.max_total_tokens as u32 > args.max_batch_total_tokens { - return Err(LauncherError::ArgumentValidation(format!( - "`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}", - args.max_total_tokens, args.max_batch_total_tokens - ))); - } + if args.validation_workers == 0 { return Err(LauncherError::ArgumentValidation( "`validation_workers` must be > 0".to_string(), @@ -1065,6 +1052,21 @@ fn main() -> Result<(), LauncherError> { tracing::info!("Sharding model on {num_shard} processes"); } + if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens { + if args.max_batch_prefill_tokens > *max_batch_total_tokens { + return Err(LauncherError::ArgumentValidation(format!( + "`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}", + args.max_batch_prefill_tokens, max_batch_total_tokens + ))); + } + if args.max_total_tokens as u32 > *max_batch_total_tokens { + return Err(LauncherError::ArgumentValidation(format!( + "`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}", + args.max_total_tokens, max_batch_total_tokens + ))); + } + } + // Signal handler let running = Arc::new(AtomicBool::new(true)); let r = running.clone(); diff --git a/proto/generate.proto b/proto/generate.proto index 5e06194..57d79bc 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -198,9 +198,10 @@ message DecodeResponse { message WarmupRequest { /// Batch to warmup on Batch batch = 1; - /// Maximum number of tokens that the client will send - uint32 max_total_tokens = 2; } /// Empty response -message WarmupResponse {} +message WarmupResponse { + /// Maximum number of tokens supported by the model + optional uint32 max_supported_total_tokens = 1; +} diff --git a/router/client/src/client.rs b/router/client/src/client.rs index b9607a5..7753f30 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -103,8 +103,7 @@ impl Client { &mut self, max_input_length: u32, max_prefill_tokens: u32, - max_total_tokens: u32, - ) -> Result<()> { + ) -> Result> { let mut n_tokens = 0; let mut requests = Vec::new(); @@ -143,13 +142,9 @@ impl Client { max_tokens: 0, }; - let request = tonic::Request::new(WarmupRequest { - batch: Some(batch), - max_total_tokens, - }) - .inject_context(); - self.stub.warmup(request).await?.into_inner(); - Ok(()) + let request = tonic::Request::new(WarmupRequest { batch: Some(batch) }).inject_context(); + let response = self.stub.warmup(request).await?.into_inner(); + Ok(response.max_supported_total_tokens) } /// Generate one token for each request in the given batch diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs index 9dd173a..6d146bc 100644 --- a/router/client/src/sharded_client.rs +++ b/router/client/src/sharded_client.rs @@ -95,14 +95,11 @@ impl ShardedClient { &mut self, max_input_length: u32, max_prefill_tokens: u32, - max_total_tokens: u32, - ) -> Result<()> { + ) -> Result> { let futures: Vec<_> = self .clients .iter_mut() - .map(|client| { - Box::pin(client.warmup(max_input_length, max_prefill_tokens, max_total_tokens)) - }) + .map(|client| Box::pin(client.warmup(max_input_length, max_prefill_tokens))) .collect(); // all shards return the same message join_all(futures).await.pop().unwrap() diff --git a/router/src/infer.rs b/router/src/infer.rs index d0d22d3..188ddc6 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -53,7 +53,7 @@ impl Infer { generation_health: Arc, ) -> Self { // Infer shared state - let queue = Queue::new(requires_padding); + let queue = Queue::new(requires_padding, 16); let shared = Arc::new(Shared { batching_task: Notify::new(), }); diff --git a/router/src/main.rs b/router/src/main.rs index 178c249..5aef03d 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -37,8 +37,8 @@ struct Args { waiting_served_ratio: f32, #[clap(default_value = "4096", long, env)] max_batch_prefill_tokens: u32, - #[clap(default_value = "16000", long, env)] - max_batch_total_tokens: u32, + #[clap(long, env)] + max_batch_total_tokens: Option, #[clap(default_value = "20", long, env)] max_waiting_tokens: usize, #[clap(default_value = "0.0.0.0", long, env)] @@ -110,18 +110,22 @@ fn main() -> Result<(), RouterError> { if max_input_length as u32 > max_batch_prefill_tokens { return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be >= `max_input_length`. Given: {max_batch_prefill_tokens} and {max_input_length}"))); } - if max_batch_prefill_tokens > max_batch_total_tokens { - return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}"))); - } - if max_total_tokens as u32 > max_batch_total_tokens { - return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}"))); - } + if validation_workers == 0 { return Err(RouterError::ArgumentValidation( "`validation_workers` must be > 0".to_string(), )); } + if let Some(ref max_batch_total_tokens) = max_batch_total_tokens { + if max_batch_prefill_tokens > *max_batch_total_tokens { + return Err(RouterError::ArgumentValidation(format!("`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens}"))); + } + if max_total_tokens as u32 > *max_batch_total_tokens { + return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens}"))); + } + } + // CORS allowed origins // map to go inside the option and then map to parse from String to HeaderValue // Finally, convert to AllowOrigin @@ -210,14 +214,35 @@ fn main() -> Result<(), RouterError> { // Warmup model tracing::info!("Warming up model"); - sharded_client - .warmup( - max_input_length as u32, - max_batch_prefill_tokens, - max_batch_total_tokens, - ) + let max_supported_batch_total_tokens = match sharded_client + .warmup(max_input_length as u32, max_batch_prefill_tokens) .await - .map_err(RouterError::Warmup)?; + .map_err(RouterError::Warmup)? + { + // Older models do not support automatic max-batch-total-tokens + None => { + let max_batch_total_tokens = max_batch_total_tokens.unwrap_or( + 16000.max((max_total_tokens as u32).max(max_batch_prefill_tokens)), + ); + tracing::warn!("Model does not support automatic max batch total tokens"); + max_batch_total_tokens + } + // Flash attention models return their max supported total tokens + Some(max_supported_batch_total_tokens) => { + // Warn if user added his own max-batch-total-tokens as we will ignore it + if max_batch_total_tokens.is_some() { + tracing::warn!( + "`--max-batch-total-tokens` is deprecated for Flash \ + Attention models." + ); + tracing::warn!( + "Inferred max batch total tokens: {max_supported_batch_total_tokens}" + ); + } + max_supported_batch_total_tokens + } + }; + tracing::info!("Setting max batch total tokens to {max_supported_batch_total_tokens}"); tracing::info!("Connected"); let addr = match hostname.parse() { @@ -240,7 +265,7 @@ fn main() -> Result<(), RouterError> { max_total_tokens, waiting_served_ratio, max_batch_prefill_tokens, - max_batch_total_tokens, + max_supported_batch_total_tokens, max_waiting_tokens, sharded_client, tokenizer, diff --git a/router/src/queue.rs b/router/src/queue.rs index 48e483a..2d8d6d1 100644 --- a/router/src/queue.rs +++ b/router/src/queue.rs @@ -33,12 +33,12 @@ pub(crate) struct Queue { } impl Queue { - pub(crate) fn new(requires_padding: bool) -> Self { + pub(crate) fn new(requires_padding: bool, block_size: u32) -> Self { // Create channel let (queue_sender, queue_receiver) = flume::unbounded(); // Launch background queue task - tokio::spawn(queue_task(requires_padding, queue_receiver)); + tokio::spawn(queue_task(requires_padding, block_size, queue_receiver)); Self { queue_sender } } @@ -81,8 +81,12 @@ impl Queue { } // Background task responsible of the queue state -async fn queue_task(requires_padding: bool, receiver: flume::Receiver) { - let mut state = State::new(requires_padding); +async fn queue_task( + requires_padding: bool, + block_size: u32, + receiver: flume::Receiver, +) { + let mut state = State::new(requires_padding, block_size); while let Ok(cmd) = receiver.recv_async().await { match cmd { @@ -119,15 +123,19 @@ struct State { /// Whether the model is using padding requires_padding: bool, + + /// Paged Attention block size + block_size: u32, } impl State { - fn new(requires_padding: bool) -> Self { + fn new(requires_padding: bool, block_size: u32) -> Self { Self { entries: VecDeque::with_capacity(128), next_id: 0, next_batch_id: 0, requires_padding, + block_size, } } @@ -187,10 +195,21 @@ impl State { max_input_length = max_input_length.max(entry.request.input_length); prefill_tokens = (batch_requests.len() + 1) as u32 * max_input_length } else { - prefill_tokens += entry.request.input_length; + // pad to block size + prefill_tokens += ((entry.request.input_length + self.block_size - 1) + / self.block_size) + * self.block_size; } - decode_tokens += entry.request.stopping_parameters.max_new_tokens; + if self.requires_padding { + decode_tokens += entry.request.stopping_parameters.max_new_tokens; + } else { + // pad to block size + decode_tokens += + ((entry.request.stopping_parameters.max_new_tokens + self.block_size - 1) + / self.block_size) + * self.block_size; + } if prefill_tokens > prefill_token_budget || (prefill_tokens + decode_tokens) > token_budget @@ -321,7 +340,7 @@ mod tests { #[test] fn test_append() { - let mut state = State::new(false); + let mut state = State::new(false, 1); let (entry, _guard) = default_entry(); assert_eq!(state.next_id, 0); @@ -337,7 +356,7 @@ mod tests { #[test] fn test_next_batch_empty() { - let mut state = State::new(false); + let mut state = State::new(false, 1); assert!(state.next_batch(None, 1, 1).is_none()); assert!(state.next_batch(Some(1), 1, 1).is_none()); @@ -345,7 +364,7 @@ mod tests { #[test] fn test_next_batch_min_size() { - let mut state = State::new(false); + let mut state = State::new(false, 1); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); @@ -377,7 +396,7 @@ mod tests { #[test] fn test_next_batch_token_budget() { - let mut state = State::new(false); + let mut state = State::new(false, 1); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); state.append(entry1); @@ -410,14 +429,14 @@ mod tests { #[tokio::test] async fn test_queue_append() { - let queue = Queue::new(false); + let queue = Queue::new(false, 1); let (entry, _guard) = default_entry(); queue.append(entry); } #[tokio::test] async fn test_queue_next_batch_empty() { - let queue = Queue::new(false); + let queue = Queue::new(false, 1); assert!(queue.next_batch(None, 1, 1).await.is_none()); assert!(queue.next_batch(Some(1), 1, 1).await.is_none()); @@ -425,7 +444,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_min_size() { - let queue = Queue::new(false); + let queue = Queue::new(false, 1); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -458,7 +477,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_token_budget() { - let queue = Queue::new(false); + let queue = Queue::new(false, 1); let (entry1, _guard1) = default_entry(); let (entry2, _guard2) = default_entry(); queue.append(entry1); @@ -483,7 +502,7 @@ mod tests { #[tokio::test] async fn test_queue_next_batch_dropped_receiver() { - let queue = Queue::new(false); + let queue = Queue::new(false, 1); let (entry, _) = default_entry(); queue.append(entry); diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index d034d47..517fba6 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -710,14 +710,14 @@ class FlashCausalLM(Model): def batch_type(self) -> Type[FlashCausalLMBatch]: return FlashCausalLMBatch - def warmup(self, batch: FlashCausalLMBatch, max_total_tokens: int): + def warmup(self, batch: FlashCausalLMBatch): global CACHE_MANAGER torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats(self.device) try: CACHE_MANAGER = CacheManager( - # Adds some wiggle room - math.ceil(max_total_tokens / BLOCK_SIZE) + 10, + batch.blocks, self.num_layers, self.num_kv_heads, self.head_size, @@ -727,11 +727,43 @@ class FlashCausalLM(Model): _, batch = self.generate_token(batch) except Exception as e: raise RuntimeError( - f"Not enough memory to handle {max_total_tokens} total tokens with {len(batch.input_ids)} " - f"prefill tokens. " - f"You need to decrease `--max-batch-total-tokens` or `--max-batch-prefill-tokens`" + f"Not enough memory to handle {len(batch.input_ids)} prefill tokens. " + f"You need to decrease `--max-batch-prefill-tokens`" ) from e + + # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm) + # Calculate the number of blocks that can be allocated with the + # profiled peak memory. + torch.cuda.synchronize(self.device) + peak_memory = torch.cuda.max_memory_reserved(self.device) + + dtype_size = torch.tensor([], dtype=self.dtype).element_size() + cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size + total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size + + total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory + + # 0.98 to add some wiggle room + num_blocks = ( + int((total_gpu_memory * 0.98 - peak_memory) // total_cache_size) + # Add batch.blocks as we allocated it above, so it is included in the peak memory. + + batch.blocks + ) + + del CACHE_MANAGER del batch + torch.cuda.empty_cache() + + CACHE_MANAGER = CacheManager( + num_blocks, + self.num_layers, + self.num_kv_heads, + self.head_size, + self.dtype, + self.device, + ) + + return int(num_blocks * BLOCK_SIZE) def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str: return self.tokenizer.decode( @@ -991,7 +1023,6 @@ class FlashCausalLM(Model): if stopped: del batch - torch.cuda.empty_cache() # No need to return a batch if we know that all requests stopped return generations, None diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index f8460fc..3827197 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -58,8 +58,9 @@ class Model(ABC): def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]: raise NotImplementedError - def warmup(self, batch: B, max_total_tokens: int): + def warmup(self, batch: B) -> Optional[int]: self.generate_token(batch) + return None def decode_token( self, diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 7bc62ce..e0efbcf 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -51,21 +51,17 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): filtered_batch = batch.filter(request.request_ids) self.cache.set(filtered_batch) - if torch.cuda.is_available(): - torch.cuda.empty_cache() - return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) async def Warmup(self, request, context): batch = self.model.batch_type.from_pb( request.batch, self.model.tokenizer, self.model.dtype, self.model.device ) - self.model.warmup(batch, request.max_total_tokens) + max_supported_total_tokens = self.model.warmup(batch) - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - return generate_pb2.WarmupResponse() + return generate_pb2.WarmupResponse( + max_supported_total_tokens=max_supported_total_tokens + ) async def Prefill(self, request, context): batch = self.model.batch_type.from_pb( @@ -96,8 +92,6 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): if len(batches) > 1: batch = self.model.batch_type.concatenate(batches) - if torch.cuda.is_available(): - torch.cuda.empty_cache() else: batch = batches[0]