feat(benchmark): tui based benchmarking tool (#149)

This commit is contained in:
OlivierDehaene 2023-03-30 15:26:27 +02:00 committed by GitHub
parent 55106ec476
commit 610bb1f978
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 4133 additions and 7 deletions

View File

@ -5,6 +5,9 @@ members = [
"router/grpc-metadata",
"launcher"
]
exclude = [
"benchmark"
]
[profile.release]
debug = 1

View File

@ -7,6 +7,9 @@ install-router:
install-launcher:
cd launcher && cargo install --path .
install-benchmark:
cd benchmark && cargo install --path .
install: install-server install-router install-launcher
server-dev:

BIN
assets/benchmark.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 102 KiB

1
benchmark/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
target

2801
benchmark/Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

35
benchmark/Cargo.toml Normal file
View File

@ -0,0 +1,35 @@
[package]
name = "text-generation-benchmark"
version = "0.1.0"
edition = "2021"
authors = ["Olivier Dehaene"]
description = "Text Generation Benchmarking tool"
[profile.release]
debug = 1
incremental = true
lto = "off"
panic = "abort"
[lib]
path = "src/lib.rs"
[[bin]]
name = "text-generation-benchmark"
path = "src/main.rs"
[dependencies]
average = "0.13"
clap = { version = "4.1.4", features = ["derive", "env"] }
crossterm = "0.26"
float-ord = "0.3.2"
serde = {version = "1.0.142", features = ["derive"]}
serde_json = "1.0"
text-generation-client = { path = "../router/client" }
thiserror = "1.0.38"
tokenizers = "0.13.2"
tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
tui = {package = "ratatui", version = "0.20", default-features = false, features = ["crossterm"]}
tracing = "0.1.37"
tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] }

30
benchmark/README.md Normal file
View File

