added padded blocks and logs everywhere
This commit is contained in:
parent
abe521204e
commit
7ed1044585
|
@ -497,7 +497,7 @@ fn statis_spans<'a>(data: &[f64], unit: &'static str) -> Vec<Line<'a>> {
|
|||
"Lowest: {:.2} {unit}",
|
||||
data.iter()
|
||||
.min_by(|a, b| a.total_cmp(b))
|
||||
.unwrap_or(&std::f64::NAN)
|
||||
.unwrap_or(&f64::NAN)
|
||||
),
|
||||
Style::default().fg(Color::Reset),
|
||||
)]),
|
||||
|
@ -506,7 +506,7 @@ fn statis_spans<'a>(data: &[f64], unit: &'static str) -> Vec<Line<'a>> {
|
|||
"Highest: {:.2} {unit}",
|
||||
data.iter()
|
||||
.max_by(|a, b| a.total_cmp(b))
|
||||
.unwrap_or(&std::f64::NAN)
|
||||
.unwrap_or(&f64::NAN)
|
||||
),
|
||||
Style::default().fg(Color::Reset),
|
||||
)]),
|
||||
|
@ -555,17 +555,17 @@ fn latency_throughput_chart<'a>(
|
|||
let min_latency: f64 = *latency_iter
|
||||
.clone()
|
||||
.min_by(|a, b| a.total_cmp(b))
|
||||
.unwrap_or(&std::f64::NAN);
|
||||
.unwrap_or(&f64::NAN);
|
||||
let max_latency: f64 = *latency_iter
|
||||
.max_by(|a, b| a.total_cmp(b))
|
||||
.unwrap_or(&std::f64::NAN);
|
||||
.unwrap_or(&f64::NAN);
|
||||
let min_throughput: f64 = *throughput_iter
|
||||
.clone()
|
||||
.min_by(|a, b| a.total_cmp(b))
|
||||
.unwrap_or(&std::f64::NAN);
|
||||
.unwrap_or(&f64::NAN);
|
||||
let max_throughput: f64 = *throughput_iter
|
||||
.max_by(|a, b| a.total_cmp(b))
|
||||
.unwrap_or(&std::f64::NAN);
|
||||
.unwrap_or(&f64::NAN);
|
||||
|
||||
// Char min max values
|
||||
let min_x = if zoom {
|
||||
|
|
|
@ -156,17 +156,17 @@ fn avg_min_max(data: &[f64]) -> (f64, f64, f64) {
|
|||
let min = data
|
||||
.iter()
|
||||
.min_by(|a, b| a.total_cmp(b))
|
||||
.unwrap_or(&std::f64::NAN);
|
||||
.unwrap_or(&f64::NAN);
|
||||
let max = data
|
||||
.iter()
|
||||
.max_by(|a, b| a.total_cmp(b))
|
||||
.unwrap_or(&std::f64::NAN);
|
||||
.unwrap_or(&f64::NAN);
|
||||
(average, *min, *max)
|
||||
}
|
||||
|
||||
fn px(data: &[f64], p: u32) -> f64 {
|
||||
let i = (f64::from(p) / 100.0 * data.len() as f64) as usize;
|
||||
*data.get(i).unwrap_or(&std::f64::NAN)
|
||||
*data.get(i).unwrap_or(&f64::NAN)
|
||||
}
|
||||
|
||||
fn format_value(value: f64, unit: &'static str) -> String {
|
||||
|
|
|
@ -37,7 +37,7 @@ pub(crate) fn percentiles(values: &[f64], pecents: &[i32]) -> BTreeMap<String, f
|
|||
.iter()
|
||||
.map(|&p| {
|
||||
let i = (f64::from(p) / 100.0 * values.len() as f64) as usize;
|
||||
(format!("p{p}"), *values.get(i).unwrap_or(&std::f64::NAN))
|
||||
(format!("p{p}"), *values.get(i).unwrap_or(&f64::NAN))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
|
|
@ -206,6 +206,8 @@ message KeptRequest {
|
|||
uint64 id = 1;
|
||||
/// Paged attention blocks
|
||||
repeated uint32 blocks = 2;
|
||||
/// Paged attention blocks padded to max blocks for this batch
|
||||
repeated uint32 padded_blocks = 3;
|
||||
}
|
||||
|
||||
/// kept_requests + terminated_request_ids might not cover all requests from the
|
||||
|
|
|
@ -32,7 +32,7 @@ impl BlockAllocation {
|
|||
self.required_blocks,
|
||||
),
|
||||
};
|
||||
let remaining_blocks = required_blocks - self.allocated_blocks.len();
|
||||
let remaining_blocks = required_blocks.saturating_sub(self.allocated_blocks.len());
|
||||
let new_blocks = min(remaining_blocks, 16);
|
||||
|
||||
// Try to allocate all remaining blocks
|
||||
|
|
|
@ -314,6 +314,9 @@ async fn decode(
|
|||
|
||||
// Filter and send finished generations
|
||||
let mut filtered_stream_responses = filter_send_ended_generations(generations, entries);
|
||||
|
||||
tracing::info!("filtered_stream: {:?}", start_filtering_time.elapsed());
|
||||
|
||||
// Send `StreamResponseInfer::Intermediate` messages for entries that don't need to be
|
||||
// re-allocated,
|
||||
// Allocated new blocks for entries that go over their allocation
|
||||
|
@ -321,17 +324,21 @@ async fn decode(
|
|||
let (force_update, terminated_entries) =
|
||||
filter_send_update_allocations(entries, &mut filtered_stream_responses);
|
||||
|
||||
tracing::info!("filtered_update: {:?}", start_filtering_time.elapsed());
|
||||
|
||||
let next_batch = match next_batch {
|
||||
// Run Only on re-allocation or if entries were filtered
|
||||
Some(batch) if batch.size as usize != entries.len() || force_update => {
|
||||
// Filter next batch: remove requests that were stopped and update blocks/slots
|
||||
let (filtered_batch, terminated_generations) =
|
||||
filter_batch(client, batch, entries, &terminated_entries).await;
|
||||
tracing::info!("filter_batch: {:?}", start_filtering_time.elapsed());
|
||||
send_terminated_generations(
|
||||
terminated_generations,
|
||||
terminated_entries,
|
||||
filtered_stream_responses,
|
||||
);
|
||||
tracing::info!("send_terminated: {:?}", start_filtering_time.elapsed());
|
||||
|
||||
filtered_batch
|
||||
}
|
||||
|
@ -379,23 +386,49 @@ async fn filter_batch(
|
|||
client.clear_cache(Some(id)).await.unwrap();
|
||||
Default::default()
|
||||
} else {
|
||||
// Collect new blocks/slots
|
||||
let max_blocks = entries
|
||||
.iter()
|
||||
.map(|(_, entry)| {
|
||||
entry
|
||||
.block_allocation
|
||||
.as_ref()
|
||||
.map(|alloc| alloc.blocks().len())
|
||||
})
|
||||
.max()
|
||||
.flatten();
|
||||
|
||||
let start_time = Instant::now();
|
||||
|
||||
// Collect new blocks
|
||||
let updated_requests = entries
|
||||
.iter()
|
||||
.map(|(request_id, entry)| {
|
||||
let blocks = entry
|
||||
let (blocks, padded_blocks) = entry
|
||||
.block_allocation
|
||||
.as_ref()
|
||||
.map(|alloc| alloc.blocks().to_vec())
|
||||
.map(|alloc| {
|
||||
let max_blocks = match max_blocks {
|
||||
Some(max_blocks) => max_blocks,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let blocks = alloc.blocks().to_vec();
|
||||
let mut padded_blocks = blocks.clone();
|
||||
padded_blocks.resize(max_blocks - padded_blocks.len(), 0);
|
||||
(blocks, padded_blocks)
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
KeptRequest {
|
||||
id: *request_id,
|
||||
blocks,
|
||||
padded_blocks,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
tracing::info!("Collect blocks: {:?}", start_time.elapsed());
|
||||
|
||||
// Filter Python shard cache
|
||||
// We unwrap here as we need to panic since we cannot recover if this method fails
|
||||
client
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
[toolchain]
|
||||
# Released on: 02 May, 2024
|
||||
# https://releases.rs/docs/1.78.0/
|
||||
channel = "1.78.0"
|
||||
# Released on: 13 June, 2024
|
||||
# https://releases.rs/docs/1.79.0/
|
||||
channel = "1.79.0"
|
||||
components = ["rustfmt", "clippy"]
|
||||
|
|
|
@ -403,6 +403,8 @@ class FlashCausalLMBatch(Batch):
|
|||
kept_requests: List[generate_pb2.KeptRequest],
|
||||
terminated_request_ids: List[int],
|
||||
) -> Tuple[Optional["FlashCausalLMBatch"], List[generate_pb2.TerminatedGeneration]]:
|
||||
start = time.time_ns()
|
||||
|
||||
terminated_generations = []
|
||||
for request_id in terminated_request_ids:
|
||||
idx = self.requests_idx_mapping[request_id]
|
||||
|
@ -429,6 +431,11 @@ class FlashCausalLMBatch(Batch):
|
|||
),
|
||||
)
|
||||
)
|
||||
|
||||
from loguru import logger
|
||||
|
||||
logger.info(f"terminated generations {(time.time_ns() - start)/1e6}")
|
||||
|
||||
if not kept_requests:
|
||||
return None, terminated_generations
|
||||
|
||||
|
@ -445,7 +452,7 @@ class FlashCausalLMBatch(Batch):
|
|||
|
||||
requests = []
|
||||
flat_blocks = []
|
||||
block_tables = []
|
||||
padded_blocks = []
|
||||
all_input_ids = []
|
||||
|
||||
input_lengths = []
|
||||
|
@ -483,8 +490,8 @@ class FlashCausalLMBatch(Batch):
|
|||
top_n_tokens.append(self.top_n_tokens[idx])
|
||||
|
||||
request_block_table = request.blocks
|
||||
block_tables.append(request_block_table)
|
||||
flat_blocks.extend(request_block_table)
|
||||
padded_blocks.extend(request.padded_blocks)
|
||||
|
||||
# Index
|
||||
slot_indices.append((num_blocks * BLOCK_SIZE) + request_input_length - 1)
|
||||
|
@ -492,6 +499,8 @@ class FlashCausalLMBatch(Batch):
|
|||
num_blocks += len(request_block_table)
|
||||
max_blocks = max(max_blocks, len(request_block_table))
|
||||
|
||||
logger.info(f"for loop requests: {(time.time_ns() - start)/1e6}")
|
||||
|
||||
# Index into tensors
|
||||
input_ids = self.input_ids[indices]
|
||||
position_ids = self.position_ids[indices]
|
||||
|
@ -503,12 +512,14 @@ class FlashCausalLMBatch(Batch):
|
|||
self.speculative_ids[indices] if self.speculative_ids is not None else None
|
||||
)
|
||||
|
||||
# Create block_tables_tensor on CPU
|
||||
block_tables_tensor = torch.zeros(
|
||||
(len(block_tables), max_blocks), dtype=torch.int32, device="cpu"
|
||||
)
|
||||
for i, request_blocks in enumerate(block_tables):
|
||||
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
|
||||
logger.info(f"slice objects: {(time.time_ns() - start)/1e6}")
|
||||
|
||||
# Create block_tables_tensor on GPU
|
||||
block_tables_tensor = torch.tensor(
|
||||
padded_blocks, dtype=torch.int32, device=device
|
||||
).view(len(requests), -1)
|
||||
|
||||
logger.info(f"allocate block table: {(time.time_ns() - start)/1e6}")
|
||||
|
||||
# Allocate on GPU
|
||||
slot_indices = torch.tensor(slot_indices, dtype=torch.int64, device=device)
|
||||
|
@ -522,6 +533,8 @@ class FlashCausalLMBatch(Batch):
|
|||
+ torch.arange(0, BLOCK_SIZE, device=device, dtype=torch.int64)
|
||||
).flatten()
|
||||
|
||||
logger.info(f"done allocation: {(time.time_ns() - start)/1e6}")
|
||||
|
||||
filtered_batch = type(self)(
|
||||
batch_id=self.batch_id,
|
||||
requests=requests,
|
||||
|
|
Loading…
Reference in New Issue