feat(benchmark): tui based benchmarking tool (#149)
This commit is contained in:
parent
55106ec476
commit
610bb1f978
|
@ -5,6 +5,9 @@ members = [
|
||||||
"router/grpc-metadata",
|
"router/grpc-metadata",
|
||||||
"launcher"
|
"launcher"
|
||||||
]
|
]
|
||||||
|
exclude = [
|
||||||
|
"benchmark"
|
||||||
|
]
|
||||||
|
|
||||||
[profile.release]
|
[profile.release]
|
||||||
debug = 1
|
debug = 1
|
||||||
|
|
3
Makefile
3
Makefile
|
@ -7,6 +7,9 @@ install-router:
|
||||||
install-launcher:
|
install-launcher:
|
||||||
cd launcher && cargo install --path .
|
cd launcher && cargo install --path .
|
||||||
|
|
||||||
|
install-benchmark:
|
||||||
|
cd benchmark && cargo install --path .
|
||||||
|
|
||||||
install: install-server install-router install-launcher
|
install: install-server install-router install-launcher
|
||||||
|
|
||||||
server-dev:
|
server-dev:
|
||||||
|
|
Binary file not shown.
After Width: | Height: | Size: 102 KiB |
|
@ -0,0 +1 @@
|
||||||
|
target
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,35 @@
|
||||||
|
[package]
|
||||||
|
name = "text-generation-benchmark"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
authors = ["Olivier Dehaene"]
|
||||||
|
description = "Text Generation Benchmarking tool"
|
||||||
|
|
||||||
|
[profile.release]
|
||||||
|
debug = 1
|
||||||
|
incremental = true
|
||||||
|
lto = "off"
|
||||||
|
panic = "abort"
|
||||||
|
|
||||||
|
[lib]
|
||||||
|
path = "src/lib.rs"
|
||||||
|
|
||||||
|
[[bin]]
|
||||||
|
name = "text-generation-benchmark"
|
||||||
|
path = "src/main.rs"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
average = "0.13"
|
||||||
|
clap = { version = "4.1.4", features = ["derive", "env"] }
|
||||||
|
crossterm = "0.26"
|
||||||
|
float-ord = "0.3.2"
|
||||||
|
serde = {version = "1.0.142", features = ["derive"]}
|
||||||
|
serde_json = "1.0"
|
||||||
|
text-generation-client = { path = "../router/client" }
|
||||||
|
thiserror = "1.0.38"
|
||||||
|
tokenizers = "0.13.2"
|
||||||
|
tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
|
||||||
|
tui = {package = "ratatui", version = "0.20", default-features = false, features = ["crossterm"]}
|
||||||
|
tracing = "0.1.37"
|
||||||
|
tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] }
|
||||||
|
|
|
@ -0,0 +1,30 @@
|
||||||
|
<div align="center">
|
||||||
|
|
||||||
|
# Text Generation Inference benchmarking tool
|
||||||
|
|
||||||
|
![benchmark](../assets/benchmark.png)
|
||||||
|
|
||||||
|
</div>
|
||||||
|
|
||||||
|
A lightweight benchmarking tool based inspired by [oha](https://github.com/hatoo/oha)
|
||||||
|
and powered by [tui](https://github.com/tui-rs-revival/ratatui).
|
||||||
|
|
||||||
|
## Install
|
||||||
|
|
||||||
|
```shell
|
||||||
|
make install-benchmark
|
||||||
|
```
|
||||||
|
|
||||||
|
## Run
|
||||||
|
|
||||||
|
First, start `text-generation-inference`:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
text-generation-launcher --model-id bigscience/bloom-560m
|
||||||
|
```
|
||||||
|
|
||||||
|
Then run the benchmarking tool:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
text-generation-benchmark --tokenizer-name bigscience/bloom-560m
|
||||||
|
```
|
|
@ -0,0 +1,3 @@
|
||||||
|
[toolchain]
|
||||||
|
channel = "1.67.0"
|
||||||
|
components = ["rustfmt", "clippy"]
|
|
@ -0,0 +1,688 @@
|
||||||
|
/// Inspired by https://github.com/hatoo/oha/blob/bb989ea3cd77727e7743e7daa60a19894bb5e901/src/monitor.rs
|
||||||
|
use crate::generation::{Decode, Message, Prefill};
|
||||||
|
use crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
|
||||||
|
use text_generation_client::ClientError;
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
use tui::backend::Backend;
|
||||||
|
use tui::layout::{Alignment, Constraint, Direction, Layout};
|
||||||
|
use tui::style::{Color, Modifier, Style};
|
||||||
|
use tui::text::{Span, Spans};
|
||||||
|
use tui::widgets::{
|
||||||
|
Axis, BarChart, Block, Borders, Chart, Dataset, Gauge, GraphType, Paragraph, Tabs,
|
||||||
|
};
|
||||||
|
use tui::{symbols, Frame};
|
||||||
|
|
||||||
|
/// TUI powered App
|
||||||
|
pub(crate) struct App {
|
||||||
|
pub(crate) running: bool,
|
||||||
|
completed_runs: Vec<usize>,
|
||||||
|
completed_batch: usize,
|
||||||
|
current_batch: usize,
|
||||||
|
current_tab: usize,
|
||||||
|
touched_tab: bool,
|
||||||
|
zoom: bool,
|
||||||
|
is_error: bool,
|
||||||
|
data: Data,
|
||||||
|
tokenizer_name: String,
|
||||||
|
sequence_length: u32,
|
||||||
|
decode_length: u32,
|
||||||
|
n_run: usize,
|
||||||
|
batch_size: Vec<u32>,
|
||||||
|
receiver: mpsc::Receiver<Result<Message, ClientError>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl App {
|
||||||
|
pub(crate) fn new(
|
||||||
|
receiver: mpsc::Receiver<Result<Message, ClientError>>,
|
||||||
|
tokenizer_name: String,
|
||||||
|
sequence_length: u32,
|
||||||
|
decode_length: u32,
|
||||||
|
n_run: usize,
|
||||||
|
batch_size: Vec<u32>,
|
||||||
|
) -> Self {
|
||||||
|
let data = Data::new(n_run, batch_size.len());
|
||||||
|
let current_tab = 0;
|
||||||
|
|
||||||
|
let completed_runs: Vec<usize> = (0..batch_size.len()).map(|_| 0).collect();
|
||||||
|
let completed_batch = 0;
|
||||||
|
let current_batch = 0;
|
||||||
|
let is_error = false;
|
||||||
|
|
||||||
|
Self {
|
||||||
|
running: true,
|
||||||
|
completed_runs,
|
||||||
|
completed_batch,
|
||||||
|
current_batch,
|
||||||
|
current_tab,
|
||||||
|
touched_tab: false,
|
||||||
|
zoom: false,
|
||||||
|
is_error,
|
||||||
|
data,
|
||||||
|
tokenizer_name,
|
||||||
|
sequence_length,
|
||||||
|
decode_length,
|
||||||
|
n_run,
|
||||||
|
batch_size,
|
||||||
|
receiver,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Handle crossterm key events
|
||||||
|
pub(crate) fn handle_key_event(&mut self, key_event: KeyEvent) {
|
||||||
|
match key_event {
|
||||||
|
// Increase and wrap tab
|
||||||
|
KeyEvent {
|
||||||
|
code: KeyCode::Right,
|
||||||
|
..
|
||||||
|
}
|
||||||
|
| KeyEvent {
|
||||||
|
code: KeyCode::Tab, ..
|
||||||
|
} => {
|
||||||
|
self.touched_tab = true;
|
||||||
|
self.current_tab = (self.current_tab + 1) % self.batch_size.len();
|
||||||
|
}
|
||||||
|
// Decrease and wrap tab
|
||||||
|
KeyEvent {
|
||||||
|
code: KeyCode::Left,
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
self.touched_tab = true;
|
||||||
|
if self.current_tab > 0 {
|
||||||
|
self.current_tab -= 1;
|
||||||
|
} else {
|
||||||
|
self.current_tab = self.batch_size.len() - 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Zoom on throughput/latency fig
|
||||||
|
KeyEvent {
|
||||||
|
code: KeyCode::Char('+'),
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
self.zoom = true;
|
||||||
|
}
|
||||||
|
// Unzoom on throughput/latency fig
|
||||||
|
KeyEvent {
|
||||||
|
code: KeyCode::Char('-'),
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
self.zoom = false;
|
||||||
|
}
|
||||||
|
// Quit
|
||||||
|
KeyEvent {
|
||||||
|
code: KeyCode::Char('q'),
|
||||||
|
..
|
||||||
|
}
|
||||||
|
| KeyEvent {
|
||||||
|
code: KeyCode::Char('c'),
|
||||||
|
modifiers: KeyModifiers::CONTROL,
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
self.running = false;
|
||||||
|
}
|
||||||
|
_ => (),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get all pending messages from generation task
|
||||||
|
pub(crate) fn tick(&mut self) {
|
||||||
|
while let Ok(message) = self.receiver.try_recv() {
|
||||||
|
match message {
|
||||||
|
Ok(message) => match message {
|
||||||
|
Message::Prefill(step) => self.data.push_prefill(step, self.current_batch),
|
||||||
|
Message::Decode(step) => self.data.push_decode(step, self.current_batch),
|
||||||
|
Message::EndRun => {
|
||||||
|
self.completed_runs[self.current_batch] += 1;
|
||||||
|
}
|
||||||
|
Message::EndBatch => {
|
||||||
|
self.data.end_batch(self.current_batch);
|
||||||
|
self.completed_batch += 1;
|
||||||
|
|
||||||
|
if self.current_batch < self.batch_size.len() - 1 {
|
||||||
|
// Only go to next tab if the user never touched the tab keys
|
||||||
|
if !self.touched_tab {
|
||||||
|
self.current_tab += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
self.current_batch += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Message::Warmup => {}
|
||||||
|
},
|
||||||
|
Err(_) => self.is_error = true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Render frame
|
||||||
|
pub fn render<B: Backend>(&mut self, f: &mut Frame<'_, B>) {
|
||||||
|
let batch_progress =
|
||||||
|
(self.completed_batch as f64 / self.batch_size.len() as f64).clamp(0.0, 1.0);
|
||||||
|
let run_progress =
|
||||||
|
(self.completed_runs[self.current_batch] as f64 / self.n_run as f64).clamp(0.0, 1.0);
|
||||||
|
|
||||||
|
// Vertical layout
|
||||||
|
let row5 = Layout::default()
|
||||||
|
.direction(Direction::Vertical)
|
||||||
|
.constraints(
|
||||||
|
[
|
||||||
|
Constraint::Length(1),
|
||||||
|
Constraint::Length(3),
|
||||||
|
Constraint::Length(3),
|
||||||
|
Constraint::Length(13),
|
||||||
|
Constraint::Min(10),
|
||||||
|
]
|
||||||
|
.as_ref(),
|
||||||
|
)
|
||||||
|
.split(f.size());
|
||||||
|
|
||||||
|
// Top row horizontal layout
|
||||||
|
let top = Layout::default()
|
||||||
|
.direction(Direction::Horizontal)
|
||||||
|
.constraints([Constraint::Percentage(50), Constraint::Percentage(50)].as_ref())
|
||||||
|
.split(row5[2]);
|
||||||
|
|
||||||
|
// Mid row horizontal layout
|
||||||
|
let mid = Layout::default()
|
||||||
|
.direction(Direction::Horizontal)
|
||||||
|
.constraints(
|
||||||
|
[
|
||||||
|
Constraint::Percentage(25),
|
||||||
|
Constraint::Percentage(25),
|
||||||
|
Constraint::Percentage(25),
|
||||||
|
Constraint::Percentage(25),
|
||||||
|
]
|
||||||
|
.as_ref(),
|
||||||
|
)
|
||||||
|
.split(row5[3]);
|
||||||
|
|
||||||
|
// Left mid row vertical layout
|
||||||
|
let prefill_text = Layout::default()
|
||||||
|
.direction(Direction::Vertical)
|
||||||
|
.constraints([Constraint::Length(8), Constraint::Length(5)].as_ref())
|
||||||
|
.split(mid[0]);
|
||||||
|
|
||||||
|
// Right mid row vertical layout
|
||||||
|
let decode_text = Layout::default()
|
||||||
|
.direction(Direction::Vertical)
|
||||||
|
.constraints([Constraint::Length(8), Constraint::Length(5)].as_ref())
|
||||||
|
.split(mid[2]);
|
||||||
|
let decode_text_latency = Layout::default()
|
||||||
|
.direction(Direction::Horizontal)
|
||||||
|
.constraints([Constraint::Percentage(50), Constraint::Percentage(50)].as_ref())
|
||||||
|
.split(decode_text[0]);
|
||||||
|
|
||||||
|
// Bottom row horizontal layout
|
||||||
|
let bottom = Layout::default()
|
||||||
|
.direction(Direction::Horizontal)
|
||||||
|
.constraints([Constraint::Percentage(50), Constraint::Percentage(50)].as_ref())
|
||||||
|
.split(row5[4]);
|
||||||
|
|
||||||
|
// Title
|
||||||
|
let title = Block::default()
|
||||||
|
.borders(Borders::NONE)
|
||||||
|
.title(format!(
|
||||||
|
"Model: {} | Sequence Length: {} | Decode Length: {}",
|
||||||
|
self.tokenizer_name, self.sequence_length, self.decode_length
|
||||||
|
))
|
||||||
|
.style(
|
||||||
|
Style::default()
|
||||||
|
.add_modifier(Modifier::BOLD)
|
||||||
|
.fg(Color::White),
|
||||||
|
);
|
||||||
|
f.render_widget(title, row5[0]);
|
||||||
|
|
||||||
|
// Helper
|
||||||
|
let helper = Block::default()
|
||||||
|
.borders(Borders::NONE)
|
||||||
|
.title("<- | tab | ->: change batch tab | q / CTRL + c: quit | +/-: zoom")
|
||||||
|
.title_alignment(Alignment::Right)
|
||||||
|
.style(Style::default().fg(Color::White));
|
||||||
|
f.render_widget(helper, row5[0]);
|
||||||
|
|
||||||
|
// Batch tabs
|
||||||
|
let titles = self
|
||||||
|
.batch_size
|
||||||
|
.iter()
|
||||||
|
.map(|b| {
|
||||||
|
Spans::from(vec![Span::styled(
|
||||||
|
format!("Batch: {b}"),
|
||||||
|
Style::default().fg(Color::White),
|
||||||
|
)])
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
let tabs = Tabs::new(titles)
|
||||||
|
.block(Block::default().borders(Borders::ALL).title("Tabs"))
|
||||||
|
.select(self.current_tab)
|
||||||
|
.style(Style::default().fg(Color::LightCyan))
|
||||||
|
.highlight_style(
|
||||||
|
Style::default()
|
||||||
|
.add_modifier(Modifier::BOLD)
|
||||||
|
.bg(Color::Black),
|
||||||
|
);
|
||||||
|
f.render_widget(tabs, row5[1]);
|
||||||
|
|
||||||
|
// Total progress bar
|
||||||
|
let color = if self.is_error {
|
||||||
|
Color::Red
|
||||||
|
} else {
|
||||||
|
Color::LightGreen
|
||||||
|
};
|
||||||
|
let batch_gauge = progress_gauge(
|
||||||
|
"Total Progress",
|
||||||
|
format!("{} / {}", self.completed_batch, self.batch_size.len()),
|
||||||
|
batch_progress,
|
||||||
|
color,
|
||||||
|
);
|
||||||
|
f.render_widget(batch_gauge, top[0]);
|
||||||
|
|
||||||
|
// Batch progress Bar
|
||||||
|
let color = if self.is_error {
|
||||||
|
Color::Red
|
||||||
|
} else {
|
||||||
|
Color::LightBlue
|
||||||
|
};
|
||||||
|
let run_gauge = progress_gauge(
|
||||||
|
"Batch Progress",
|
||||||
|
format!(
|
||||||
|
"{} / {}",
|
||||||
|
self.completed_runs[self.current_batch], self.n_run
|
||||||
|
),
|
||||||
|
run_progress,
|
||||||
|
color,
|
||||||
|
);
|
||||||
|
f.render_widget(run_gauge, top[1]);
|
||||||
|
|
||||||
|
// Prefill text infos
|
||||||
|
let prefill_latency_block = latency_paragraph(
|
||||||
|
&mut self.data.prefill_latencies[self.current_tab],
|
||||||
|
"Prefill",
|
||||||
|
);
|
||||||
|
let prefill_throughput_block =
|
||||||
|
throughput_paragraph(&self.data.prefill_throughputs[self.current_tab], "Prefill");
|
||||||
|
|
||||||
|
f.render_widget(prefill_latency_block, prefill_text[0]);
|
||||||
|
f.render_widget(prefill_throughput_block, prefill_text[1]);
|
||||||
|
|
||||||
|
// Prefill latency histogram
|
||||||
|
let histo_width = 7;
|
||||||
|
let bins = if mid[1].width < 2 {
|
||||||
|
0
|
||||||
|
} else {
|
||||||
|
(mid[1].width as usize - 2) / (histo_width + 1)
|
||||||
|
}
|
||||||
|
.max(2);
|
||||||
|
|
||||||
|
let histo_data =
|
||||||
|
latency_histogram_data(&self.data.prefill_latencies[self.current_tab], bins);
|
||||||
|
let histo_data_str: Vec<(&str, u64)> =
|
||||||
|
histo_data.iter().map(|(l, v)| (l.as_str(), *v)).collect();
|
||||||
|
let prefill_histogram =
|
||||||
|
latency_histogram(&histo_data_str, "Prefill").bar_width(histo_width as u16);
|
||||||
|
f.render_widget(prefill_histogram, mid[1]);
|
||||||
|
|
||||||
|
// Decode text info
|
||||||
|
let decode_latency_block = latency_paragraph(
|
||||||
|
&mut self.data.decode_latencies[self.current_tab],
|
||||||
|
"Decode Total",
|
||||||
|
);
|
||||||
|
let decode_token_latency_block = latency_paragraph(
|
||||||
|
&mut self.data.decode_token_latencies[self.current_tab],
|
||||||
|
"Decode Token",
|
||||||
|
);
|
||||||
|
let decode_throughput_block =
|
||||||
|
throughput_paragraph(&self.data.decode_throughputs[self.current_tab], "Decode");
|
||||||
|
f.render_widget(decode_latency_block, decode_text_latency[0]);
|
||||||
|
f.render_widget(decode_token_latency_block, decode_text_latency[1]);
|
||||||
|
f.render_widget(decode_throughput_block, decode_text[1]);
|
||||||
|
|
||||||
|
// Decode latency histogram
|
||||||
|
let histo_data =
|
||||||
|
latency_histogram_data(&self.data.decode_latencies[self.current_tab], bins);
|
||||||
|
let histo_data_str: Vec<(&str, u64)> =
|
||||||
|
histo_data.iter().map(|(l, v)| (l.as_str(), *v)).collect();
|
||||||
|
let decode_histogram =
|
||||||
|
latency_histogram(&histo_data_str, "Decode").bar_width(histo_width as u16);
|
||||||
|
f.render_widget(decode_histogram, mid[3]);
|
||||||
|
|
||||||
|
// Prefill latency/throughput chart
|
||||||
|
let prefill_latency_throughput_chart = latency_throughput_chart(
|
||||||
|
&self.data.prefill_batch_latency_throughput,
|
||||||
|
&self.batch_size,
|
||||||
|
self.zoom,
|
||||||
|
"Prefill",
|
||||||
|
);
|
||||||
|
f.render_widget(prefill_latency_throughput_chart, bottom[0]);
|
||||||
|
|
||||||
|
// Decode latency/throughput chart
|
||||||
|
let decode_latency_throughput_chart = latency_throughput_chart(
|
||||||
|
&self.data.decode_batch_latency_throughput,
|
||||||
|
&self.batch_size,
|
||||||
|
self.zoom,
|
||||||
|
"Decode",
|
||||||
|
);
|
||||||
|
f.render_widget(decode_latency_throughput_chart, bottom[1]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// App internal data struct
|
||||||
|
struct Data {
|
||||||
|
prefill_latencies: Vec<Vec<f64>>,
|
||||||
|
prefill_throughputs: Vec<Vec<f64>>,
|
||||||
|
decode_latencies: Vec<Vec<f64>>,
|
||||||
|
decode_token_latencies: Vec<Vec<f64>>,
|
||||||
|
decode_throughputs: Vec<Vec<f64>>,
|
||||||
|
prefill_batch_latency_throughput: Vec<(f64, f64)>,
|
||||||
|
decode_batch_latency_throughput: Vec<(f64, f64)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Data {
|
||||||
|
fn new(n_run: usize, n_batch: usize) -> Self {
|
||||||
|
let prefill_latencies: Vec<Vec<f64>> =
|
||||||
|
(0..n_batch).map(|_| Vec::with_capacity(n_run)).collect();
|
||||||
|
let prefill_throughputs: Vec<Vec<f64>> = prefill_latencies.clone();
|
||||||
|
|
||||||
|
let decode_latencies: Vec<Vec<f64>> = prefill_latencies.clone();
|
||||||
|
let decode_token_latencies: Vec<Vec<f64>> = decode_latencies.clone();
|
||||||
|
let decode_throughputs: Vec<Vec<f64>> = prefill_throughputs.clone();
|
||||||
|
|
||||||
|
let prefill_batch_latency_throughput: Vec<(f64, f64)> = Vec::with_capacity(n_batch);
|
||||||
|
let decode_batch_latency_throughput: Vec<(f64, f64)> =
|
||||||
|
prefill_batch_latency_throughput.clone();
|
||||||
|
|
||||||
|
Self {
|
||||||
|
prefill_latencies,
|
||||||
|
prefill_throughputs,
|
||||||
|
decode_latencies,
|
||||||
|
decode_token_latencies,
|
||||||
|
decode_throughputs,
|
||||||
|
prefill_batch_latency_throughput,
|
||||||
|
decode_batch_latency_throughput,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn push_prefill(&mut self, prefill: Prefill, batch_idx: usize) {
|
||||||
|
let latency = prefill.latency.as_millis() as f64;
|
||||||
|
self.prefill_latencies[batch_idx].push(latency);
|
||||||
|
self.prefill_throughputs[batch_idx].push(prefill.throughput);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn push_decode(&mut self, decode: Decode, batch_idx: usize) {
|
||||||
|
let latency = decode.latency.as_millis() as f64;
|
||||||
|
let token_latency = decode.token_latency.as_millis() as f64;
|
||||||
|
self.decode_latencies[batch_idx].push(latency);
|
||||||
|
self.decode_token_latencies[batch_idx].push(token_latency);
|
||||||
|
self.decode_throughputs[batch_idx].push(decode.throughput);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn end_batch(&mut self, batch_idx: usize) {
|
||||||
|
self.prefill_batch_latency_throughput.push((
|
||||||
|
self.prefill_latencies[batch_idx].iter().sum::<f64>()
|
||||||
|
/ self.prefill_latencies[batch_idx].len() as f64,
|
||||||
|
self.prefill_throughputs[batch_idx].iter().sum::<f64>()
|
||||||
|
/ self.prefill_throughputs[batch_idx].len() as f64,
|
||||||
|
));
|
||||||
|
self.decode_batch_latency_throughput.push((
|
||||||
|
self.decode_latencies[batch_idx].iter().sum::<f64>()
|
||||||
|
/ self.decode_latencies[batch_idx].len() as f64,
|
||||||
|
self.decode_throughputs[batch_idx].iter().sum::<f64>()
|
||||||
|
/ self.decode_throughputs[batch_idx].len() as f64,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Progress bar
|
||||||
|
fn progress_gauge(title: &str, label: String, progress: f64, color: Color) -> Gauge {
|
||||||
|
Gauge::default()
|
||||||
|
.block(Block::default().title(title).borders(Borders::ALL))
|
||||||
|
.gauge_style(Style::default().fg(color))
|
||||||
|
.label(Span::raw(label))
|
||||||
|
.ratio(progress)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Throughput paragraph
|
||||||
|
fn throughput_paragraph<'a>(throughput: &Vec<f64>, name: &'static str) -> Paragraph<'a> {
|
||||||
|
// Throughput average/high/low texts
|
||||||
|
let throughput_texts = statis_spans(throughput, "tokens/secs");
|
||||||
|
|
||||||
|
// Throughput block
|
||||||
|
Paragraph::new(throughput_texts).block(
|
||||||
|
Block::default()
|
||||||
|
.title(Span::raw(format!("{name} Throughput")))
|
||||||
|
.borders(Borders::ALL),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Latency paragraph
|
||||||
|
fn latency_paragraph<'a>(latency: &mut Vec<f64>, name: &'static str) -> Paragraph<'a> {
|
||||||
|
// Latency average/high/low texts
|
||||||
|
let mut latency_texts = statis_spans(latency, "ms");
|
||||||
|
|
||||||
|
// Sort latency for percentiles
|
||||||
|
float_ord::sort(latency);
|
||||||
|
let latency_percentiles = crate::utils::percentiles(latency, &[50, 90, 99]);
|
||||||
|
|
||||||
|
// Latency p50/p90/p99 texts
|
||||||
|
let colors = vec![Color::LightGreen, Color::LightYellow, Color::LightRed];
|
||||||
|
for (i, (name, value)) in latency_percentiles.iter().enumerate() {
|
||||||
|
let span = Spans::from(vec![Span::styled(
|
||||||
|
format!("{name}: {value:.2} ms"),
|
||||||
|
Style::default().fg(colors[i]),
|
||||||
|
)]);
|
||||||
|
latency_texts.push(span);
|
||||||
|
}
|
||||||
|
|
||||||
|
Paragraph::new(latency_texts).block(
|
||||||
|
Block::default()
|
||||||
|
.title(Span::raw(format!("{name} Latency")))
|
||||||
|
.borders(Borders::ALL),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Average/High/Low spans
|
||||||
|
fn statis_spans<'a>(data: &Vec<f64>, unit: &'static str) -> Vec<Spans<'a>> {
|
||||||
|
vec![
|
||||||
|
Spans::from(vec![Span::styled(
|
||||||
|
format!(
|
||||||
|
"Average: {:.2} {unit}",
|
||||||
|
data.iter().sum::<f64>() / data.len() as f64
|
||||||
|
),
|
||||||
|
Style::default().fg(Color::LightBlue),
|
||||||
|
)]),
|
||||||
|
Spans::from(vec![Span::styled(
|
||||||
|
format!(
|
||||||
|
"Lowest: {:.2} {unit}",
|
||||||
|
data.iter()
|
||||||
|
.min_by(|a, b| a.total_cmp(b))
|
||||||
|
.unwrap_or(&std::f64::NAN)
|
||||||
|
),
|
||||||
|
Style::default().fg(Color::Reset),
|
||||||
|
)]),
|
||||||
|
Spans::from(vec![Span::styled(
|
||||||
|
format!(
|
||||||
|
"Highest: {:.2} {unit}",
|
||||||
|
data.iter()
|
||||||
|
.max_by(|a, b| a.total_cmp(b))
|
||||||
|
.unwrap_or(&std::f64::NAN)
|
||||||
|
),
|
||||||
|
Style::default().fg(Color::Reset),
|
||||||
|
)]),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Latency histogram data
|
||||||
|
fn latency_histogram_data(latency: &[f64], bins: usize) -> Vec<(String, u64)> {
|
||||||
|
let histo_data: Vec<(String, u64)> = {
|
||||||
|
let histo = crate::utils::histogram(latency, bins);
|
||||||
|
histo
|
||||||
|
.into_iter()
|
||||||
|
.map(|(label, v)| (format!("{label:.2}"), v as u64))
|
||||||
|
.collect()
|
||||||
|
};
|
||||||
|
|
||||||
|
histo_data
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Latency Histogram
|
||||||
|
fn latency_histogram<'a>(
|
||||||
|
histo_data_str: &'a Vec<(&'a str, u64)>,
|
||||||
|
name: &'static str,
|
||||||
|
) -> BarChart<'a> {
|
||||||
|
BarChart::default()
|
||||||
|
.block(
|
||||||
|
Block::default()
|
||||||
|
.title(format!("{name} latency histogram"))
|
||||||
|
.style(Style::default().fg(Color::LightYellow).bg(Color::Reset))
|
||||||
|
.borders(Borders::ALL),
|
||||||
|
)
|
||||||
|
.data(histo_data_str.as_slice())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Latency/Throughput chart
|
||||||
|
fn latency_throughput_chart<'a>(
|
||||||
|
latency_throughput: &'a Vec<(f64, f64)>,
|
||||||
|
batch_sizes: &'a [u32],
|
||||||
|
zoom: bool,
|
||||||
|
name: &'static str,
|
||||||
|
) -> Chart<'a> {
|
||||||
|
let latency_iter = latency_throughput.iter().map(|(l, _)| l);
|
||||||
|
let throughput_iter = latency_throughput.iter().map(|(_, t)| t);
|
||||||
|
|
||||||
|
// Get extreme values
|
||||||
|
let min_latency: f64 = *latency_iter
|
||||||
|
.clone()
|
||||||
|
.min_by(|a, b| a.total_cmp(b))
|
||||||
|
.unwrap_or(&std::f64::NAN);
|
||||||
|
let max_latency: f64 = *latency_iter
|
||||||
|
.max_by(|a, b| a.total_cmp(b))
|
||||||
|
.unwrap_or(&std::f64::NAN);
|
||||||
|
let min_throughput: f64 = *throughput_iter
|
||||||
|
.clone()
|
||||||
|
.min_by(|a, b| a.total_cmp(b))
|
||||||
|
.unwrap_or(&std::f64::NAN);
|
||||||
|
let max_throughput: f64 = *throughput_iter
|
||||||
|
.max_by(|a, b| a.total_cmp(b))
|
||||||
|
.unwrap_or(&std::f64::NAN);
|
||||||
|
|
||||||
|
// Char min max values
|
||||||
|
let min_x = if zoom {
|
||||||
|
((min_latency - 0.05 * min_latency) / 100.0).floor() * 100.0
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
};
|
||||||
|
let max_x = ((max_latency + 0.05 * max_latency) / 100.0).ceil() * 100.0;
|
||||||
|
let step_x = (max_x - min_x) / 4.0;
|
||||||
|
|
||||||
|
// Chart min max values
|
||||||
|
let min_y = if zoom {
|
||||||
|
((min_throughput - 0.05 * min_throughput) / 100.0).floor() * 100.0
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
};
|
||||||
|
let max_y = ((max_throughput + 0.05 * max_throughput) / 100.0).ceil() * 100.0;
|
||||||
|
let step_y = (max_y - min_y) / 4.0;
|
||||||
|
|
||||||
|
// Labels
|
||||||
|
let mut x_labels = vec![Span::styled(
|
||||||
|
format!("{min_x:.2}"),
|
||||||
|
Style::default()
|
||||||
|
.add_modifier(Modifier::BOLD)
|
||||||
|
.fg(Color::Gray)
|
||||||
|
.bg(Color::Reset),
|
||||||
|
)];
|
||||||
|
for i in 0..3 {
|
||||||
|
x_labels.push(Span::styled(
|
||||||
|
format!("{:.2}", min_x + ((i + 1) as f64 * step_x)),
|
||||||
|
Style::default().fg(Color::Gray).bg(Color::Reset),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
x_labels.push(Span::styled(
|
||||||
|
format!("{max_x:.2}"),
|
||||||
|
Style::default()
|
||||||
|
.add_modifier(Modifier::BOLD)
|
||||||
|
.fg(Color::Gray)
|
||||||
|
.bg(Color::Reset),
|
||||||
|
));
|
||||||
|
|
||||||
|
// Labels
|
||||||
|
let mut y_labels = vec![Span::styled(
|
||||||
|
format!("{min_y:.2}"),
|
||||||
|
Style::default()
|
||||||
|
.add_modifier(Modifier::BOLD)
|
||||||
|
.fg(Color::Gray)
|
||||||
|
.bg(Color::Reset),
|
||||||
|
)];
|
||||||
|
for i in 0..3 {
|
||||||
|
y_labels.push(Span::styled(
|
||||||
|
format!("{:.2}", min_y + ((i + 1) as f64 * step_y)),
|
||||||
|
Style::default().fg(Color::Gray).bg(Color::Reset),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
y_labels.push(Span::styled(
|
||||||
|
format!("{max_y:.2}"),
|
||||||
|
Style::default()
|
||||||
|
.add_modifier(Modifier::BOLD)
|
||||||
|
.fg(Color::Gray)
|
||||||
|
.bg(Color::Reset),
|
||||||
|
));
|
||||||
|
|
||||||
|
// Chart dataset
|
||||||
|
let colors = color_vec();
|
||||||
|
let datasets: Vec<Dataset> = (0..latency_throughput.len())
|
||||||
|
.map(|i| {
|
||||||
|
let color_idx = i % colors.len();
|
||||||
|
|
||||||
|
Dataset::default()
|
||||||
|
.name(batch_sizes[i].to_string())
|
||||||
|
.marker(symbols::Marker::Block)
|
||||||
|
.style(Style::default().fg(colors[color_idx]))
|
||||||
|
.graph_type(GraphType::Scatter)
|
||||||
|
.data(&latency_throughput[i..(i + 1)])
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// Chart
|
||||||
|
Chart::new(datasets)
|
||||||
|
.style(Style::default().fg(Color::Cyan).bg(Color::Reset))
|
||||||
|
.block(
|
||||||
|
Block::default()
|
||||||
|
.title(Span::styled(
|
||||||
|
format!("{name} throughput over latency"),
|
||||||
|
Style::default().fg(Color::Gray).bg(Color::Reset),
|
||||||
|
))
|
||||||
|
.borders(Borders::ALL),
|
||||||
|
)
|
||||||
|
.x_axis(
|
||||||
|
Axis::default()
|
||||||
|
.title("ms")
|
||||||
|
.style(Style::default().fg(Color::Gray).bg(Color::Reset))
|
||||||
|
.labels(x_labels)
|
||||||
|
.bounds([min_x, max_x]),
|
||||||
|
)
|
||||||
|
.y_axis(
|
||||||
|
Axis::default()
|
||||||
|
.title("tokens/secs")
|
||||||
|
.style(Style::default().fg(Color::Gray).bg(Color::Reset))
|
||||||
|
.labels(y_labels)
|
||||||
|
.bounds([min_y, max_y]),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Colors for latency/throughput chart
|
||||||
|
fn color_vec() -> Vec<Color> {
|
||||||
|
vec![
|
||||||
|
Color::Red,
|
||||||
|
Color::Green,
|
||||||
|
Color::Yellow,
|
||||||
|
Color::Blue,
|
||||||
|
Color::Magenta,
|
||||||
|
Color::Cyan,
|
||||||
|
Color::Gray,
|
||||||
|
Color::DarkGray,
|
||||||
|
Color::LightRed,
|
||||||
|
Color::LightGreen,
|
||||||
|
Color::LightYellow,
|
||||||
|
Color::LightBlue,
|
||||||
|
Color::LightMagenta,
|
||||||
|
Color::LightCyan,
|
||||||
|
]
|
||||||
|
}
|
|
@ -0,0 +1,65 @@
|
||||||
|
/// Inspired by https://github.com/orhun/rust-tui-template/blob/472aa515119d4c94903eac12d9784417281dc7f5/src/event.rs
|
||||||
|
use crossterm::event;
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
use tokio::sync::{broadcast, mpsc};
|
||||||
|
|
||||||
|
/// Events
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub(crate) enum Event {
|
||||||
|
/// Terminal tick.
|
||||||
|
Tick,
|
||||||
|
/// Key press.
|
||||||
|
Key(event::KeyEvent),
|
||||||
|
/// Terminal resize.
|
||||||
|
Resize(u16, u16),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn terminal_event_task(
|
||||||
|
fps: u32,
|
||||||
|
event_sender: mpsc::Sender<Event>,
|
||||||
|
mut shutdown_receiver: broadcast::Receiver<()>,
|
||||||
|
_shutdown_guard_sender: mpsc::Sender<()>,
|
||||||
|
) {
|
||||||
|
// End task if a message is received on shutdown_receiver
|
||||||
|
// _shutdown_guard_sender will be dropped once the task is finished
|
||||||
|
tokio::select! {
|
||||||
|
_ = event_loop(fps, event_sender) => {
|
||||||
|
},
|
||||||
|
_ = shutdown_receiver.recv() => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Main event loop
|
||||||
|
async fn event_loop(fps: u32, event_sender: mpsc::Sender<Event>) {
|
||||||
|
// Frame budget
|
||||||
|
let per_frame = Duration::from_secs(1) / fps;
|
||||||
|
|
||||||
|
// When was last frame executed
|
||||||
|
let mut last_frame = Instant::now();
|
||||||
|
|
||||||
|
loop {
|
||||||
|
// Sleep to avoid blocking the thread for too long
|
||||||
|
if let Some(sleep) = per_frame.checked_sub(last_frame.elapsed()) {
|
||||||
|
tokio::time::sleep(sleep).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get crossterm event and send a new one over the channel
|
||||||
|
if event::poll(Duration::from_secs(0)).expect("no events available") {
|
||||||
|
match event::read().expect("unable to read event") {
|
||||||
|
event::Event::Key(e) => event_sender.send(Event::Key(e)).await.unwrap_or(()),
|
||||||
|
event::Event::Resize(w, h) => {
|
||||||
|
event_sender.send(Event::Resize(w, h)).await.unwrap_or(())
|
||||||
|
}
|
||||||
|
_ => (),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Frame budget exceeded
|
||||||
|
if last_frame.elapsed() >= per_frame {
|
||||||
|
// Send tick
|
||||||
|
event_sender.send(Event::Tick).await.unwrap_or(());
|
||||||
|
// Rest last_frame time
|
||||||
|
last_frame = Instant::now();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,211 @@
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
use text_generation_client::{
|
||||||
|
Batch, ClientError, NextTokenChooserParameters, Request, ShardedClient,
|
||||||
|
StoppingCriteriaParameters,
|
||||||
|
};
|
||||||
|
use tokenizers::{Tokenizer, TruncationDirection};
|
||||||
|
use tokio::sync::{broadcast, mpsc};
|
||||||
|
|
||||||
|
const LOREM_IPSUM: &str = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.";
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub(crate) struct Prefill {
|
||||||
|
pub(crate) latency: Duration,
|
||||||
|
pub(crate) throughput: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub(crate) struct Decode {
|
||||||
|
pub(crate) latency: Duration,
|
||||||
|
pub(crate) token_latency: Duration,
|
||||||
|
pub(crate) throughput: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub(crate) enum Message {
|
||||||
|
Warmup,
|
||||||
|
Prefill(Prefill),
|
||||||
|
Decode(Decode),
|
||||||
|
EndRun,
|
||||||
|
EndBatch,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Benchmarking task
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub(crate) async fn generation_task(
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
batch_size: Vec<u32>,
|
||||||
|
sequence_length: u32,
|
||||||
|
decode_length: u32,
|
||||||
|
n_runs: usize,
|
||||||
|
warmups: usize,
|
||||||
|
client: ShardedClient,
|
||||||
|
run_sender: mpsc::Sender<Result<Message, ClientError>>,
|
||||||
|
mut shutdown_receiver: broadcast::Receiver<()>,
|
||||||
|
_shutdown_guard_sender: mpsc::Sender<()>,
|
||||||
|
) {
|
||||||
|
// End task if a message is received on shutdown_receiver
|
||||||
|
// _shutdown_guard_sender will be dropped once the task is finished
|
||||||
|
tokio::select! {
|
||||||
|
res = generate_runs(tokenizer, batch_size, sequence_length, decode_length, n_runs, warmups, client, run_sender.clone()) => {
|
||||||
|
if let Err(err) = res {
|
||||||
|
run_sender.send(Err(err)).await.unwrap_or(());
|
||||||
|
}
|
||||||
|
},
|
||||||
|
_ = shutdown_receiver.recv() => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Benchmark prefill/decode
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
async fn generate_runs(
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
batch_size: Vec<u32>,
|
||||||
|
sequence_length: u32,
|
||||||
|
decode_length: u32,
|
||||||
|
n_runs: usize,
|
||||||
|
warmups: usize,
|
||||||
|
mut client: ShardedClient,
|
||||||
|
run_sender: mpsc::Sender<Result<Message, ClientError>>,
|
||||||
|
) -> Result<(), ClientError> {
|
||||||
|
// Create a dummy sequence
|
||||||
|
let sequence = create_sequence(sequence_length, tokenizer);
|
||||||
|
|
||||||
|
for b in batch_size {
|
||||||
|
// Warmups on batch size
|
||||||
|
for _ in 0..warmups {
|
||||||
|
let (_, decode_batch) =
|
||||||
|
prefill(sequence.clone(), b, decode_length, &mut client).await?;
|
||||||
|
let _ = decode(decode_batch, &mut client).await?;
|
||||||
|
// Send warmup message
|
||||||
|
run_sender.send(Ok(Message::Warmup)).await.unwrap_or(());
|
||||||
|
}
|
||||||
|
|
||||||
|
for _ in 0..n_runs {
|
||||||
|
let (prefill, decode_batch) =
|
||||||
|
prefill(sequence.clone(), b, decode_length, &mut client).await?;
|
||||||
|
// Send prefill message
|
||||||
|
run_sender
|
||||||
|
.send(Ok(Message::Prefill(prefill)))
|
||||||
|
.await
|
||||||
|
.unwrap_or(());
|
||||||
|
|
||||||
|
let decode = decode(decode_batch, &mut client).await?;
|
||||||
|
|
||||||
|
// Send decode message
|
||||||
|
run_sender
|
||||||
|
.send(Ok(Message::Decode(decode)))
|
||||||
|
.await
|
||||||
|
.unwrap_or(());
|
||||||
|
|
||||||
|
// Send run ended message
|
||||||
|
run_sender.send(Ok(Message::EndRun)).await.unwrap_or(());
|
||||||
|
}
|
||||||
|
// Batch ended
|
||||||
|
run_sender.send(Ok(Message::EndBatch)).await.unwrap_or(());
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run a prefill step
|
||||||
|
async fn prefill(
|
||||||
|
sequence: String,
|
||||||
|
batch_size: u32,
|
||||||
|
decode_length: u32,
|
||||||
|
client: &mut ShardedClient,
|
||||||
|
) -> Result<(Prefill, Batch), ClientError> {
|
||||||
|
// Create requests
|
||||||
|
let requests = (0..batch_size)
|
||||||
|
.map(|id| Request {
|
||||||
|
id: id.into(),
|
||||||
|
inputs: sequence.clone(),
|
||||||
|
parameters: Some(NextTokenChooserParameters {
|
||||||
|
temperature: 1.0,
|
||||||
|
top_k: 0,
|
||||||
|
top_p: 1.0,
|
||||||
|
typical_p: 1.0,
|
||||||
|
do_sample: false,
|
||||||
|
seed: 0,
|
||||||
|
repetition_penalty: 1.0,
|
||||||
|
watermark: false,
|
||||||
|
}),
|
||||||
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
|
max_new_tokens: decode_length,
|
||||||
|
stop_sequences: vec![],
|
||||||
|
ignore_eos_token: true, // Will not stop even if a eos token is generated
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let batch = Batch {
|
||||||
|
id: 0,
|
||||||
|
requests,
|
||||||
|
size: batch_size,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Run prefill
|
||||||
|
let start_time = Instant::now();
|
||||||
|
let (_, decode_batch) = client.prefill(batch.clone()).await?;
|
||||||
|
|
||||||
|
// Get latency
|
||||||
|
let latency = start_time.elapsed();
|
||||||
|
|
||||||
|
// Compute throughput from latency and batch size
|
||||||
|
let throughput = batch_size as f64 / latency.as_secs_f64();
|
||||||
|
|
||||||
|
// Decode batch cannot be empty
|
||||||
|
let decode_batch = decode_batch.expect("decode_batch is None. This is a bug.");
|
||||||
|
|
||||||
|
let step = Prefill {
|
||||||
|
latency,
|
||||||
|
throughput,
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok((step, decode_batch))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run a full decode
|
||||||
|
async fn decode(batch: Batch, client: &mut ShardedClient) -> Result<Decode, ClientError> {
|
||||||
|
let mut decode_length = 0;
|
||||||
|
let batch_size = batch.size;
|
||||||
|
|
||||||
|
let start_time = Instant::now();
|
||||||
|
|
||||||
|
// Full decode over decode length
|
||||||
|
let mut next_batch = Some(batch);
|
||||||
|
while let Some(batch) = next_batch {
|
||||||
|
let result = client.decode(vec![batch]).await?;
|
||||||
|
next_batch = result.1;
|
||||||
|
decode_length += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get latency
|
||||||
|
let latency = start_time.elapsed();
|
||||||
|
let token_latency = latency / decode_length;
|
||||||
|
|
||||||
|
// Compute throughput from latency, batch size and decode length
|
||||||
|
let throughput = (batch_size * decode_length) as f64 / latency.as_secs_f64();
|
||||||
|
|
||||||
|
let step = Decode {
|
||||||
|
latency,
|
||||||
|
token_latency,
|
||||||
|
throughput,
|
||||||
|
};
|
||||||
|
Ok(step)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a dummy sequence of the correct length
|
||||||
|
fn create_sequence(sequence_length: u32, tokenizer: Tokenizer) -> String {
|
||||||
|
let lorem_ipsum_length = tokenizer.encode(LOREM_IPSUM, true).unwrap().len();
|
||||||
|
// Repeat lorem ipsum to cover sequence length
|
||||||
|
let string_sequence =
|
||||||
|
LOREM_IPSUM.repeat((0..sequence_length).step_by(lorem_ipsum_length).len());
|
||||||
|
// Encode sequence
|
||||||
|
let mut encoding = tokenizer.encode(string_sequence, true).unwrap();
|
||||||
|
// Truncate to sequence_length
|
||||||
|
encoding.truncate(sequence_length as usize, 0, TruncationDirection::Left);
|
||||||
|
// Decode
|
||||||
|
tokenizer
|
||||||
|
.decode(Vec::from(encoding.get_ids()), false)
|
||||||
|
.unwrap()
|
||||||
|
}
|
|
@ -0,0 +1,110 @@
|
||||||
|
mod app;
|
||||||
|
mod event;
|
||||||
|
mod generation;
|
||||||
|
mod utils;
|
||||||
|
|
||||||
|
use crate::app::App;
|
||||||
|
use crate::event::Event;
|
||||||
|
use crossterm::ExecutableCommand;
|
||||||
|
use std::io;
|
||||||
|
use text_generation_client::ShardedClient;
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
use tokio::sync::{broadcast, mpsc};
|
||||||
|
use tui::backend::CrosstermBackend;
|
||||||
|
use tui::Terminal;
|
||||||
|
|
||||||
|
/// Run benchmarking app
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub async fn run(
|
||||||
|
tokenizer_name: String,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
batch_size: Vec<u32>,
|
||||||
|
sequence_length: u32,
|
||||||
|
decode_length: u32,
|
||||||
|
n_runs: usize,
|
||||||
|
warmups: usize,
|
||||||
|
client: ShardedClient,
|
||||||
|
) -> Result<(), crossterm::ErrorKind> {
|
||||||
|
// Initialize terminal properties
|
||||||
|
crossterm::terminal::enable_raw_mode()?;
|
||||||
|
io::stdout().execute(crossterm::terminal::EnterAlternateScreen)?;
|
||||||
|
io::stdout().execute(crossterm::cursor::Hide)?;
|
||||||
|
|
||||||
|
// Initialize terminal
|
||||||
|
let mut terminal = {
|
||||||
|
let backend = CrosstermBackend::new(io::stdout());
|
||||||
|
Terminal::new(backend)?
|
||||||
|
};
|
||||||
|
|
||||||
|
// Create message channel between generation_task and app
|
||||||
|
let (run_sender, run_receiver) = mpsc::channel(8);
|
||||||
|
// Crossterm event channel
|
||||||
|
let (event_sender, mut event_receiver) = mpsc::channel(8);
|
||||||
|
// Shutdown channel to terminate tasks
|
||||||
|
let (shutdown_sender, _) = broadcast::channel(1);
|
||||||
|
// Channel to check if tasks terminated
|
||||||
|
let (shutdown_guard_sender, mut shutdown_guard_receiver) = mpsc::channel(1);
|
||||||
|
|
||||||
|
// Create generation task
|
||||||
|
tokio::spawn(generation::generation_task(
|
||||||
|
tokenizer,
|
||||||
|
batch_size.clone(),
|
||||||
|
sequence_length,
|
||||||
|
decode_length,
|
||||||
|
n_runs,
|
||||||
|
warmups,
|
||||||
|
client,
|
||||||
|
run_sender,
|
||||||
|
shutdown_sender.subscribe(),
|
||||||
|
shutdown_guard_sender.clone(),
|
||||||
|
));
|
||||||
|
|
||||||
|
// Create event task
|
||||||
|
tokio::spawn(event::terminal_event_task(
|
||||||
|
250,
|
||||||
|
event_sender,
|
||||||
|
shutdown_sender.subscribe(),
|
||||||
|
shutdown_guard_sender.clone(),
|
||||||
|
));
|
||||||
|
|
||||||
|
// Drop our end of shutdown sender
|
||||||
|
drop(shutdown_guard_sender);
|
||||||
|
|
||||||
|
// Create App
|
||||||
|
let mut app = App::new(
|
||||||
|
run_receiver,
|
||||||
|
tokenizer_name,
|
||||||
|
sequence_length,
|
||||||
|
decode_length,
|
||||||
|
n_runs,
|
||||||
|
batch_size,
|
||||||
|
);
|
||||||
|
|
||||||
|
while app.running {
|
||||||
|
// Draw frame
|
||||||
|
terminal.draw(|frame| app.render(frame))?;
|
||||||
|
|
||||||
|
// Await a new event from event handling task
|
||||||
|
match event_receiver.recv().await {
|
||||||
|
None => break,
|
||||||
|
// Update app state
|
||||||
|
Some(event) => match event {
|
||||||
|
Event::Tick => app.tick(),
|
||||||
|
Event::Key(key_event) => app.handle_key_event(key_event),
|
||||||
|
_ => {}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ask tasks to shutdown
|
||||||
|
let _ = shutdown_sender.send(());
|
||||||
|
// Wait for tasks to shutdown
|
||||||
|
let _ = shutdown_guard_receiver.recv().await;
|
||||||
|
|
||||||
|
// Revert terminal to original view
|
||||||
|
io::stdout().execute(crossterm::terminal::LeaveAlternateScreen)?;
|
||||||
|
crossterm::terminal::disable_raw_mode()?;
|
||||||
|
io::stdout().execute(crossterm::cursor::Show)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
|
@ -0,0 +1,119 @@
|
||||||
|
/// Text Generation Inference benchmarking tool
|
||||||
|
///
|
||||||
|
/// Inspired by the great Oha app: https://github.com/hatoo/oha
|
||||||
|
/// and: https://github.com/orhun/rust-tui-template
|
||||||
|
use clap::Parser;
|
||||||
|
use std::path::Path;
|
||||||
|
use text_generation_client::ShardedClient;
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
use tracing_subscriber::layer::SubscriberExt;
|
||||||
|
use tracing_subscriber::util::SubscriberInitExt;
|
||||||
|
use tracing_subscriber::EnvFilter;
|
||||||
|
|
||||||
|
/// App Configuration
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[clap(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
#[clap(short, long, env)]
|
||||||
|
tokenizer_name: String,
|
||||||
|
#[clap(short, long)]
|
||||||
|
batch_size: Option<Vec<u32>>,
|
||||||
|
#[clap(default_value = "10", short, long, env)]
|
||||||
|
sequence_length: u32,
|
||||||
|
#[clap(default_value = "8", short, long, env)]
|
||||||
|
decode_length: u32,
|
||||||
|
#[clap(default_value = "10", short, long, env)]
|
||||||
|
runs: usize,
|
||||||
|
#[clap(default_value = "1", short, long, env)]
|
||||||
|
warmups: usize,
|
||||||
|
#[clap(default_value = "/tmp/text-generation-server-0", short, long, env)]
|
||||||
|
master_shard_uds_path: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
// Get args
|
||||||
|
let args = Args::parse();
|
||||||
|
// Pattern match configuration
|
||||||
|
let Args {
|
||||||
|
tokenizer_name,
|
||||||
|
batch_size,
|
||||||
|
sequence_length,
|
||||||
|
decode_length,
|
||||||
|
runs,
|
||||||
|
warmups,
|
||||||
|
master_shard_uds_path,
|
||||||
|
} = args;
|
||||||
|
|
||||||
|
let batch_size = batch_size.unwrap_or(vec![1, 2, 4, 8, 16, 32]);
|
||||||
|
|
||||||
|
init_logging();
|
||||||
|
|
||||||
|
// Tokenizer instance
|
||||||
|
// This will only be used to validate payloads
|
||||||
|
tracing::info!("Loading tokenizer");
|
||||||
|
let local_path = Path::new(&tokenizer_name);
|
||||||
|
let tokenizer =
|
||||||
|
if local_path.exists() && local_path.is_dir() && local_path.join("tokenizer.json").exists()
|
||||||
|
{
|
||||||
|
// Load local tokenizer
|
||||||
|
tracing::info!("Found local tokenizer");
|
||||||
|
Tokenizer::from_file(local_path.join("tokenizer.json")).unwrap()
|
||||||
|
} else {
|
||||||
|
// Download and instantiate tokenizer
|
||||||
|
// We need to download it outside of the Tokio runtime
|
||||||
|
tracing::info!("Downloading tokenizer");
|
||||||
|
Tokenizer::from_pretrained(tokenizer_name.clone(), None).unwrap()
|
||||||
|
};
|
||||||
|
tracing::info!("Tokenizer loaded");
|
||||||
|
|
||||||
|
// Launch Tokio runtime
|
||||||
|
tokio::runtime::Builder::new_multi_thread()
|
||||||
|
.enable_all()
|
||||||
|
.build()
|
||||||
|
.unwrap()
|
||||||
|
.block_on(async {
|
||||||
|
// Instantiate sharded client from the master unix socket
|
||||||
|
tracing::info!("Connect to model server");
|
||||||
|
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
||||||
|
.await
|
||||||
|
.expect("Could not connect to server");
|
||||||
|
// Clear the cache; useful if the webserver rebooted
|
||||||
|
sharded_client
|
||||||
|
.clear_cache(None)
|
||||||
|
.await
|
||||||
|
.expect("Unable to clear cache");
|
||||||
|
tracing::info!("Connected");
|
||||||
|
|
||||||
|
// Run app
|
||||||
|
text_generation_benchmark::run(
|
||||||
|
tokenizer_name,
|
||||||
|
tokenizer,
|
||||||
|
batch_size,
|
||||||
|
sequence_length,
|
||||||
|
decode_length,
|
||||||
|
runs,
|
||||||
|
warmups,
|
||||||
|
sharded_client,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
});
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Init logging using LOG_LEVEL
|
||||||
|
fn init_logging() {
|
||||||
|
// STDOUT/STDERR layer
|
||||||
|
let fmt_layer = tracing_subscriber::fmt::layer()
|
||||||
|
.with_file(true)
|
||||||
|
.with_line_number(true);
|
||||||
|
|
||||||
|
// Filter events with LOG_LEVEL
|
||||||
|
let env_filter =
|
||||||
|
EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info"));
|
||||||
|
|
||||||
|
tracing_subscriber::registry()
|
||||||
|
.with(env_filter)
|
||||||
|
.with(fmt_layer)
|
||||||
|
.init();
|
||||||
|
}
|
|
@ -0,0 +1,43 @@
|
||||||
|
/// MIT License
|
||||||
|
//
|
||||||
|
// Copyright (c) 2020 hatoo
|
||||||
|
//
|
||||||
|
// Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
// of this software and associated documentation files (the "Software"), to deal
|
||||||
|
// in the Software without restriction, including without limitation the rights
|
||||||
|
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
// copies of the Software, and to permit persons to whom the Software is
|
||||||
|
// furnished to do so, subject to the following conditions:
|
||||||
|
//
|
||||||
|
// The above copyright notice and this permission notice shall be included in all
|
||||||
|
// copies or substantial portions of the Software.
|
||||||
|
use std::collections::BTreeMap;
|
||||||
|
|
||||||
|
pub(crate) fn histogram(values: &[f64], bins: usize) -> Vec<(f64, usize)> {
|
||||||
|
assert!(bins >= 2);
|
||||||
|
let mut bucket: Vec<usize> = vec![0; bins];
|
||||||
|
let min = values.iter().collect::<average::Min>().min();
|
||||||
|
let max = values.iter().collect::<average::Max>().max();
|
||||||
|
let step = (max - min) / (bins - 1) as f64;
|
||||||
|
|
||||||
|
for &v in values {
|
||||||
|
let i = std::cmp::min(((v - min) / step).ceil() as usize, bins - 1);
|
||||||
|
bucket[i] += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
bucket
|
||||||
|
.into_iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(i, v)| (min + step * i as f64, v))
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn percentiles(values: &[f64], pecents: &[i32]) -> BTreeMap<String, f64> {
|
||||||
|
pecents
|
||||||
|
.iter()
|
||||||
|
.map(|&p| {
|
||||||
|
let i = (f64::from(p) / 100.0 * values.len() as f64) as usize;
|
||||||
|
(format!("p{p}"), *values.get(i).unwrap_or(&std::f64::NAN))
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
|
@ -53,6 +53,9 @@ message StoppingCriteriaParameters {
|
||||||
uint32 max_new_tokens = 1;
|
uint32 max_new_tokens = 1;
|
||||||
/// Optional stopping sequences
|
/// Optional stopping sequences
|
||||||
repeated string stop_sequences = 2;
|
repeated string stop_sequences = 2;
|
||||||
|
/// Ignore end of sequence token
|
||||||
|
/// used for benchmarking
|
||||||
|
bool ignore_eos_token = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
message Request {
|
message Request {
|
||||||
|
|
|
@ -37,7 +37,7 @@ struct Args {
|
||||||
max_waiting_tokens: usize,
|
max_waiting_tokens: usize,
|
||||||
#[clap(default_value = "3000", long, short, env)]
|
#[clap(default_value = "3000", long, short, env)]
|
||||||
port: u16,
|
port: u16,
|
||||||
#[clap(default_value = "/tmp/text-generation-0", long, env)]
|
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
|
||||||
master_shard_uds_path: String,
|
master_shard_uds_path: String,
|
||||||
#[clap(default_value = "bigscience/bloom", long, env)]
|
#[clap(default_value = "bigscience/bloom", long, env)]
|
||||||
tokenizer_name: String,
|
tokenizer_name: String,
|
||||||
|
@ -76,6 +76,8 @@ fn main() -> Result<(), std::io::Error> {
|
||||||
panic!("validation_workers must be > 0");
|
panic!("validation_workers must be > 0");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
init_logging(otlp_endpoint, json_output);
|
||||||
|
|
||||||
// CORS allowed origins
|
// CORS allowed origins
|
||||||
// map to go inside the option and then map to parse from String to HeaderValue
|
// map to go inside the option and then map to parse from String to HeaderValue
|
||||||
// Finally, convert to AllowOrigin
|
// Finally, convert to AllowOrigin
|
||||||
|
@ -89,17 +91,21 @@ fn main() -> Result<(), std::io::Error> {
|
||||||
|
|
||||||
// Tokenizer instance
|
// Tokenizer instance
|
||||||
// This will only be used to validate payloads
|
// This will only be used to validate payloads
|
||||||
|
tracing::info!("Loading tokenizer");
|
||||||
let local_path = Path::new(&tokenizer_name);
|
let local_path = Path::new(&tokenizer_name);
|
||||||
let tokenizer =
|
let tokenizer =
|
||||||
if local_path.exists() && local_path.is_dir() && local_path.join("tokenizer.json").exists()
|
if local_path.exists() && local_path.is_dir() && local_path.join("tokenizer.json").exists()
|
||||||
{
|
{
|
||||||
// Load local tokenizer
|
// Load local tokenizer
|
||||||
|
tracing::info!("Found local tokenizer");
|
||||||
Tokenizer::from_file(local_path.join("tokenizer.json")).unwrap()
|
Tokenizer::from_file(local_path.join("tokenizer.json")).unwrap()
|
||||||
} else {
|
} else {
|
||||||
// Download and instantiate tokenizer
|
// Download and instantiate tokenizer
|
||||||
// We need to download it outside of the Tokio runtime
|
// We need to download it outside of the Tokio runtime
|
||||||
|
tracing::info!("Downloading tokenizer");
|
||||||
Tokenizer::from_pretrained(tokenizer_name.clone(), None).unwrap()
|
Tokenizer::from_pretrained(tokenizer_name.clone(), None).unwrap()
|
||||||
};
|
};
|
||||||
|
tracing::info!("Tokenizer loaded");
|
||||||
|
|
||||||
// Launch Tokio runtime
|
// Launch Tokio runtime
|
||||||
tokio::runtime::Builder::new_multi_thread()
|
tokio::runtime::Builder::new_multi_thread()
|
||||||
|
@ -107,8 +113,6 @@ fn main() -> Result<(), std::io::Error> {
|
||||||
.build()
|
.build()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.block_on(async {
|
.block_on(async {
|
||||||
init_logging(otlp_endpoint, json_output);
|
|
||||||
|
|
||||||
// Get pipeline tag
|
// Get pipeline tag
|
||||||
let model_info = reqwest::get(format!(
|
let model_info = reqwest::get(format!(
|
||||||
"https://huggingface.co/api/models/{tokenizer_name}"
|
"https://huggingface.co/api/models/{tokenizer_name}"
|
||||||
|
|
|
@ -237,6 +237,7 @@ mod tests {
|
||||||
watermark: false,
|
watermark: false,
|
||||||
},
|
},
|
||||||
stopping_parameters: StoppingCriteriaParameters {
|
stopping_parameters: StoppingCriteriaParameters {
|
||||||
|
ignore_eos_token: false,
|
||||||
max_new_tokens: 0,
|
max_new_tokens: 0,
|
||||||
stop_sequences: vec![],
|
stop_sequences: vec![],
|
||||||
},
|
},
|
||||||
|
|
|
@ -315,6 +315,7 @@ fn validate(
|
||||||
let stopping_parameters = StoppingCriteriaParameters {
|
let stopping_parameters = StoppingCriteriaParameters {
|
||||||
max_new_tokens,
|
max_new_tokens,
|
||||||
stop_sequences,
|
stop_sequences,
|
||||||
|
ignore_eos_token: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
metrics::histogram!("tgi_request_input_length", input_length as f64);
|
metrics::histogram!("tgi_request_input_length", input_length as f64);
|
||||||
|
|
|
@ -18,7 +18,7 @@ def serve(
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
sharded: bool = False,
|
sharded: bool = False,
|
||||||
quantize: bool = False,
|
quantize: bool = False,
|
||||||
uds_path: Path = "/tmp/text-generation",
|
uds_path: Path = "/tmp/text-generation-server",
|
||||||
logger_level: str = "INFO",
|
logger_level: str = "INFO",
|
||||||
json_output: bool = False,
|
json_output: bool = False,
|
||||||
otlp_endpoint: Optional[str] = None,
|
otlp_endpoint: Optional[str] = None,
|
||||||
|
|
|
@ -123,20 +123,22 @@ class StoppingCriteria:
|
||||||
self,
|
self,
|
||||||
eos_token_id: int,
|
eos_token_id: int,
|
||||||
stop_sequence_criterias: List[StopSequenceCriteria],
|
stop_sequence_criterias: List[StopSequenceCriteria],
|
||||||
max_new_tokens=20,
|
max_new_tokens: int = 20,
|
||||||
|
ignore_eos_token: bool = False,
|
||||||
):
|
):
|
||||||
self.eos_token_id = eos_token_id
|
self.eos_token_id = eos_token_id
|
||||||
self.stop_sequence_criterias = stop_sequence_criterias
|
self.stop_sequence_criterias = stop_sequence_criterias
|
||||||
self.max_new_tokens = max_new_tokens
|
self.max_new_tokens = max_new_tokens
|
||||||
self.current_tokens = 0
|
self.current_tokens = 0
|
||||||
self.current_output = ""
|
self.current_output = ""
|
||||||
|
self.ignore_eos_token = ignore_eos_token
|
||||||
|
|
||||||
def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
|
def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
|
||||||
self.current_tokens += 1
|
self.current_tokens += 1
|
||||||
if self.current_tokens >= self.max_new_tokens:
|
if self.current_tokens >= self.max_new_tokens:
|
||||||
return True, FinishReason.FINISH_REASON_LENGTH
|
return True, FinishReason.FINISH_REASON_LENGTH
|
||||||
|
|
||||||
if last_token == self.eos_token_id:
|
if not self.ignore_eos_token and last_token == self.eos_token_id:
|
||||||
return True, FinishReason.FINISH_REASON_EOS_TOKEN
|
return True, FinishReason.FINISH_REASON_EOS_TOKEN
|
||||||
|
|
||||||
self.current_output += last_output
|
self.current_output += last_output
|
||||||
|
@ -156,5 +158,8 @@ class StoppingCriteria:
|
||||||
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
|
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
|
||||||
]
|
]
|
||||||
return StoppingCriteria(
|
return StoppingCriteria(
|
||||||
tokenizer.eos_token_id, stop_sequence_criterias, pb.max_new_tokens
|
tokenizer.eos_token_id,
|
||||||
|
stop_sequence_criterias,
|
||||||
|
pb.max_new_tokens,
|
||||||
|
pb.ignore_eos_token,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue