added padded blocks and logs everywhere

This commit is contained in:
OlivierDehaene 2024-06-18 12:18:05 +02:00
parent abe521204e
commit 7ed1044585
8 changed files with 73 additions and 25 deletions

View File

@ -497,7 +497,7 @@ fn statis_spans<'a>(data: &[f64], unit: &'static str) -> Vec<Line<'a>> {
"Lowest: {:.2} {unit}",
data.iter()
.min_by(|a, b| a.total_cmp(b))
.unwrap_or(&std::f64::NAN)
.unwrap_or(&f64::NAN)
),
Style::default().fg(Color::Reset),
)]),
@ -506,7 +506,7 @@ fn statis_spans<'a>(data: &[f64], unit: &'static str) -> Vec<Line<'a>> {
"Highest: {:.2} {unit}",
data.iter()
.max_by(|a, b| a.total_cmp(b))
.unwrap_or(&std::f64::NAN)
.unwrap_or(&f64::NAN)
),
Style::default().fg(Color::Reset),
)]),
@ -555,17 +555,17 @@ fn latency_throughput_chart<'a>(
let min_latency: f64 = *latency_iter
.clone()
.min_by(|a, b| a.total_cmp(b))
.unwrap_or(&std::f64::NAN);
.unwrap_or(&f64::NAN);
let max_latency: f64 = *latency_iter
.max_by(|a, b| a.total_cmp(b))
.unwrap_or(&std::f64::NAN);
.unwrap_or(&f64::NAN);
let min_throughput: f64 = *throughput_iter
.clone()
.min_by(|a, b| a.total_cmp(b))
.unwrap_or(&std::f64::NAN);
.unwrap_or(&f64::NAN);
let max_throughput: f64 = *throughput_iter
.max_by(|a, b| a.total_cmp(b))
.unwrap_or(&std::f64::NAN);
.unwrap_or(&f64::NAN);
// Char min max values
let min_x = if zoom {

View File

@ -156,17 +156,17 @@ fn avg_min_max(data: &[f64]) -> (f64, f64, f64) {
let min = data
.iter()
.min_by(|a, b| a.total_cmp(b))
.unwrap_or(&std::f64::NAN);
.unwrap_or(&f64::NAN);
let max = data
.iter()
.max_by(|a, b| a.total_cmp(b))
.unwrap_or(&std::f64::NAN);
.unwrap_or(&f64::NAN);
(average, *min, *max)
}
fn px(data: &[f64], p: u32) -> f64 {
let i = (f64::from(p) / 100.0 * data.len() as f64) as usize;
*data.get(i).unwrap_or(&std::f64::NAN)
*data.get(i).unwrap_or(&f64::NAN)
}
fn format_value(value: f64, unit: &'static str) -> String {

View File

@ -37,7 +37,7 @@ pub(crate) fn percentiles(values: &[f64], pecents: &[i32]) -> BTreeMap<String, f
.iter()
.map(|&p| {
let i = (f64::from(p) / 100.0 * values.len() as f64) as usize;
(format!("p{p}"), *values.get(i).unwrap_or(&std::f64::NAN))
(format!("p{p}"), *values.get(i).unwrap_or(&f64::NAN))
})
.collect()
}

View File

@ -206,6 +206,8 @@ message KeptRequest {
uint64 id = 1;
/// Paged attention blocks
repeated uint32 blocks = 2;
/// Paged attention blocks padded to max blocks for this batch
repeated uint32 padded_blocks = 3;
}
/// kept_requests + terminated_request_ids might not cover all requests from the

View File

@ -32,7 +32,7 @@ impl BlockAllocation {
self.required_blocks,
),
};
let remaining_blocks = required_blocks - self.allocated_blocks.len();
let remaining_blocks = required_blocks.saturating_sub(self.allocated_blocks.len());
let new_blocks = min(remaining_blocks, 16);
// Try to allocate all remaining blocks

View File