@ -0,0 +1,30 @@
<div align="center">
# Text Generation Inference benchmarking tool
![benchmark](../assets/benchmark.png)
</div>
A lightweight benchmarking tool based inspired by [oha](https://github.com/hatoo/oha)
and powered by [tui](https://github.com/tui-rs-revival/ratatui).
## Install
```shell
make install-benchmark
```
## Run
First, start `text-generation-inference`:
```shell
text-generation-launcher --model-id bigscience/bloom-560m
```
Then run the benchmarking tool:
```shell
text-generation-benchmark --tokenizer-name bigscience/bloom-560m
```

View File

@ -0,0 +1,3 @@
[toolchain]
channel = "1.67.0"
components = ["rustfmt", "clippy"]

688
benchmark/src/app.rs Normal file
View File

@ -0,0 +1,688 @@
/// Inspired by https://github.com/hatoo/oha/blob/bb989ea3cd77727e7743e7daa60a19894bb5e901/src/monitor.rs
use crate::generation::{Decode, Message, Prefill};
use crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
use text_generation_client::ClientError;
use tokio::sync::mpsc;
use tui::backend::Backend;
use tui::layout::{Alignment, Constraint, Direction, Layout};
use tui::style::{Color, Modifier, Style};
use tui::text::{Span, Spans};
use tui::widgets::{
Axis, BarChart, Block, Borders, Chart, Dataset, Gauge, GraphType, Paragraph, Tabs,
};
use tui::{symbols, Frame};
/// TUI powered App
pub(crate) struct App {
pub(crate) running: bool,
completed_runs: Vec<usize>,
completed_batch: usize,
current_batch: usize,
current_tab: usize,
touched_tab: bool,
zoom: bool,
is_error: bool,
data: Data,
tokenizer_name: String,
sequence_length: u32,
decode_length: u32,
n_run: usize,
batch_size: Vec<u32>,
receiver: mpsc::Receiver<Result<Message, ClientError>>,
}
impl App {
pub(crate) fn new(
receiver: mpsc::Receiver<Result<Message, ClientError>>,
tokenizer_name: String,
sequence_length: u32,
decode_length: u32,
n_run: usize,
batch_size: Vec<u32>,
) -> Self {
let data = Data::new(n_run, batch_size.len());
let current_tab = 0;
let completed_runs: Vec<usize> = (0..batch_size.len()).map(|_| 0).collect();
let completed_batch = 0;
let current_batch = 0;
let is_error = false;
Self {
running: true,
completed_runs,
completed_batch,
current_batch,
current_tab,
touched_tab: false,
zoom: false,
is_error,
data,
tokenizer_name,
sequence_length,
decode_length,
n_run,
batch_size,
receiver,
}
}
/// Handle crossterm key events
pub(crate) fn handle_key_event(&mut self, key_event: KeyEvent) {
match key_event {
// Increase and wrap tab
KeyEvent {
code: KeyCode::Right,
..
}
| KeyEvent {
code: KeyCode::Tab, ..
} => {
self.touched_tab = true;
self.current_tab = (self.current_tab + 1) % self.batch_size.len();
}
// Decrease and wrap tab
KeyEvent {
code: KeyCode::Left,
..
} => {
self.touched_tab = true;
if self.current_tab > 0 {
self.current_tab -= 1;
} else {
self.current_tab = self.batch_size.len() - 1;
}
}
// Zoom on throughput/latency fig
KeyEvent {
code: KeyCode::Char('+'),
..
} => {
self.zoom = true;
}
// Unzoom on throughput/latency fig
KeyEvent {
code: KeyCode::Char('-'),
..
} => {
self.zoom = false;
}
// Quit
KeyEvent {
code: KeyCode::Char('q'),
..
}
| KeyEvent {
code: KeyCode::Char('c'),
modifiers: KeyModifiers::CONTROL,
..
} => {
self.running = false;
}
_ => (),
}
}
/// Get all pending messages from generation task
pub(crate) fn tick(&mut self) {
while let Ok(message) = self.receiver.try_recv() {
match message {
Ok(message) => match message {
Message::Prefill(step) => self.data.push_prefill(step, self.current_batch),
Message::Decode(step) => self.data.push_decode(step, self.current_batch),
Message::EndRun => {
self.completed_runs[self.current_batch] += 1;
}
Message::EndBatch => {
self.data.end_batch(self.current_batch);
self.completed_batch += 1;
if self.current_batch < self.batch_size.len() - 1 {
// Only go to next tab if the user never touched the tab keys
if !self.touched_tab {
self.current_tab += 1;
}
self.current_batch += 1;
}
}
Message::Warmup => {}
},
Err(_) => self.is_error = true,
}
}
}
/// Render frame
pub fn render<B: Backend>(&mut self, f: &mut Frame<'_, B>) {
let batch_progress =
(self.completed_batch as f64 / self.batch_size.len() as f64).clamp(0.0, 1.0);
let run_progress =
(self.completed_runs[self.current_batch] as f64 / self.n_run as f64).clamp(0.0, 1.0);
// Vertical layout
let row5 = Layout::default()
.direction(Direction::Vertical)
.constraints(
[
Constraint::Length(1),
Constraint::Length(3),
Constraint::Length(3),
Constraint::Length(13),
Constraint::Min(10),
]
.as_ref(),
)
.split(f.size());
// Top row horizontal layout
let top = Layout::default()
.direction(Direction::Horizontal)
.constraints([Constraint::Percentage(50), Constraint::Percentage(50)].as_ref())
.split(row5[2]);
// Mid row horizontal layout
let mid = Layout::default()
.direction(Direction::Horizontal)
.constraints(
[
Constraint::Percentage(25),
Constraint::Percentage(25),
Constraint::Percentage(25),
Constraint::Percentage(25),
]
.as_ref(),
)
.split(row5[3]);
// Left mid row vertical layout
let prefill_text = Layout::default()
.direction(Direction::Vertical)
.constraints([Constraint::Length(8), Constraint::Length(5)].as_ref())
.split(mid[0]);
// Right mid row vertical layout
let decode_text = Layout::default()
.direction(Direction::Vertical)
.constraints([Constraint::Length(8), Constraint::Length(5)].as_ref())
.split(mid[2]);
let decode_text_latency = Layout::default()
.direction(Direction::Horizontal)
.constraints([Constraint::Percentage(50), Constraint::Percentage(50)].as_ref())
.split(decode_text[0]);
// Bottom row horizontal layout
let bottom = Layout::default()
.direction(Direction::Horizontal)
.constraints([Constraint::Percentage(50), Constraint::Percentage(50)].as_ref())
.split(row5[4]);
// Title
let title = Block::default()
.borders(Borders::NONE)
.title(format!(
"Model: {} | Sequence Length: {} | Decode Length: {}",
self.tokenizer_name, self.sequence_length, self.decode_length
))
.style(
Style::default()
.add_modifier(Modifier::BOLD)
.fg(Color::White),
);
f.render_widget(title, row5[0]);
// Helper
let helper = Block::default()
.borders(Borders::NONE)
.title("<- | tab | ->: change batch tab | q / CTRL + c: quit | +/-: zoom")
.title_alignment(Alignment::Right)
.style(Style::default().fg(Color::White));
f.render_widget(helper, row5[0]);
// Batch tabs
let titles = self
.batch_size
.iter()
.map(|b| {
Spans::from(vec![Span::styled(
format!("Batch: {b}"),
Style::default().fg(Color::White),
)])
})
.collect();
let tabs = Tabs::new(titles)
.block(Block::default().borders(Borders::ALL).title("Tabs"))
.select(self.current_tab)
.style(Style::default().fg(Color::LightCyan))
.highlight_style(
Style::default()
.add_modifier(Modifier::BOLD)
.bg(Color::Black),
);
f.render_widget(tabs, row5[1]);
// Total progress bar
let color = if self.is_error {
Color::Red
} else {
Color::LightGreen
};
let batch_gauge = progress_gauge(
"Total Progress",
format!("{} / {}", self.completed_batch, self.batch_size.len()),
batch_progress,
color,
);
f.render_widget(batch_gauge, top[0]);
// Batch progress Bar
let color = if self.is_error {
Color::Red
} else {
Color::LightBlue
};
let run_gauge = progress_gauge(
"Batch Progress",
format!(
"{} / {}",
self.completed_runs[self.current_batch], self.n_run
),
run_progress,
color,
);
f.render_widget(run_gauge, top[1]);
// Prefill text infos
let prefill_latency_block = latency_paragraph(
&mut self.data.prefill_latencies[self.current_tab],
"Prefill",
);
let prefill_throughput_block =
throughput_paragraph(&self.data.prefill_throughputs[self.current_tab], "Prefill");
f.render_widget(prefill_latency_block, prefill_text[0]);
f.render_widget(prefill_throughput_block, prefill_text[1]);
// Prefill latency histogram
let histo_width = 7;
let bins = if mid[1].width < 2 {
0
} else {
(mid[1].width as usize - 2) / (histo_width + 1)
}
.max(2);
let histo_data =
latency_histogram_data(&self.data.prefill_latencies[self.current_tab], bins);
let histo_data_str: Vec<(&str, u64)> =
histo_data.iter().map(|(l, v)| (l.as_str(), *v)).collect();
let prefill_histogram =
latency_histogram(&histo_data_str, "Prefill").bar_width(histo_width as u16);
f.render_widget(prefill_histogram, mid[1]);
// Decode text info
let decode_latency_block = latency_paragraph(
&mut self.data.decode_latencies[self.current_tab],
"Decode Total",
);
let decode_token_latency_block = latency_paragraph(
&mut self.data.decode_token_latencies[self.current_tab],
"Decode Token",
);
let decode_throughput_block =
throughput_paragraph(&self.data.decode_throughputs[self.current_tab], "Decode");
f.render_widget(decode_latency_block, decode_text_latency[0]);
f.render_widget(decode_token_latency_block, decode_text_latency[1]);
f.render_widget(decode_throughput_block, decode_text[1]);
// Decode latency histogram
let histo_data =
latency_histogram_data(&self.data.decode_latencies[self.current_tab], bins);
let histo_data_str: Vec<(&str, u64)> =
histo_data.iter().map(|(l, v)| (l.as_str(), *v)).collect();
let decode_histogram =
latency_histogram(&histo_data_str, "Decode").bar_width(histo_width as u16);
f.render_widget(decode_histogram, mid[3]);
// Prefill latency/throughput chart
let prefill_latency_throughput_chart = latency_throughput_chart(
&self.data.prefill_batch_latency_throughput,
&self.batch_size,
self.zoom,
"Prefill",
);
f.render_widget(prefill_latency_throughput_chart, bottom[0]);
// Decode latency/throughput chart
let decode_latency_throughput_chart = latency_throughput_chart(
&self.data.decode_batch_latency_throughput,
&self.batch_size,
self.zoom,
"Decode",
);
f.render_widget(decode_latency_throughput_chart, bottom[1]);
}
}
/// App internal data struct
struct Data {
prefill_latencies: Vec<Vec<f64>>,
prefill_throughputs: Vec<Vec<f64>>,
decode_latencies: Vec<Vec<f64>>,
decode_token_latencies: Vec<Vec<f64>>,
decode_throughputs: Vec<Vec<f64>>,
prefill_batch_latency_throughput: Vec<(f64, f64)>,
decode_batch_latency_throughput: Vec<(f64, f64)>,
}
impl Data {
fn new(n_run: usize, n_batch: usize) -> Self {
let prefill_latencies: Vec<Vec<f64>> =
(0..n_batch).map(|_| Vec::with_capacity(n_run)).collect();
let prefill_throughputs: Vec<Vec<f64>> = prefill_latencies.clone();
let decode_latencies: Vec<Vec<f64>> = prefill_latencies.clone();
let decode_token_latencies: Vec<Vec<f64>> = decode_latencies.clone();
let decode_throughputs: Vec<Vec<f64>> = prefill_throughputs.clone();
let prefill_batch_latency_throughput: Vec<(f64, f64)> = Vec::with_capacity(n_batch);
let decode_batch_latency_throughput: Vec<(f64, f64)> =
prefill_batch_latency_throughput.clone();
Self {
prefill_latencies,
prefill_throughputs,
decode_latencies,
decode_token_latencies,
decode_throughputs,
prefill_batch_latency_throughput,
decode_batch_latency_throughput,
}
}
fn push_prefill(&mut self, prefill: Prefill, batch_idx: usize) {
let latency = prefill.latency.as_millis() as f64;
self.prefill_latencies[batch_idx].push(latency);
self.prefill_throughputs[batch_idx].push(prefill.throughput);
}
fn push_decode(&mut self, decode: Decode, batch_idx: usize) {
let latency = decode.latency.as_millis() as f64;
let token_latency = decode.token_latency.as_millis() as f64;
self.decode_latencies[batch_idx].push(latency);
self.decode_token_latencies[batch_idx].push(token_latency);
self.decode_throughputs[batch_idx].push(decode.throughput);
}
fn end_batch(&mut self, batch_idx: usize) {
self.prefill_batch_latency_throughput.push((
self.prefill_latencies[batch_idx].iter().sum::<f64>()
/ self.prefill_latencies[batch_idx].len() as f64,
self.prefill_throughputs[batch_idx].iter().sum::<f64>()
/ self.prefill_throughputs[batch_idx].len() as f64,
));
self.decode_batch_latency_throughput.push((
self.decode_latencies[batch_idx].iter().sum::<f64>()
/ self.decode_latencies[batch_idx].len() as f64,
self.decode_throughputs[batch_idx].iter().sum::<f64>()
/ self.decode_throughputs[batch_idx].len() as f64,
));
}
}
/// Progress bar
fn progress_gauge(title: &str, label: String, progress: f64, color: Color) -> Gauge {
Gauge::default()
.block(Block::default().title(title).borders(Borders::ALL))
.gauge_style(Style::default().fg(color))
.label(Span::raw(label))
.ratio(progress)
}
/// Throughput paragraph
fn throughput_paragraph<'a>(throughput: &Vec<f64>, name: &'static str) -> Paragraph<'a> {
// Throughput average/high/low texts
let throughput_texts = statis_spans(throughput, "tokens/secs");
// Throughput block
Paragraph::new(throughput_texts).block(
Block::default()
.title(Span::raw(format!("{name} Throughput")))
.borders(Borders::ALL),
)
}
/// Latency paragraph
fn latency_paragraph<'a>(latency: &mut Vec<f64>, name: &'static str) -> Paragraph<'a> {
// Latency average/high/low texts
let mut latency_texts = statis_spans(latency, "ms");
// Sort latency for percentiles
float_ord::sort(latency);
let latency_percentiles = crate::utils::percentiles(latency, &[50, 90, 99]);
// Latency p50/p90/p99 texts
let colors = vec![Color::LightGreen, Color::LightYellow, Color::LightRed];
for (i, (name, value)) in latency_percentiles.iter().enumerate() {
let span = Spans::from(vec![Span::styled(
format!("{name}: {value:.2} ms"),
Style::default().fg(colors[i]),
)]);
latency_texts.push(span);
}
Paragraph::new(latency_texts).block(
Block::default()
.title(Span::raw(format!("{name} Latency")))
.borders(Borders::ALL),
)
}
/// Average/High/Low spans
fn statis_spans<'a>(data: &Vec<f64>, unit: &'static str) -> Vec<Spans<'a>> {
vec![
Spans::from(vec![Span::styled(
format!(
"Average: {:.2} {unit}",
data.iter().sum::<f64>() / data.len() as f64
),
Style::default().fg(Color::LightBlue),
)]),
Spans::from(vec![Span::styled(
format!(
"Lowest: {:.2} {unit}",
data.iter()
.min_by(|a, b| a.total_cmp(b))
.unwrap_or(&std::f64::NAN)
),
Style::default().fg(Color::Reset),
)]),
Spans::from(vec![Span::styled(
format!(
"Highest: {:.2} {unit}",
data.iter()
.max_by(|a, b| a.total_cmp(b))
.unwrap_or(&std::f64::NAN)
),
Style::default().fg(Color::Reset),
)]),
]
}
/// Latency histogram data
fn latency_histogram_data(latency: &[f64], bins: usize) -> Vec<(String, u64)> {
let histo_data: Vec<(String, u64)> = {
let histo = crate::utils::histogram(latency, bins);
histo
.into_iter()
.map(|(label, v)| (format!("{label:.2}"), v as u64))
.collect()
};
histo_data
}
/// Latency Histogram
fn latency_histogram<'a>(
histo_data_str: &'a Vec<(&'a str, u64)>,
name: &'static str,
) -> BarChart<'a> {
BarChart::default()
.block(
Block::default()
.title(format!("{name} latency histogram"))
.style(Style::default().fg(Color::LightYellow).bg(Color::Reset))
.borders(Borders::ALL),
)
.data(histo_data_str.as_slice())
}
/// Latency/Throughput chart
fn latency_throughput_chart<'a>(
latency_throughput: &'a Vec<(f64, f64)>,
batch_sizes: &'a [u32],
zoom: bool,
name: &'static str,
) -> Chart<'a> {
let latency_iter = latency_throughput.iter().map(|(l, _)| l);
let throughput_iter = latency_throughput.iter().map(|(_, t)| t);
// Get extreme values
let min_latency: f64 = *latency_iter
.clone()
.min_by(|a, b| a.total_cmp(b))
.unwrap_or(&std::f64::NAN);
let max_latency: f64 = *latency_iter
.max_by(|a, b| a.total_cmp(b))
.unwrap_or(&std::f64::NAN);
let min_throughput: f64 = *throughput_iter
.clone()
.min_by(|a, b| a.total_cmp(b))
.unwrap_or(&std::f64::NAN);
let max_throughput: f64 = *throughput_iter
.max_by(|a, b| a.total_cmp(b))
.unwrap_or(&std::f64::NAN);
// Char min max values
let min_x = if zoom {
((min_latency - 0.05 * min_latency) / 100.0).floor() * 100.0
} else {
0.0
};
let max_x = ((max_latency + 0.05 * max_latency) / 100.0).ceil() * 100.0;
let step_x = (max_x - min_x) / 4.0;
// Chart min max values
let min_y = if zoom {
((min_throughput - 0.05 * min_throughput) / 100.0).floor() * 100.0
} else {
0.0
};
let max_y = ((max_throughput + 0.05 * max_throughput) / 100.0).ceil() * 100.0;
let step_y = (max_y - min_y) / 4.0;
// Labels
let mut x_labels = vec![Span::styled(
format!("{min_x:.2}"),
Style::default()
.add_modifier(Modifier::BOLD)
.fg(Color::Gray)
.bg(Color::Reset),
)];
for i in 0..3 {
x_labels.push(Span::styled(
format!("{:.2}", min_x + ((i + 1) as f64 * step_x)),
Style::default().fg(Color::Gray).bg(Color::Reset),
));
}
x_labels.push(Span::styled(
format!("{max_x:.2}"),
Style::default()
.add_modifier(Modifier::BOLD)
.fg(Color::Gray)
.bg(Color::Reset),
));
// Labels
let mut y_labels = vec![Span::styled(
format!("{min_y:.2}"),
Style::default()
.add_modifier(Modifier::BOLD)
.fg(Color::Gray)
.bg(Color::Reset),
)];
for i in 0..3 {
y_labels.push(Span::styled(
format!("{:.2}", min_y + ((i + 1) as f64 * step_y)),
Style::default().fg(Color::Gray).bg(Color::Reset),
));
}
y_labels.push(Span::styled(
format!("{max_y:.2}"),
Style::default()
.add_modifier(Modifier::BOLD)
.fg(Color::Gray)
.bg(Color::Reset),
));
// Chart dataset
let colors = color_vec();
let datasets: Vec<Dataset> = (0..latency_throughput.len())
.map(|i| {
let color_idx = i % colors.len();
Dataset::default()
.name(batch_sizes[i].to_string())
.marker(symbols::Marker::Block)
.style(Style::default().fg(colors[color_idx]))
.graph_type(GraphType::Scatter)
.data(&latency_throughput[i..(i + 1)])
})
.collect();
// Chart
Chart::new(datasets)
.style(Style::default().fg(Color::Cyan).bg(Color::Reset))
.block(
Block::default()
.title(Span::styled(
format!("{name} throughput over latency"),
Style::default().fg(Color::Gray).bg(Color::Reset),
))
.borders(Borders::ALL),
)
.x_axis(
Axis::default()
.title("ms")
.style(Style::default().fg(Color::Gray).bg(Color::Reset))
.labels(x_labels)
.bounds([min_x, max_x]),
)
.y_axis(
Axis::default()
.title("tokens/secs")
.style(Style::default().fg(Color::Gray).bg(Color::Reset))
.labels(y_labels)
.bounds([min_y, max_y]),
)
}
// Colors for latency/throughput chart
fn color_vec() -> Vec<Color> {
vec![
Color::Red,
Color::Green,
Color::Yellow,
Color::Blue,
Color::Magenta,
Color::Cyan,
Color::Gray,
Color::DarkGray,
Color::LightRed,
Color::LightGreen,
Color::LightYellow,
Color::LightBlue,
Color::LightMagenta,
Color::LightCyan,
]
}

