diff --git a/benchmark/src/main.rs b/benchmark/src/main.rs index 443af77f..03f61dcd 100644 --- a/benchmark/src/main.rs +++ b/benchmark/src/main.rs @@ -5,7 +5,7 @@ use clap::Parser; use std::path::Path; use text_generation_client::ShardedClient; -use tokenizers::Tokenizer; +use tokenizers::{FromPretrainedParameters, Tokenizer}; use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::util::SubscriberInitExt; use tracing_subscriber::EnvFilter; @@ -16,6 +16,8 @@ use tracing_subscriber::EnvFilter; struct Args { #[clap(short, long, env)] tokenizer_name: String, + #[clap(default_value = "main", long, env)] + revision: String, #[clap(short, long)] batch_size: Option>, #[clap(default_value = "10", short, long, env)] @@ -36,6 +38,7 @@ fn main() -> Result<(), Box> { // Pattern match configuration let Args { tokenizer_name, + revision, batch_size, sequence_length, decode_length, @@ -59,10 +62,19 @@ fn main() -> Result<(), Box> { tracing::info!("Found local tokenizer"); Tokenizer::from_file(local_path.join("tokenizer.json")).unwrap() } else { + tracing::info!("Downloading tokenizer"); + + // Parse Huggingface hub token + let auth_token = std::env::var("HUGGING_FACE_HUB_TOKEN").ok(); + // Download and instantiate tokenizer // We need to download it outside of the Tokio runtime - tracing::info!("Downloading tokenizer"); - Tokenizer::from_pretrained(tokenizer_name.clone(), None).unwrap() + let params = FromPretrainedParameters { + revision, + auth_token, + ..Default::default() + }; + Tokenizer::from_pretrained(tokenizer_name.clone(), Some(params)).unwrap() }; tracing::info!("Tokenizer loaded");