feat(benchmark): add support for private tokenizers (#262)

This commit is contained in:
OlivierDehaene 2023-04-29 12:17:30 +02:00 committed by GitHub
parent b0b97fd9a7
commit 0e9d249b79
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 15 additions and 3 deletions

View File

@ -5,7 +5,7 @@
use clap::Parser;
use std::path::Path;
use text_generation_client::ShardedClient;
use tokenizers::Tokenizer;
use tokenizers::{FromPretrainedParameters, Tokenizer};
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::EnvFilter;
@ -16,6 +16,8 @@ use tracing_subscriber::EnvFilter;
struct Args {
#[clap(short, long, env)]
tokenizer_name: String,
#[clap(default_value = "main", long, env)]
revision: String,
#[clap(short, long)]
batch_size: Option<Vec<u32>>,
#[clap(default_value = "10", short, long, env)]
@ -36,6 +38,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
// Pattern match configuration
let Args {
tokenizer_name,
revision,
batch_size,
sequence_length,
decode_length,
@ -59,10 +62,19 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing::info!("Found local tokenizer");
Tokenizer::from_file(local_path.join("tokenizer.json")).unwrap()
} else {
tracing::info!("Downloading tokenizer");
// Parse Huggingface hub token
let auth_token = std::env::var("HUGGING_FACE_HUB_TOKEN").ok();
// Download and instantiate tokenizer
// We need to download it outside of the Tokio runtime
tracing::info!("Downloading tokenizer");
Tokenizer::from_pretrained(tokenizer_name.clone(), None).unwrap()
let params = FromPretrainedParameters {
revision,
auth_token,
..Default::default()
};
Tokenizer::from_pretrained(tokenizer_name.clone(), Some(params)).unwrap()
};
tracing::info!("Tokenizer loaded");