65
benchmark/src/event.rs Normal file
View File

@ -0,0 +1,65 @@
/// Inspired by https://github.com/orhun/rust-tui-template/blob/472aa515119d4c94903eac12d9784417281dc7f5/src/event.rs
use crossterm::event;
use std::time::{Duration, Instant};
use tokio::sync::{broadcast, mpsc};
/// Events
#[derive(Debug)]
pub(crate) enum Event {
/// Terminal tick.
Tick,
/// Key press.
Key(event::KeyEvent),
/// Terminal resize.
Resize(u16, u16),
}
pub(crate) async fn terminal_event_task(
fps: u32,
event_sender: mpsc::Sender<Event>,
mut shutdown_receiver: broadcast::Receiver<()>,
_shutdown_guard_sender: mpsc::Sender<()>,
) {
// End task if a message is received on shutdown_receiver
// _shutdown_guard_sender will be dropped once the task is finished
tokio::select! {
_ = event_loop(fps, event_sender) => {
},
_ = shutdown_receiver.recv() => {}
}
}
/// Main event loop
async fn event_loop(fps: u32, event_sender: mpsc::Sender<Event>) {
// Frame budget
let per_frame = Duration::from_secs(1) / fps;
// When was last frame executed
let mut last_frame = Instant::now();
loop {
// Sleep to avoid blocking the thread for too long
if let Some(sleep) = per_frame.checked_sub(last_frame.elapsed()) {
tokio::time::sleep(sleep).await;
}
// Get crossterm event and send a new one over the channel
if event::poll(Duration::from_secs(0)).expect("no events available") {
match event::read().expect("unable to read event") {
event::Event::Key(e) => event_sender.send(Event::Key(e)).await.unwrap_or(()),
event::Event::Resize(w, h) => {
event_sender.send(Event::Resize(w, h)).await.unwrap_or(())
}
_ => (),
}
}
// Frame budget exceeded
if last_frame.elapsed() >= per_frame {
// Send tick
event_sender.send(Event::Tick).await.unwrap_or(());
// Rest last_frame time
last_frame = Instant::now();
}
}
}

