fix: enable chat requests in vertex endpoint (#2481)

* fix: enable chat requests in vertex endpoint

* feat: avoid unwrap and pre allocate future vec
This commit is contained in:
drbh 2024-09-02 10:00:52 -04:00 committed by GitHub
parent de2cdeca53
commit 47d7e34458
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 124 additions and 31 deletions

View File

@ -55,13 +55,20 @@ impl std::str::FromStr for Attention {
} }
#[derive(Clone, Deserialize, ToSchema)] #[derive(Clone, Deserialize, ToSchema)]
pub(crate) struct VertexInstance { pub(crate) struct GenerateVertexInstance {
#[schema(example = "What is Deep Learning?")] #[schema(example = "What is Deep Learning?")]
pub inputs: String, pub inputs: String,
#[schema(nullable = true, default = "null", example = "null")] #[schema(nullable = true, default = "null", example = "null")]
pub parameters: Option<GenerateParameters>, pub parameters: Option<GenerateParameters>,
} }
#[derive(Clone, Deserialize, ToSchema)]
#[serde(untagged)]
enum VertexInstance {
Generate(GenerateVertexInstance),
Chat(ChatRequest),
}
#[derive(Deserialize, ToSchema)] #[derive(Deserialize, ToSchema)]
pub(crate) struct VertexRequest { pub(crate) struct VertexRequest {
#[serde(rename = "instances")] #[serde(rename = "instances")]

View File

@ -8,7 +8,7 @@ use crate::kserve::{
kserve_model_metadata, kserve_model_metadata_ready, kserve_model_metadata, kserve_model_metadata_ready,
}; };
use crate::validation::ValidationError; use crate::validation::ValidationError;
use crate::{default_tool_prompt, ChatTokenizeResponse}; use crate::{default_tool_prompt, ChatTokenizeResponse, VertexInstance};
use crate::{ use crate::{
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName, usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
@ -1406,12 +1406,12 @@ async fn vertex_compatibility(
)); ));
} }
// Process all instances // Prepare futures for all instances
let predictions = req let mut futures = Vec::with_capacity(req.instances.len());
.instances
.iter() for instance in req.instances.iter() {
.map(|instance| { let generate_request = match instance {
let generate_request = GenerateRequest { VertexInstance::Generate(instance) => GenerateRequest {
inputs: instance.inputs.clone(), inputs: instance.inputs.clone(),
add_special_tokens: true, add_special_tokens: true,
parameters: GenerateParameters { parameters: GenerateParameters {
@ -1422,31 +1422,117 @@ async fn vertex_compatibility(
decoder_input_details: true, decoder_input_details: true,
..Default::default() ..Default::default()
}, },
}; },
VertexInstance::Chat(instance) => {
let ChatRequest {
model,
max_tokens,
messages,
seed,
stop,
stream,
tools,
tool_choice,
tool_prompt,
temperature,
response_format,
guideline,
presence_penalty,
frequency_penalty,
top_p,
top_logprobs,
..
} = instance.clone();
async { let repetition_penalty = presence_penalty.map(|x| x + 2.0);
generate_internal( let max_new_tokens = max_tokens.or(Some(100));
Extension(infer.clone()), let tool_prompt = tool_prompt
compute_type.clone(), .filter(|s| !s.is_empty())
Json(generate_request), .unwrap_or_else(default_tool_prompt);
span.clone(), let stop = stop.unwrap_or_default();
) // enable greedy only when temperature is 0
.await let (do_sample, temperature) = match temperature {
.map(|(_, Json(generation))| generation.generated_text) Some(temperature) if temperature == 0.0 => (false, None),
.map_err(|_| { other => (true, other),
( };
StatusCode::INTERNAL_SERVER_ERROR, let (inputs, grammar, _using_tools) = match prepare_chat_input(
Json(ErrorResponse { &infer,
error: "Incomplete generation".into(), response_format,
error_type: "Incomplete generation".into(), tools,
}), tool_choice,
) &tool_prompt,
}) guideline,
messages,
) {
Ok(result) => result,
Err(e) => {
return Err((
StatusCode::BAD_REQUEST,
Json(ErrorResponse {
error: format!("Failed to prepare chat input: {}", e),
error_type: "Input preparation error".to_string(),
}),
));
}
};
GenerateRequest {
inputs: inputs.to_string(),
add_special_tokens: false,
parameters: GenerateParameters {
best_of: None,
temperature,
repetition_penalty,
frequency_penalty,
top_k: None,
top_p,
typical_p: None,
do_sample,
max_new_tokens,
return_full_text: None,
stop,
truncate: None,
watermark: false,
details: true,
decoder_input_details: !stream,
seed,
top_n_tokens: top_logprobs,
grammar,
adapter_id: model.filter(|m| *m != "tgi").map(String::from),
},
}
} }
}) };
.collect::<FuturesUnordered<_>>()
.try_collect::<Vec<_>>() let infer_clone = infer.clone();
.await?; let compute_type_clone = compute_type.clone();
let span_clone = span.clone();
futures.push(async move {
generate_internal(
Extension(infer_clone),
compute_type_clone,
Json(generate_request),
span_clone,
)
.await
.map(|(_, Json(generation))| generation.generated_text)
.map_err(|_| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "Incomplete generation".into(),
error_type: "Incomplete generation".into(),
}),
)
})
});
}
// execute all futures in parallel, collect results, returning early if any error occurs
let results = futures::future::join_all(futures).await;
let predictions: Result<Vec<_>, _> = results.into_iter().collect();
let predictions = predictions?;
let response = VertexResponse { predictions }; let response = VertexResponse { predictions };
Ok((HeaderMap::new(), Json(response)).into_response()) Ok((HeaderMap::new(), Json(response)).into_response())