feat(benchmark): add support for private tokenizers (#262)
This commit is contained in:
parent
b0b97fd9a7
commit
0e9d249b79
|
@ -5,7 +5,7 @@
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use text_generation_client::ShardedClient;
|
use text_generation_client::ShardedClient;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::{FromPretrainedParameters, Tokenizer};
|
||||||
use tracing_subscriber::layer::SubscriberExt;
|
use tracing_subscriber::layer::SubscriberExt;
|
||||||
use tracing_subscriber::util::SubscriberInitExt;
|
use tracing_subscriber::util::SubscriberInitExt;
|
||||||
use tracing_subscriber::EnvFilter;
|
use tracing_subscriber::EnvFilter;
|
||||||
|
@ -16,6 +16,8 @@ use tracing_subscriber::EnvFilter;
|
||||||
struct Args {
|
struct Args {
|
||||||
#[clap(short, long, env)]
|
#[clap(short, long, env)]
|
||||||
tokenizer_name: String,
|
tokenizer_name: String,
|
||||||
|
#[clap(default_value = "main", long, env)]
|
||||||
|
revision: String,
|
||||||
#[clap(short, long)]
|
#[clap(short, long)]
|
||||||
batch_size: Option<Vec<u32>>,
|
batch_size: Option<Vec<u32>>,
|
||||||
#[clap(default_value = "10", short, long, env)]
|
#[clap(default_value = "10", short, long, env)]
|
||||||
|
@ -36,6 +38,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
// Pattern match configuration
|
// Pattern match configuration
|
||||||
let Args {
|
let Args {
|
||||||
tokenizer_name,
|
tokenizer_name,
|
||||||
|
revision,
|
||||||
batch_size,
|
batch_size,
|
||||||
sequence_length,
|
sequence_length,
|
||||||
decode_length,
|
decode_length,
|
||||||
|
@ -59,10 +62,19 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
tracing::info!("Found local tokenizer");
|
tracing::info!("Found local tokenizer");
|
||||||
Tokenizer::from_file(local_path.join("tokenizer.json")).unwrap()
|
Tokenizer::from_file(local_path.join("tokenizer.json")).unwrap()
|
||||||
} else {
|
} else {
|
||||||
|
tracing::info!("Downloading tokenizer");
|
||||||
|
|
||||||
|
// Parse Huggingface hub token
|
||||||
|
let auth_token = std::env::var("HUGGING_FACE_HUB_TOKEN").ok();
|
||||||
|
|
||||||
// Download and instantiate tokenizer
|
// Download and instantiate tokenizer
|
||||||
// We need to download it outside of the Tokio runtime
|
// We need to download it outside of the Tokio runtime
|
||||||
tracing::info!("Downloading tokenizer");
|
let params = FromPretrainedParameters {
|
||||||
Tokenizer::from_pretrained(tokenizer_name.clone(), None).unwrap()
|
revision,
|
||||||
|
auth_token,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
Tokenizer::from_pretrained(tokenizer_name.clone(), Some(params)).unwrap()
|
||||||
};
|
};
|
||||||
tracing::info!("Tokenizer loaded");
|
tracing::info!("Tokenizer loaded");
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue