feat(benchmark): add support for private tokenizers (#262)
This commit is contained in:
parent
b0b97fd9a7
commit
0e9d249b79
|
@ -5,7 +5,7 @@
|
|||
use clap::Parser;
|
||||
use std::path::Path;
|
||||
use text_generation_client::ShardedClient;
|
||||
use tokenizers::Tokenizer;
|
||||
use tokenizers::{FromPretrainedParameters, Tokenizer};
|
||||
use tracing_subscriber::layer::SubscriberExt;
|
||||
use tracing_subscriber::util::SubscriberInitExt;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
@ -16,6 +16,8 @@ use tracing_subscriber::EnvFilter;
|
|||
struct Args {
|
||||
#[clap(short, long, env)]
|
||||
tokenizer_name: String,
|
||||
#[clap(default_value = "main", long, env)]
|
||||
revision: String,
|
||||
#[clap(short, long)]
|
||||
batch_size: Option<Vec<u32>>,
|
||||
#[clap(default_value = "10", short, long, env)]
|
||||
|
@ -36,6 +38,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||
// Pattern match configuration
|
||||
let Args {
|
||||
tokenizer_name,
|
||||
revision,
|
||||
batch_size,
|
||||
sequence_length,
|
||||
decode_length,
|
||||
|
@ -59,10 +62,19 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||
tracing::info!("Found local tokenizer");
|
||||
Tokenizer::from_file(local_path.join("tokenizer.json")).unwrap()
|
||||
} else {
|
||||
tracing::info!("Downloading tokenizer");
|
||||
|
||||
// Parse Huggingface hub token
|
||||
let auth_token = std::env::var("HUGGING_FACE_HUB_TOKEN").ok();
|
||||
|
||||
// Download and instantiate tokenizer
|
||||
// We need to download it outside of the Tokio runtime
|
||||
tracing::info!("Downloading tokenizer");
|
||||
Tokenizer::from_pretrained(tokenizer_name.clone(), None).unwrap()
|
||||
let params = FromPretrainedParameters {
|
||||
revision,
|
||||
auth_token,
|
||||
..Default::default()
|
||||
};
|
||||
Tokenizer::from_pretrained(tokenizer_name.clone(), Some(params)).unwrap()
|
||||
};
|
||||
tracing::info!("Tokenizer loaded");
|
||||
|
||||
|
|
Loading…
Reference in New Issue