diff --git a/router/src/lib.rs b/router/src/lib.rs index a1e1dadf..d8029c72 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -55,13 +55,20 @@ impl std::str::FromStr for Attention { } #[derive(Clone, Deserialize, ToSchema)] -pub(crate) struct VertexInstance { +pub(crate) struct GenerateVertexInstance { #[schema(example = "What is Deep Learning?")] pub inputs: String, #[schema(nullable = true, default = "null", example = "null")] pub parameters: Option, } +#[derive(Clone, Deserialize, ToSchema)] +#[serde(untagged)] +enum VertexInstance { + Generate(GenerateVertexInstance), + Chat(ChatRequest), +} + #[derive(Deserialize, ToSchema)] pub(crate) struct VertexRequest { #[serde(rename = "instances")] diff --git a/router/src/server.rs b/router/src/server.rs index d3d34215..fac56a77 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -8,7 +8,7 @@ use crate::kserve::{ kserve_model_metadata, kserve_model_metadata_ready, }; use crate::validation::ValidationError; -use crate::{default_tool_prompt, ChatTokenizeResponse}; +use crate::{default_tool_prompt, ChatTokenizeResponse, VertexInstance}; use crate::{ usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, @@ -1406,12 +1406,12 @@ async fn vertex_compatibility( )); } - // Process all instances - let predictions = req - .instances - .iter() - .map(|instance| { - let generate_request = GenerateRequest { + // Prepare futures for all instances + let mut futures = Vec::with_capacity(req.instances.len()); + + for instance in req.instances.iter() { + let generate_request = match instance { + VertexInstance::Generate(instance) => GenerateRequest { inputs: instance.inputs.clone(), add_special_tokens: true, parameters: GenerateParameters { @@ -1422,31 +1422,117 @@ async fn vertex_compatibility( decoder_input_details: true, ..Default::default() }, - }; + }, + VertexInstance::Chat(instance) => { + let ChatRequest { + model, + max_tokens, + messages, + seed, + stop, + stream, + tools, + tool_choice, + tool_prompt, + temperature, + response_format, + guideline, + presence_penalty, + frequency_penalty, + top_p, + top_logprobs, + .. + } = instance.clone(); - async { - generate_internal( - Extension(infer.clone()), - compute_type.clone(), - Json(generate_request), - span.clone(), - ) - .await - .map(|(_, Json(generation))| generation.generated_text) - .map_err(|_| { - ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(ErrorResponse { - error: "Incomplete generation".into(), - error_type: "Incomplete generation".into(), - }), - ) - }) + let repetition_penalty = presence_penalty.map(|x| x + 2.0); + let max_new_tokens = max_tokens.or(Some(100)); + let tool_prompt = tool_prompt + .filter(|s| !s.is_empty()) + .unwrap_or_else(default_tool_prompt); + let stop = stop.unwrap_or_default(); + // enable greedy only when temperature is 0 + let (do_sample, temperature) = match temperature { + Some(temperature) if temperature == 0.0 => (false, None), + other => (true, other), + }; + let (inputs, grammar, _using_tools) = match prepare_chat_input( + &infer, + response_format, + tools, + tool_choice, + &tool_prompt, + guideline, + messages, + ) { + Ok(result) => result, + Err(e) => { + return Err(( + StatusCode::BAD_REQUEST, + Json(ErrorResponse { + error: format!("Failed to prepare chat input: {}", e), + error_type: "Input preparation error".to_string(), + }), + )); + } + }; + + GenerateRequest { + inputs: inputs.to_string(), + add_special_tokens: false, + parameters: GenerateParameters { + best_of: None, + temperature, + repetition_penalty, + frequency_penalty, + top_k: None, + top_p, + typical_p: None, + do_sample, + max_new_tokens, + return_full_text: None, + stop, + truncate: None, + watermark: false, + details: true, + decoder_input_details: !stream, + seed, + top_n_tokens: top_logprobs, + grammar, + adapter_id: model.filter(|m| *m != "tgi").map(String::from), + }, + } } - }) - .collect::>() - .try_collect::>() - .await?; + }; + + let infer_clone = infer.clone(); + let compute_type_clone = compute_type.clone(); + let span_clone = span.clone(); + + futures.push(async move { + generate_internal( + Extension(infer_clone), + compute_type_clone, + Json(generate_request), + span_clone, + ) + .await + .map(|(_, Json(generation))| generation.generated_text) + .map_err(|_| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: "Incomplete generation".into(), + error_type: "Incomplete generation".into(), + }), + ) + }) + }); + } + + // execute all futures in parallel, collect results, returning early if any error occurs + let results = futures::future::join_all(futures).await; + let predictions: Result, _> = results.into_iter().collect(); + let predictions = predictions?; let response = VertexResponse { predictions }; Ok((HeaderMap::new(), Json(response)).into_response())