209 lines
7.1 KiB
Rust
209 lines
7.1 KiB
Rust
/// Text Generation Inference benchmarking tool
|
|
///
|
|
/// Inspired by the great Oha app: https://github.com/hatoo/oha
|
|
/// and: https://github.com/orhun/rust-tui-template
|
|
use clap::Parser;
|
|
use std::path::Path;
|
|
use text_generation_client::ShardedClient;
|
|
use tokenizers::{FromPretrainedParameters, Tokenizer};
|
|
use tracing_subscriber::layer::SubscriberExt;
|
|
use tracing_subscriber::util::SubscriberInitExt;
|
|
use tracing_subscriber::EnvFilter;
|
|
|
|
/// App Configuration
|
|
#[derive(Parser, Debug)]
|
|
#[clap(author, version, about, long_about = None)]
|
|
struct Args {
|
|
/// The name of the tokenizer (as in model_id on the huggingface hub, or local path).
|
|
#[clap(short, long, env)]
|
|
tokenizer_name: String,
|
|
|
|
/// The revision to use for the tokenizer if on the hub.
|
|
#[clap(default_value = "main", long, env)]
|
|
revision: String,
|
|
|
|
/// The various batch sizes to benchmark for, the idea is to get enough
|
|
/// batching to start seeing increased latency, this usually means you're
|
|
/// moving from memory bound (usual as BS=1) to compute bound, and this is
|
|
/// a sweet spot for the maximum batch size for the model under test
|
|
#[clap(short, long)]
|
|
batch_size: Option<Vec<u32>>,
|
|
|
|
/// This is the initial prompt sent to the text-generation-server length
|
|
/// in token. Longer prompt will slow down the benchmark. Usually the
|
|
/// latency grows somewhat linearly with this for the prefill step.
|
|
///
|
|
/// Most importantly, the prefill step is usually not the one dominating
|
|
/// your runtime, so it's ok to keep it short.
|
|
#[clap(default_value = "10", short, long, env)]
|
|
sequence_length: u32,
|
|
|
|
/// This is how many tokens will be generated by the server and averaged out
|
|
/// to give the `decode` latency. This is the *critical* number you want to optimize for
|
|
/// LLM spend most of their time doing decoding.
|
|
///
|
|
/// Decode latency is usually quite stable.
|
|
#[clap(default_value = "8", short, long, env)]
|
|
decode_length: u32,
|
|
|
|
///How many runs should we average from
|
|
#[clap(default_value = "10", short, long, env)]
|
|
runs: usize,
|
|
|
|
/// Number of warmup cycles
|
|
#[clap(default_value = "1", short, long, env)]
|
|
warmups: usize,
|
|
|
|
/// The location of the grpc socket. This benchmark tool bypasses the router
|
|
/// completely and directly talks to the gRPC processes
|
|
#[clap(default_value = "/tmp/text-generation-server-0", short, long, env)]
|
|
master_shard_uds_path: String,
|
|
|
|
/// Generation parameter in case you want to specifically test/debug particular
|
|
/// decoding strategies, for full doc refer to the `text-generation-server`
|
|
#[clap(long, env)]
|
|
temperature: Option<f32>,
|
|
|
|
/// Generation parameter in case you want to specifically test/debug particular
|
|
/// decoding strategies, for full doc refer to the `text-generation-server`
|
|
#[clap(long, env)]
|
|
top_k: Option<u32>,
|
|
|
|
/// Generation parameter in case you want to specifically test/debug particular
|
|
/// decoding strategies, for full doc refer to the `text-generation-server`
|
|
#[clap(long, env)]
|
|
top_p: Option<f32>,
|
|
|
|
/// Generation parameter in case you want to specifically test/debug particular
|
|
/// decoding strategies, for full doc refer to the `text-generation-server`
|
|
#[clap(long, env)]
|
|
typical_p: Option<f32>,
|
|
|
|
/// Generation parameter in case you want to specifically test/debug particular
|
|
/// decoding strategies, for full doc refer to the `text-generation-server`
|
|
#[clap(long, env)]
|
|
repetition_penalty: Option<f32>,
|
|
|
|
/// Generation parameter in case you want to specifically test/debug particular
|
|
/// decoding strategies, for full doc refer to the `text-generation-server`
|
|
#[clap(long, env)]
|
|
watermark: bool,
|
|
|
|
/// Generation parameter in case you want to specifically test/debug particular
|
|
/// decoding strategies, for full doc refer to the `text-generation-server`
|
|
#[clap(long, env)]
|
|
do_sample: bool,
|
|
}
|
|
|
|
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|
init_logging();
|
|
|
|
// Get args
|
|
let args = Args::parse();
|
|
// Pattern match configuration
|
|
let Args {
|
|
tokenizer_name,
|
|
revision,
|
|
batch_size,
|
|
sequence_length,
|
|
decode_length,
|
|
runs,
|
|
warmups,
|
|
temperature,
|
|
top_k,
|
|
top_p,
|
|
typical_p,
|
|
repetition_penalty,
|
|
watermark,
|
|
do_sample,
|
|
master_shard_uds_path,
|
|
} = args;
|
|
|
|
let batch_size = batch_size.unwrap_or(vec![1, 2, 4, 8, 16, 32]);
|
|
|
|
// Tokenizer instance
|
|
// This will only be used to validate payloads
|
|
tracing::info!("Loading tokenizer");
|
|
let local_path = Path::new(&tokenizer_name);
|
|
let tokenizer =
|
|
if local_path.exists() && local_path.is_dir() && local_path.join("tokenizer.json").exists()
|
|
{
|
|
// Load local tokenizer
|
|
tracing::info!("Found local tokenizer");
|
|
Tokenizer::from_file(local_path.join("tokenizer.json")).unwrap()
|
|
} else {
|
|
tracing::info!("Downloading tokenizer");
|
|
|
|
// Parse Huggingface hub token
|
|
let auth_token = std::env::var("HUGGING_FACE_HUB_TOKEN").ok();
|
|
|
|
// Download and instantiate tokenizer
|
|
// We need to download it outside of the Tokio runtime
|
|
let params = FromPretrainedParameters {
|
|
revision,
|
|
auth_token,
|
|
..Default::default()
|
|
};
|
|
Tokenizer::from_pretrained(tokenizer_name.clone(), Some(params)).unwrap()
|
|
};
|
|
tracing::info!("Tokenizer loaded");
|
|
|
|
// Launch Tokio runtime
|
|
tokio::runtime::Builder::new_multi_thread()
|
|
.enable_all()
|
|
.build()
|
|
.unwrap()
|
|
.block_on(async {
|
|
// Instantiate sharded client from the master unix socket
|
|
tracing::info!("Connect to model server");
|
|
let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
|
|
.await
|
|
.expect("Could not connect to server");
|
|
// Clear the cache; useful if the webserver rebooted
|
|
sharded_client
|
|
.clear_cache(None)
|
|
.await
|
|
.expect("Unable to clear cache");
|
|
tracing::info!("Connected");
|
|
|
|
// Run app
|
|
text_generation_benchmark::run(
|
|
tokenizer_name,
|
|
tokenizer,
|
|
batch_size,
|
|
sequence_length,
|
|
decode_length,
|
|
runs,
|
|
warmups,
|
|
temperature,
|
|
top_k,
|
|
top_p,
|
|
typical_p,
|
|
repetition_penalty,
|
|
watermark,
|
|
do_sample,
|
|
sharded_client,
|
|
)
|
|
.await
|
|
.unwrap();
|
|
});
|
|
Ok(())
|
|
}
|
|
|
|
/// Init logging using LOG_LEVEL
|
|
fn init_logging() {
|
|
// STDOUT/STDERR layer
|
|
let fmt_layer = tracing_subscriber::fmt::layer()
|
|
.with_file(true)
|
|
.with_line_number(true);
|
|
|
|
// Filter events with LOG_LEVEL
|
|
let env_filter =
|
|
EnvFilter::try_from_env("LOG_LEVEL").unwrap_or_else(|_| EnvFilter::new("info"));
|
|
|
|
tracing_subscriber::registry()
|
|
.with(env_filter)
|
|
.with(fmt_layer)
|
|
.init();
|
|
}
|