From 7ed1044585b8fce59facf5404b02919d29e935de Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Tue, 18 Jun 2024 12:18:05 +0200 Subject: [PATCH] added padded blocks and logs everywhere --- benchmark/src/app.rs | 12 +++--- benchmark/src/table.rs | 6 +-- benchmark/src/utils.rs | 2 +- proto/v3/generate.proto | 2 + router/src/infer/v3/block_allocator.rs | 2 +- router/src/infer/v3/scheduler.rs | 39 +++++++++++++++++-- rust-toolchain.toml | 6 +-- .../models/flash_causal_lm.py | 29 ++++++++++---- 8 files changed, 73 insertions(+), 25 deletions(-) diff --git a/benchmark/src/app.rs b/benchmark/src/app.rs index 48ac976a..a0a9313a 100644 --- a/benchmark/src/app.rs +++ b/benchmark/src/app.rs @@ -497,7 +497,7 @@ fn statis_spans<'a>(data: &[f64], unit: &'static str) -> Vec> { "Lowest: {:.2} {unit}", data.iter() .min_by(|a, b| a.total_cmp(b)) - .unwrap_or(&std::f64::NAN) + .unwrap_or(&f64::NAN) ), Style::default().fg(Color::Reset), )]), @@ -506,7 +506,7 @@ fn statis_spans<'a>(data: &[f64], unit: &'static str) -> Vec> { "Highest: {:.2} {unit}", data.iter() .max_by(|a, b| a.total_cmp(b)) - .unwrap_or(&std::f64::NAN) + .unwrap_or(&f64::NAN) ), Style::default().fg(Color::Reset), )]), @@ -555,17 +555,17 @@ fn latency_throughput_chart<'a>( let min_latency: f64 = *latency_iter .clone() .min_by(|a, b| a.total_cmp(b)) - .unwrap_or(&std::f64::NAN); + .unwrap_or(&f64::NAN); let max_latency: f64 = *latency_iter .max_by(|a, b| a.total_cmp(b)) - .unwrap_or(&std::f64::NAN); + .unwrap_or(&f64::NAN); let min_throughput: f64 = *throughput_iter .clone() .min_by(|a, b| a.total_cmp(b)) - .unwrap_or(&std::f64::NAN); + .unwrap_or(&f64::NAN); let max_throughput: f64 = *throughput_iter .max_by(|a, b| a.total_cmp(b)) - .unwrap_or(&std::f64::NAN); + .unwrap_or(&f64::NAN); // Char min max values let min_x = if zoom { diff --git a/benchmark/src/table.rs b/benchmark/src/table.rs index e18d7310..1585a25f 100644 --- a/benchmark/src/table.rs +++ b/benchmark/src/table.rs @@ -156,17 +156,17 @@ fn avg_min_max(data: &[f64]) -> (f64, f64, f64) { let min = data .iter() .min_by(|a, b| a.total_cmp(b)) - .unwrap_or(&std::f64::NAN); + .unwrap_or(&f64::NAN); let max = data .iter() .max_by(|a, b| a.total_cmp(b)) - .unwrap_or(&std::f64::NAN); + .unwrap_or(&f64::NAN); (average, *min, *max) } fn px(data: &[f64], p: u32) -> f64 { let i = (f64::from(p) / 100.0 * data.len() as f64) as usize; - *data.get(i).unwrap_or(&std::f64::NAN) + *data.get(i).unwrap_or(&f64::NAN) } fn format_value(value: f64, unit: &'static str) -> String { diff --git a/benchmark/src/utils.rs b/benchmark/src/utils.rs index d096d655..20469991 100644 --- a/benchmark/src/utils.rs +++ b/benchmark/src/utils.rs @@ -37,7 +37,7 @@ pub(crate) fn percentiles(values: &[f64], pecents: &[i32]) -> BTreeMap { // Filter next batch: remove requests that were stopped and update blocks/slots let (filtered_batch, terminated_generations) = filter_batch(client, batch, entries, &terminated_entries).await; + tracing::info!("filter_batch: {:?}", start_filtering_time.elapsed()); send_terminated_generations( terminated_generations, terminated_entries, filtered_stream_responses, ); + tracing::info!("send_terminated: {:?}", start_filtering_time.elapsed()); filtered_batch } @@ -379,23 +386,49 @@ async fn filter_batch( client.clear_cache(Some(id)).await.unwrap(); Default::default() } else { - // Collect new blocks/slots + let max_blocks = entries + .iter() + .map(|(_, entry)| { + entry + .block_allocation + .as_ref() + .map(|alloc| alloc.blocks().len()) + }) + .max() + .flatten(); + + let start_time = Instant::now(); + + // Collect new blocks let updated_requests = entries .iter() .map(|(request_id, entry)| { - let blocks = entry + let (blocks, padded_blocks) = entry .block_allocation .as_ref() - .map(|alloc| alloc.blocks().to_vec()) + .map(|alloc| { + let max_blocks = match max_blocks { + Some(max_blocks) => max_blocks, + _ => unreachable!(), + }; + + let blocks = alloc.blocks().to_vec(); + let mut padded_blocks = blocks.clone(); + padded_blocks.resize(max_blocks - padded_blocks.len(), 0); + (blocks, padded_blocks) + }) .unwrap_or_default(); KeptRequest { id: *request_id, blocks, + padded_blocks, } }) .collect(); + tracing::info!("Collect blocks: {:?}", start_time.elapsed()); + // Filter Python shard cache // We unwrap here as we need to panic since we cannot recover if this method fails client diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 507ee859..83f9a5b0 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,5 +1,5 @@ [toolchain] -# Released on: 02 May, 2024 -# https://releases.rs/docs/1.78.0/ -channel = "1.78.0" +# Released on: 13 June, 2024 +# https://releases.rs/docs/1.79.0/ +channel = "1.79.0" components = ["rustfmt", "clippy"] diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 47963aba..1182f3d4 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -403,6 +403,8 @@ class FlashCausalLMBatch(Batch): kept_requests: List[generate_pb2.KeptRequest], terminated_request_ids: List[int], ) -> Tuple[Optional["FlashCausalLMBatch"], List[generate_pb2.TerminatedGeneration]]: + start = time.time_ns() + terminated_generations = [] for request_id in terminated_request_ids: idx = self.requests_idx_mapping[request_id] @@ -429,6 +431,11 @@ class FlashCausalLMBatch(Batch): ), ) ) + + from loguru import logger + + logger.info(f"terminated generations {(time.time_ns() - start)/1e6}") + if not kept_requests: return None, terminated_generations @@ -445,7 +452,7 @@ class FlashCausalLMBatch(Batch): requests = [] flat_blocks = [] - block_tables = [] + padded_blocks = [] all_input_ids = [] input_lengths = [] @@ -483,8 +490,8 @@ class FlashCausalLMBatch(Batch): top_n_tokens.append(self.top_n_tokens[idx]) request_block_table = request.blocks - block_tables.append(request_block_table) flat_blocks.extend(request_block_table) + padded_blocks.extend(request.padded_blocks) # Index slot_indices.append((num_blocks * BLOCK_SIZE) + request_input_length - 1) @@ -492,6 +499,8 @@ class FlashCausalLMBatch(Batch): num_blocks += len(request_block_table) max_blocks = max(max_blocks, len(request_block_table)) + logger.info(f"for loop requests: {(time.time_ns() - start)/1e6}") + # Index into tensors input_ids = self.input_ids[indices] position_ids = self.position_ids[indices] @@ -503,12 +512,14 @@ class FlashCausalLMBatch(Batch): self.speculative_ids[indices] if self.speculative_ids is not None else None ) - # Create block_tables_tensor on CPU - block_tables_tensor = torch.zeros( - (len(block_tables), max_blocks), dtype=torch.int32, device="cpu" - ) - for i, request_blocks in enumerate(block_tables): - block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks) + logger.info(f"slice objects: {(time.time_ns() - start)/1e6}") + + # Create block_tables_tensor on GPU + block_tables_tensor = torch.tensor( + padded_blocks, dtype=torch.int32, device=device + ).view(len(requests), -1) + + logger.info(f"allocate block table: {(time.time_ns() - start)/1e6}") # Allocate on GPU slot_indices = torch.tensor(slot_indices, dtype=torch.int64, device=device) @@ -522,6 +533,8 @@ class FlashCausalLMBatch(Batch): + torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int64) ).flatten() + logger.info(f"done allocation: {(time.time_ns() - start)/1e6}") + filtered_batch = type(self)( batch_id=self.batch_id, requests=requests,