@ -314,6 +314,9 @@ async fn decode(
// Filter and send finished generations
let mut filtered_stream_responses = filter_send_ended_generations(generations, entries);
tracing::info!("filtered_stream: {:?}", start_filtering_time.elapsed());
// Send `StreamResponseInfer::Intermediate` messages for entries that don't need to be
// re-allocated,
// Allocated new blocks for entries that go over their allocation
@ -321,17 +324,21 @@ async fn decode(
let (force_update, terminated_entries) =
filter_send_update_allocations(entries, &mut filtered_stream_responses);
tracing::info!("filtered_update: {:?}", start_filtering_time.elapsed());
let next_batch = match next_batch {
// Run Only on re-allocation or if entries were filtered
Some(batch) if batch.size as usize != entries.len() || force_update => {
// Filter next batch: remove requests that were stopped and update blocks/slots
let (filtered_batch, terminated_generations) =
filter_batch(client, batch, entries, &terminated_entries).await;
tracing::info!("filter_batch: {:?}", start_filtering_time.elapsed());
send_terminated_generations(
terminated_generations,
terminated_entries,
filtered_stream_responses,
);
tracing::info!("send_terminated: {:?}", start_filtering_time.elapsed());
filtered_batch
}
@ -379,23 +386,49 @@ async fn filter_batch(
client.clear_cache(Some(id)).await.unwrap();
Default::default()
} else {
// Collect new blocks/slots
let max_blocks = entries
.iter()
.map(|(_, entry)| {
entry
.block_allocation
.as_ref()
.map(|alloc| alloc.blocks().len())
})
.max()
.flatten();
let start_time = Instant::now();
// Collect new blocks
let updated_requests = entries
.iter()
.map(|(request_id, entry)| {
let blocks = entry
let (blocks, padded_blocks) = entry
.block_allocation
.as_ref()
.map(|alloc| alloc.blocks().to_vec())
.map(|alloc| {
let max_blocks = match max_blocks {
Some(max_blocks) => max_blocks,
_ => unreachable!(),
};
let blocks = alloc.blocks().to_vec();
let mut padded_blocks = blocks.clone();
padded_blocks.resize(max_blocks - padded_blocks.len(), 0);
(blocks, padded_blocks)
})
.unwrap_or_default();
KeptRequest {
id: *request_id,
blocks,
padded_blocks,
}
})
.collect();
tracing::info!("Collect blocks: {:?}", start_time.elapsed());
// Filter Python shard cache
// We unwrap here as we need to panic since we cannot recover if this method fails
client

View File

@ -1,5 +1,5 @@
[toolchain]
# Released on: 02 May, 2024
# https://releases.rs/docs/1.78.0/
channel = "1.78.0"
# Released on: 13 June, 2024
# https://releases.rs/docs/1.79.0/
channel = "1.79.0"
components = ["rustfmt", "clippy"]

View File

@ -403,6 +403,8 @@ class FlashCausalLMBatch(Batch):
kept_requests: List[generate_pb2.KeptRequest],
terminated_request_ids: List[int],
) -> Tuple[Optional["FlashCausalLMBatch"], List[generate_pb2.TerminatedGeneration]]:
start = time.time_ns()
terminated_generations = []
for request_id in terminated_request_ids:
idx = self.requests_idx_mapping[request_id]
@ -429,6 +431,11 @@ class FlashCausalLMBatch(Batch):
),
)
)
from loguru import logger
logger.info(f"terminated generations {(time.time_ns() - start)/1e6}")
if not kept_requests:
return None, terminated_generations
@ -445,7 +452,7 @@ class FlashCausalLMBatch(Batch):
requests = []
flat_blocks = []
block_tables = []
padded_blocks = []
all_input_ids = []
input_lengths = []
@ -483,8 +490,8 @@ class FlashCausalLMBatch(Batch):
top_n_tokens.append(self.top_n_tokens[idx])
request_block_table = request.blocks
block_tables.append(request_block_table)
flat_blocks.extend(request_block_table)
padded_blocks.extend(request.padded_blocks)
# Index
slot_indices.append((num_blocks * BLOCK_SIZE) + request_input_length - 1)
@ -492,6 +499,8 @@ class FlashCausalLMBatch(Batch):
num_blocks += len(request_block_table)
max_blocks = max(max_blocks, len(request_block_table))
logger.info(f"for loop requests: {(time.time_ns() - start)/1e6}")
# Index into tensors
input_ids = self.input_ids[indices]
position_ids = self.position_ids[indices]
@ -503,12 +512,14 @@ class FlashCausalLMBatch(Batch):
self.speculative_ids[indices] if self.speculative_ids is not None else None
)
# Create block_tables_tensor on CPU
block_tables_tensor = torch.zeros(
(len(block_tables), max_blocks), dtype=torch.int32, device="cpu"
)
for i, request_blocks in enumerate(block_tables):
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
logger.info(f"slice objects: {(time.time_ns() - start)/1e6}")
# Create block_tables_tensor on GPU
block_tables_tensor = torch.tensor(
padded_blocks, dtype=torch.int32, device=device
).view(len(requests), -1)
logger.info(f"allocate block table: {(time.time_ns() - start)/1e6}")
# Allocate on GPU
slot_indices = torch.tensor(slot_indices, dtype=torch.int64, device=device)
@ -522,6 +533,8 @@ class FlashCausalLMBatch(Batch):
+ torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int64)
).flatten()
logger.info(f"done allocation: {(time.time_ns() - start)/1e6}")
filtered_batch = type(self)(
batch_id=self.batch_id,
requests=requests,