diff --git a/benchmark/src/generation.rs b/benchmark/src/generation.rs index dde429a5..4a119e86 100644 --- a/benchmark/src/generation.rs +++ b/benchmark/src/generation.rs @@ -74,16 +74,28 @@ async fn generate_runs( for b in batch_size { // Warmups on batch size for _ in 0..warmups { - let (_, decode_batch) = - prefill(sequence.clone(), sequence_length, b, decode_length, &mut client).await?; + let (_, decode_batch) = prefill( + sequence.clone(), + sequence_length, + b, + decode_length, + &mut client, + ) + .await?; let _ = decode(decode_batch, &mut client).await?; // Send warmup message run_sender.send(Ok(Message::Warmup)).await.unwrap_or(()); } for _ in 0..n_runs { - let (prefill, decode_batch) = - prefill(sequence.clone(), sequence_length, b, decode_length, &mut client).await?; + let (prefill, decode_batch) = prefill( + sequence.clone(), + sequence_length, + b, + decode_length, + &mut client, + ) + .await?; // Send prefill message run_sender .send(Ok(Message::Prefill(prefill))) @@ -143,6 +155,7 @@ async fn prefill( id: 0, requests, size: batch_size, + max_tokens: batch_size * (sequence_length + decode_length), }; // Run prefill