diff --git a/benchmark/src/main.rs b/benchmark/src/main.rs index 6172d377..a7550060 100644 --- a/benchmark/src/main.rs +++ b/benchmark/src/main.rs @@ -14,36 +14,85 @@ use tracing_subscriber::EnvFilter; #[derive(Parser, Debug)] #[clap(author, version, about, long_about = None)] struct Args { + /// The name of the tokenizer (as in model_id on the huggingface hub, or local path). #[clap(short, long, env)] tokenizer_name: String, + + /// The revision to use for the tokenizer if on the hub. #[clap(default_value = "main", long, env)] revision: String, + + /// The various batch sizes to benchmark for, the idea is to get enough + /// batching to start seeing increased latency, this usually means you're + /// moving from memory bound (usual as BS=1) to compute bound, and this is + /// a sweet spot for the maximum batch size for the model under test #[clap(short, long)] batch_size: Option>, + + /// This is the initial prompt sent to the text-generation-server length + /// in token. Longer prompt will slow down the benchmark. Usually the + /// latency grows somewhat linearly with this for the prefill step. + /// + /// Most importantly, the prefill step is usually not the one dominating + /// your runtime, so it's ok to keep it short. #[clap(default_value = "10", short, long, env)] sequence_length: u32, + + /// This is how many tokens will be generated by the server and averaged out + /// to give the `decode` latency. This is the *critical* number you want to optimize for + /// LLM spend most of their time doing decoding. + /// + /// Decode latency is usually quite stable. #[clap(default_value = "8", short, long, env)] decode_length: u32, + + ///How many runs should we average from #[clap(default_value = "10", short, long, env)] runs: usize, + + /// Number of warmup cycles #[clap(default_value = "1", short, long, env)] warmups: usize, - #[clap(long, env)] - temperature: Option, - #[clap(long, env)] - top_k: Option, - #[clap(long, env)] - top_p: Option, - #[clap(long, env)] - typical_p: Option, - #[clap(long, env)] - repetition_penalty: Option, - #[clap(long, env)] - watermark: bool, - #[clap(long, env)] - do_sample: bool, + + /// The location of the grpc socket. This benchmark tool bypasses the router + /// completely and directly talks to the gRPC processes #[clap(default_value = "/tmp/text-generation-server-0", short, long, env)] master_shard_uds_path: String, + + /// Generation parameter in case you want to specifically test/debug particular + /// decoding strategies, for full doc refer to the `text-generation-server` + #[clap(long, env)] + temperature: Option, + + /// Generation parameter in case you want to specifically test/debug particular + /// decoding strategies, for full doc refer to the `text-generation-server` + #[clap(long, env)] + top_k: Option, + + /// Generation parameter in case you want to specifically test/debug particular + /// decoding strategies, for full doc refer to the `text-generation-server` + #[clap(long, env)] + top_p: Option, + + /// Generation parameter in case you want to specifically test/debug particular + /// decoding strategies, for full doc refer to the `text-generation-server` + #[clap(long, env)] + typical_p: Option, + + /// Generation parameter in case you want to specifically test/debug particular + /// decoding strategies, for full doc refer to the `text-generation-server` + #[clap(long, env)] + repetition_penalty: Option, + + /// Generation parameter in case you want to specifically test/debug particular + /// decoding strategies, for full doc refer to the `text-generation-server` + #[clap(long, env)] + watermark: bool, + + /// Generation parameter in case you want to specifically test/debug particular + /// decoding strategies, for full doc refer to the `text-generation-server` + #[clap(long, env)] + do_sample: bool, } fn main() -> Result<(), Box> {