211
benchmark/src/generation.rs Normal file
View File

@ -0,0 +1,211 @@
use std::time::{Duration, Instant};
use text_generation_client::{
Batch, ClientError, NextTokenChooserParameters, Request, ShardedClient,
StoppingCriteriaParameters,
};
use tokenizers::{Tokenizer, TruncationDirection};
use tokio::sync::{broadcast, mpsc};
const LOREM_IPSUM: &str = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.";
#[derive(Debug, Clone)]
pub(crate) struct Prefill {
pub(crate) latency: Duration,
pub(crate) throughput: f64,
}
#[derive(Debug, Clone)]
pub(crate) struct Decode {
pub(crate) latency: Duration,
pub(crate) token_latency: Duration,
pub(crate) throughput: f64,
}
#[derive(Debug)]
pub(crate) enum Message {
Warmup,
Prefill(Prefill),
Decode(Decode),
EndRun,
EndBatch,
}
/// Benchmarking task
#[allow(clippy::too_many_arguments)]
pub(crate) async fn generation_task(
tokenizer: Tokenizer,
batch_size: Vec<u32>,
sequence_length: u32,
decode_length: u32,
n_runs: usize,
warmups: usize,
client: ShardedClient,
run_sender: mpsc::Sender<Result<Message, ClientError>>,
mut shutdown_receiver: broadcast::Receiver<()>,
_shutdown_guard_sender: mpsc::Sender<()>,
) {
// End task if a message is received on shutdown_receiver
// _shutdown_guard_sender will be dropped once the task is finished
tokio::select! {
res = generate_runs(tokenizer, batch_size, sequence_length, decode_length, n_runs, warmups, client, run_sender.clone()) => {
if let Err(err) = res {
run_sender.send(Err(err)).await.unwrap_or(());
}
},
_ = shutdown_receiver.recv() => {}
}
}
/// Benchmark prefill/decode
#[allow(clippy::too_many_arguments)]
async fn generate_runs(
tokenizer: Tokenizer,
batch_size: Vec<u32>,
sequence_length: u32,
decode_length: u32,
n_runs: usize,
warmups: usize,
mut client: ShardedClient,
run_sender: mpsc::Sender<Result<Message, ClientError>>,
) -> Result<(), ClientError> {
// Create a dummy sequence
let sequence = create_sequence(sequence_length, tokenizer);
for b in batch_size {
// Warmups on batch size
for _ in 0..warmups {
let (_, decode_batch) =
prefill(sequence.clone(), b, decode_length, &mut client).await?;
let _ = decode(decode_batch, &mut client).await?;
// Send warmup message
run_sender.send(Ok(Message::Warmup)).await.unwrap_or(());
}
for _ in 0..n_runs {
let (prefill, decode_batch) =
prefill(sequence.clone(), b, decode_length, &mut client).await?;
// Send prefill message
run_sender
.send(Ok(Message::Prefill(prefill)))
.await
.unwrap_or(());
let decode = decode(decode_batch, &mut client).await?;
// Send decode message
run_sender
.send(Ok(Message::Decode(decode)))
.await
.unwrap_or(());
// Send run ended message
run_sender.send(Ok(Message::EndRun)).await.unwrap_or(());
}
// Batch ended
run_sender.send(Ok(Message::EndBatch)).await.unwrap_or(());
}
Ok(())
}
// Run a prefill step
async fn prefill(
sequence: String,
batch_size: u32,
decode_length: u32,
client: &mut ShardedClient,
) -> Result<(Prefill, Batch), ClientError> {
// Create requests
let requests = (0..batch_size)
.map(|id| Request {
id: id.into(),
inputs: sequence.clone(),
parameters: Some(NextTokenChooserParameters {
temperature: 1.0,
top_k: 0,
top_p: 1.0,
typical_p: 1.0,
do_sample: false,
seed: 0,
repetition_penalty: 1.0,
watermark: false,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: decode_length,
stop_sequences: vec![],
ignore_eos_token: true, // Will not stop even if a eos token is generated
}),
})
.collect();
let batch = Batch {
id: 0,
requests,
size: batch_size,
};
// Run prefill
let start_time = Instant::now();
let (_, decode_batch) = client.prefill(batch.clone()).await?;
// Get latency
let latency = start_time.elapsed();
// Compute throughput from latency and batch size
let throughput = batch_size as f64 / latency.as_secs_f64();
// Decode batch cannot be empty
let decode_batch = decode_batch.expect("decode_batch is None. This is a bug.");
let step = Prefill {
latency,
throughput,
};
Ok((step, decode_batch))
}
/// Run a full decode
async fn decode(batch: Batch, client: &mut ShardedClient) -> Result<Decode, ClientError> {
let mut decode_length = 0;
let batch_size = batch.size;
let start_time = Instant::now();
// Full decode over decode length
let mut next_batch = Some(batch);
while let Some(batch) = next_batch {
let result = client.decode(vec![batch]).await?;
next_batch = result.1;
decode_length += 1;
}
// Get latency
let latency = start_time.elapsed();
let token_latency = latency / decode_length;
// Compute throughput from latency, batch size and decode length
let throughput = (batch_size * decode_length) as f64 / latency.as_secs_f64();
let step = Decode {
latency,
token_latency,
throughput,
};
Ok(step)
}
/// Create a dummy sequence of the correct length
fn create_sequence(sequence_length: u32, tokenizer: Tokenizer) -> String {
let lorem_ipsum_length = tokenizer.encode(LOREM_IPSUM, true).unwrap().len();
// Repeat lorem ipsum to cover sequence length
let string_sequence =
LOREM_IPSUM.repeat((0..sequence_length).step_by(lorem_ipsum_length).len());
// Encode sequence
let mut encoding = tokenizer.encode(string_sequence, true).unwrap();
// Truncate to sequence_length
encoding.truncate(sequence_length as usize, 0, TruncationDirection::Left);
// Decode
tokenizer
.decode(Vec::from(encoding.get_ids()), false)
.unwrap()
}

