From 7de8a377b067af5d9133874b88f5b0a37452a5eb Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 26 Apr 2023 00:54:27 +0200 Subject: [PATCH] fix(benchmarking): fix benchmarking tool --- benchmark/src/generation.rs | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/benchmark/src/generation.rs b/benchmark/src/generation.rs index dde429a5..4a119e86 100644 --- a/benchmark/src/generation.rs +++ b/benchmark/src/generation.rs @@ -74,16 +74,28 @@ async fn generate_runs( for b in batch_size { // Warmups on batch size for _ in 0..warmups { - let (_, decode_batch) = - prefill(sequence.clone(), sequence_length, b, decode_length, &mut client).await?; + let (_, decode_batch) = prefill( + sequence.clone(), + sequence_length, + b, + decode_length, + &mut client, + ) + .await?; let _ = decode(decode_batch, &mut client).await?; // Send warmup message run_sender.send(Ok(Message::Warmup)).await.unwrap_or(()); } for _ in 0..n_runs { - let (prefill, decode_batch) = - prefill(sequence.clone(), sequence_length, b, decode_length, &mut client).await?; + let (prefill, decode_batch) = prefill( + sequence.clone(), + sequence_length, + b, + decode_length, + &mut client, + ) + .await?; // Send prefill message run_sender .send(Ok(Message::Prefill(prefill))) @@ -143,6 +155,7 @@ async fn prefill( id: 0, requests, size: batch_size, + max_tokens: batch_size * (sequence_length + decode_length), }; // Run prefill