2024-07-03 02:27:53 -06:00
use std ::collections ::HashMap ;
2024-07-01 07:53:23 -06:00
use clap ::Parser ;
2024-07-03 02:27:53 -06:00
use tokenizers ::{ FromPretrainedParameters , Tokenizer } ;
2024-07-01 07:53:23 -06:00
use text_generation_backends_trtllm ::{ errors ::TensorRtLlmBackendError , TrtLLmBackend } ;
use text_generation_router ::server ;
/// App Configuration
#[ derive(Parser, Debug) ]
#[ clap(author, version, about, long_about = None) ]
struct Args {
#[ clap(default_value = " 128 " , long, env) ]
max_concurrent_requests : usize ,
#[ clap(default_value = " 2 " , long, env) ]
max_best_of : usize ,
#[ clap(default_value = " 4 " , long, env) ]
max_stop_sequences : usize ,
#[ clap(default_value = " 5 " , long, env) ]
max_top_n_tokens : u32 ,
#[ clap(default_value = " 1024 " , long, env) ]
max_input_tokens : usize ,
#[ clap(default_value = " 2048 " , long, env) ]
max_total_tokens : usize ,
#[ clap(default_value = " 4096 " , long, env) ]
max_batch_prefill_tokens : u32 ,
#[ clap(long, env) ]
max_batch_total_tokens : Option < u32 > ,
#[ clap(default_value = " 0.0.0.0 " , long, env) ]
hostname : String ,
#[ clap(default_value = " 3000 " , long, short, env) ]
port : u16 ,
2024-07-01 08:11:59 -06:00
#[ clap(long, env, required = true) ]
2024-07-01 07:53:23 -06:00
tokenizer_name : String ,
#[ clap(long, env) ]
tokenizer_config_path : Option < String > ,
#[ clap(long, env) ]
revision : Option < String > ,
#[ clap(long, env) ]
model_id : String ,
#[ clap(default_value = " 2 " , long, env) ]
validation_workers : usize ,
#[ clap(long, env) ]
json_output : bool ,
#[ clap(long, env) ]
otlp_endpoint : Option < String > ,
#[ clap(default_value = " text-generation-inference.router " , long, env) ]
otlp_service_name : String ,
#[ clap(long, env) ]
cors_allow_origin : Option < Vec < String > > ,
#[ clap(long, env, default_value_t = false) ]
messages_api_enabled : bool ,
#[ clap(default_value = " 4 " , long, env) ]
max_client_batch_size : usize ,
2024-07-10 07:50:28 -06:00
#[ clap(long, env) ]
auth_token : Option < String >
2024-07-01 07:53:23 -06:00
}
#[ tokio::main ]
async fn main ( ) -> Result < ( ) , TensorRtLlmBackendError > {
// Get args
let args = Args ::parse ( ) ;
// Pattern match configuration
let Args {
max_concurrent_requests ,
max_best_of ,
max_stop_sequences ,
max_top_n_tokens ,
max_input_tokens ,
max_total_tokens ,
max_batch_prefill_tokens ,
max_batch_total_tokens ,
hostname ,
port ,
tokenizer_name ,
tokenizer_config_path ,
revision ,
model_id ,
validation_workers ,
json_output ,
otlp_endpoint ,
otlp_service_name ,
cors_allow_origin ,
messages_api_enabled ,
max_client_batch_size ,
2024-07-10 07:50:28 -06:00
auth_token
2024-07-01 07:53:23 -06:00
} = args ;
// Launch Tokio runtime
text_generation_router ::logging ::init_logging ( otlp_endpoint , otlp_service_name , json_output ) ;
// Validate args
if max_input_tokens > = max_total_tokens {
return Err ( TensorRtLlmBackendError ::ArgumentValidation (
" `max_input_tokens` must be < `max_total_tokens` " . to_string ( ) ,
) ) ;
}
if max_input_tokens as u32 > max_batch_prefill_tokens {
return Err ( TensorRtLlmBackendError ::ArgumentValidation ( format! ( " `max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {max_batch_prefill_tokens} and {max_input_tokens} " ) ) ) ;
}
if validation_workers = = 0 {
return Err ( TensorRtLlmBackendError ::ArgumentValidation (
" `validation_workers` must be > 0 " . to_string ( ) ,
) ) ;
}
if let Some ( ref max_batch_total_tokens ) = max_batch_total_tokens {
if max_batch_prefill_tokens > * max_batch_total_tokens {
return Err ( TensorRtLlmBackendError ::ArgumentValidation ( format! ( " `max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {max_batch_prefill_tokens} and {max_batch_total_tokens} " ) ) ) ;
}
if max_total_tokens as u32 > * max_batch_total_tokens {
return Err ( TensorRtLlmBackendError ::ArgumentValidation ( format! ( " `max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_batch_total_tokens} " ) ) ) ;
}
}
// Run server
2024-07-03 02:27:53 -06:00
let tokenizer = Tokenizer ::from_pretrained (
tokenizer_name . clone ( ) ,
Some ( FromPretrainedParameters {
revision : revision . clone ( ) . unwrap_or ( String ::from ( " main " ) ) ,
user_agent : HashMap ::new ( ) ,
auth_token ,
} ) ,
) ? ;
let backend = TrtLLmBackend ::new ( tokenizer , model_id ) ? ;
2024-07-01 07:53:23 -06:00
server ::run (
backend ,
max_concurrent_requests ,
max_best_of ,
max_stop_sequences ,
max_top_n_tokens ,
max_input_tokens ,
max_total_tokens ,
validation_workers ,
tokenizer_name ,
tokenizer_config_path ,
revision ,
hostname ,
port ,
cors_allow_origin ,
false ,
None ,
None ,
messages_api_enabled ,
true ,
max_client_batch_size ,
)
. await ? ;
Ok ( ( ) )
}