110
benchmark/src/lib.rs Normal file
View File

@ -0,0 +1,110 @@
mod app;
mod event;
mod generation;
mod utils;
use crate::app::App;
use crate::event::Event;
use crossterm::ExecutableCommand;
use std::io;
use text_generation_client::ShardedClient;
use tokenizers::Tokenizer;
use tokio::sync::{broadcast, mpsc};
use tui::backend::CrosstermBackend;
use tui::Terminal;
/// Run benchmarking app
#[allow(clippy::too_many_arguments)]
pub async fn run(
tokenizer_name: String,
tokenizer: Tokenizer,
batch_size: Vec<u32>,
sequence_length: u32,
decode_length: u32,
n_runs: usize,
warmups: usize,
client: ShardedClient,
) -> Result<(), crossterm::ErrorKind> {
// Initialize terminal properties
crossterm::terminal::enable_raw_mode()?;
io::stdout().execute(crossterm::terminal::EnterAlternateScreen)?;
io::stdout().execute(crossterm::cursor::Hide)?;
// Initialize terminal
let mut terminal = {
let backend = CrosstermBackend::new(io::stdout());
Terminal::new(backend)?
};
// Create message channel between generation_task and app
let (run_sender, run_receiver) = mpsc::channel(8);
// Crossterm event channel
let (event_sender, mut event_receiver) = mpsc::channel(8);
// Shutdown channel to terminate tasks
let (shutdown_sender, _) = broadcast::channel(1);
// Channel to check if tasks terminated
let (shutdown_guard_sender, mut shutdown_guard_receiver) = mpsc::channel(1);
// Create generation task
tokio::spawn(generation::generation_task(
tokenizer,
batch_size.clone(),
sequence_length,
decode_length,
n_runs,
warmups,
client,
run_sender,
shutdown_sender.subscribe(),
shutdown_guard_sender.clone(),
));
// Create event task
tokio::spawn(event::terminal_event_task(
250,
event_sender,
shutdown_sender.subscribe(),
shutdown_guard_sender.clone(),
));
// Drop our end of shutdown sender
drop(shutdown_guard_sender);
// Create App
let mut app = App::new(
run_receiver,
tokenizer_name,
sequence_length,
decode_length,
n_runs,
batch_size,
);
while app.running {
// Draw frame
terminal.draw(|frame| app.render(frame))?;
// Await a new event from event handling task
match event_receiver.recv().await {
None => break,
// Update app state
Some(event) => match event {
Event::Tick => app.tick(),
Event::Key(key_event) => app.handle_key_event(key_event),
_ => {}
},
}
}
// Ask tasks to shutdown
let _ = shutdown_sender.send(());
// Wait for tasks to shutdown
let _ = shutdown_guard_receiver.recv().await;
// Revert terminal to original view
io::stdout().execute(crossterm::terminal::LeaveAlternateScreen)?;
crossterm::terminal::disable_raw_mode()?;
io::stdout().execute(crossterm::cursor::Show)?;
Ok(())
}

119
benchmark/src/main.rs Normal file
View File

@ -0,0 +1,119 @@
/// Text Generation Inference benchmarking tool
///
/// Inspired by the great Oha app: https://github.com/hatoo/oha
/// and: https://github.com/orhun/rust-tui-template
use clap::Parser;
use std::path::Path;
use text_generation_client::ShardedClient;
use tokenizers::Tokenizer;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::EnvFilter;
/// App Configuration
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
#[clap(short, long, env)]
tokenizer_name: String,
#[clap(short, long)]
batch_size: Option<Vec<u32>>,
#[clap(default_value = "10", short, long, env)]
sequence_length: u32,
#[clap(default_value = "8", short, long, env)]
decode_length: u32,
#[clap(default_value = "10", short, long, env)]
runs: usize,
#[clap(default_value = "1", short, long, env)]
warmups: usize,
#[clap(default_value = "/tmp/text-generation-server-0", short, long, env)]
master_shard_uds_path: String,
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
// Get args
let args = Args::parse();
// Pattern match configuration
let Args {
tokenizer_name,
batch_size,
sequence_length,
decode_length,
runs,
warmups,
master_shard_uds_path,
} = args;
let batch_size = batch_size.unwrap_or(vec![1, 2, 4, 8, 16, 32]);
init_logging();
// Tokenizer instance
// This will only be used to validate payloads
tracing::info!("Loading tokenizer");
let local_path = Path::new(&tokenizer_name);
let tokenizer =
if local_path.exists() && local_path.is_dir() && local_path.join("tokenizer.json").exists()
{
// Load local tokenizer
tracing::info!("Found local tokenizer");
Tokenizer::from_file(local_path.join("tokenizer.json")).unwrap()
} else {
// Download and instantiate tokenizer
// We need to download it outside of the Tokio runtime
tracing::info!("Downloading tokenizer");
Tokenizer::from_pretrained(tokenizer_name.clone(), None).unwrap()
};
tracing::info!("Tokenizer loaded");
// Launch Tokio runtime
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap()
.block_on(async {
// Instantiate sharded client from the master unix socket
tracing::info!("Connect to model server");
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
.await
.expect("Could not connect to server");
// Clear the cache; useful if the webserver rebooted
sharded_client
.clear_cache(None)
.await
.expect("Unable to clear cache");
tracing::info!("Connected");
// Run app
text_generation_benchmark::run(
tokenizer_name,
tokenizer,
batch_size,
sequence_length,
decode_length,
runs,
warmups,
sharded_client,
)
.await
.unwrap();
});
Ok(())
}
/// Init logging using LOG_LEVEL
fn init_logging() {
// STDOUT/STDERR layer
let fmt_layer = tracing_subscriber::fmt::layer()
.with_file(true)
.with_line_number(true);
// Filter events with LOG_LEVEL
let env_filter =
EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info"));
tracing_subscriber::registry()
.with(env_filter)
.with(fmt_layer)
.init();
}

43
benchmark/src/utils.rs Normal file
View File

@ -0,0 +1,43 @@
/// MIT License
//
// Copyright (c) 2020 hatoo
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
use std::collections::BTreeMap;
pub(crate) fn histogram(values: &[f64], bins: usize) -> Vec<(f64, usize)> {
assert!(bins >= 2);
let mut bucket: Vec<usize> = vec![0; bins];
let min = values.iter().collect::<average::Min>().min();
let max = values.iter().collect::<average::Max>().max();
let step = (max - min) / (bins - 1) as f64;
for &v in values {
let i = std::cmp::min(((v - min) / step).ceil() as usize, bins - 1);
bucket[i] += 1;
}
bucket
.into_iter()
.enumerate()
.map(|(i, v)| (min + step * i as f64, v))
.collect()
}
pub(crate) fn percentiles(values: &[f64], pecents: &[i32]) -> BTreeMap<String, f64> {
pecents
.iter()
.map(|&p| {
let i = (f64::from(p) / 100.0 * values.len() as f64) as usize;
(format!("p{p}"), *values.get(i).unwrap_or(&std::f64::NAN))
})
.collect()
}

View File

@ -53,6 +53,9 @@ message StoppingCriteriaParameters {
uint32 max_new_tokens = 1;
/// Optional stopping sequences
repeated string stop_sequences = 2;
/// Ignore end of sequence token
/// used for benchmarking
bool ignore_eos_token = 3;
}
message Request {

View File

@ -37,7 +37,7 @@ struct Args {
max_waiting_tokens: usize,
#[clap(default_value = "3000", long, short, env)]
port: u16,
#[clap(default_value = "/tmp/text-generation-0", long, env)]
#[clap(default_value = "/tmp/text-generation-server-0", long, env)]
master_shard_uds_path: String,
#[clap(default_value = "bigscience/bloom", long, env)]
tokenizer_name: String,
@ -76,6 +76,8 @@ fn main() -> Result<(), std::io::Error> {
panic!("validation_workers must be > 0");
}
init_logging(otlp_endpoint, json_output);
// CORS allowed origins
// map to go inside the option and then map to parse from String to HeaderValue
// Finally, convert to AllowOrigin
@ -89,17 +91,21 @@ fn main() -> Result<(), std::io::Error> {
// Tokenizer instance
// This will only be used to validate payloads
tracing::info!("Loading tokenizer");
let local_path = Path::new(&tokenizer_name);
let tokenizer =
if local_path.exists() && local_path.is_dir() && local_path.join("tokenizer.json").exists()
{
// Load local tokenizer
tracing::info!("Found local tokenizer");
Tokenizer::from_file(local_path.join("tokenizer.json")).unwrap()
} else {
// Download and instantiate tokenizer
// We need to download it outside of the Tokio runtime
tracing::info!("Downloading tokenizer");
Tokenizer::from_pretrained(tokenizer_name.clone(), None).unwrap()
};
tracing::info!("Tokenizer loaded");
// Launch Tokio runtime
tokio::runtime::Builder::new_multi_thread()
@ -107,8 +113,6 @@ fn main() -> Result<(), std::io::Error> {
.build()
.unwrap()
.block_on(async {
init_logging(otlp_endpoint, json_output);
// Get pipeline tag
let model_info = reqwest::get(format!(
"https://huggingface.co/api/models/{tokenizer_name}"

View File

@ -237,6 +237,7 @@ mod tests {
watermark: false,
},
stopping_parameters: StoppingCriteriaParameters {
ignore_eos_token: false,
max_new_tokens: 0,
stop_sequences: vec![],
},

View File

@ -315,6 +315,7 @@ fn validate(
let stopping_parameters = StoppingCriteriaParameters {
max_new_tokens,
stop_sequences,
ignore_eos_token: false,
};
metrics::histogram!("tgi_request_input_length", input_length as f64);

View File

@ -18,7 +18,7 @@ def serve(
revision: Optional[str] = None,
sharded: bool = False,
quantize: bool = False,
uds_path: Path = "/tmp/text-generation",
uds_path: Path = "/tmp/text-generation-server",
logger_level: str = "INFO",
json_output: bool = False,
otlp_endpoint: Optional[str] = None,

View File

@ -123,20 +123,22 @@ class StoppingCriteria:
self,
eos_token_id: int,
stop_sequence_criterias: List[StopSequenceCriteria],
max_new_tokens=20,
max_new_tokens: int = 20,
ignore_eos_token: bool = False,
):
self.eos_token_id = eos_token_id
self.stop_sequence_criterias = stop_sequence_criterias
self.max_new_tokens = max_new_tokens
self.current_tokens = 0
self.current_output = ""
self.ignore_eos_token = ignore_eos_token
def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
self.current_tokens += 1
if self.current_tokens >= self.max_new_tokens:
return True, FinishReason.FINISH_REASON_LENGTH
if last_token == self.eos_token_id:
if not self.ignore_eos_token and last_token == self.eos_token_id:
return True, FinishReason.FINISH_REASON_EOS_TOKEN
self.current_output += last_output
@ -156,5 +158,8 @@ class StoppingCriteria:
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
]
return StoppingCriteria(
tokenizer.eos_token_id, stop_sequence_criterias, pb.max_new_tokens
tokenizer.eos_token_id,
stop_sequence_criterias,
pb.max_new_tokens,
pb.ignore_eos_token,
)