fix: enable chat requests in vertex endpoint (#2481)

* fix: enable chat requests in vertex endpoint

* feat: avoid unwrap and pre allocate future vec
This commit is contained in:
drbh 2024-09-02 10:00:52 -04:00 committed by GitHub
parent de2cdeca53
commit 47d7e34458
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 124 additions and 31 deletions

View File

@ -55,13 +55,20 @@ impl std::str::FromStr for Attention {
}
#[derive(Clone, Deserialize, ToSchema)]
pub(crate) struct VertexInstance {
pub(crate) struct GenerateVertexInstance {
#[schema(example = "What is Deep Learning?")]
pub inputs: String,
#[schema(nullable = true, default = "null", example = "null")]
pub parameters: Option<GenerateParameters>,
}
#[derive(Clone, Deserialize, ToSchema)]
#[serde(untagged)]
enum VertexInstance {
Generate(GenerateVertexInstance),
Chat(ChatRequest),
}
#[derive(Deserialize, ToSchema)]
pub(crate) struct VertexRequest {
#[serde(rename = "instances")]

View File

@ -8,7 +8,7 @@ use crate::kserve::{
kserve_model_metadata, kserve_model_metadata_ready,
};
use crate::validation::ValidationError;
use crate::{default_tool_prompt, ChatTokenizeResponse};
use crate::{default_tool_prompt, ChatTokenizeResponse, VertexInstance};
use crate::{
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
@ -1406,12 +1406,12 @@ async fn vertex_compatibility(
));
}
// Process all instances
let predictions = req
.instances
.iter()
.map(|instance| {
let generate_request = GenerateRequest {
// Prepare futures for all instances
let mut futures = Vec::with_capacity(req.instances.len());
for instance in req.instances.iter() {
let generate_request = match instance {
VertexInstance::Generate(instance) => GenerateRequest {
inputs: instance.inputs.clone(),
add_special_tokens: true,
parameters: GenerateParameters {
@ -1422,14 +1422,98 @@ async fn vertex_compatibility(
decoder_input_details: true,
..Default::default()
},
},
VertexInstance::Chat(instance) => {
let ChatRequest {
model,
max_tokens,
messages,
seed,
stop,
stream,
tools,
tool_choice,
tool_prompt,
temperature,
response_format,
guideline,
presence_penalty,
frequency_penalty,
top_p,
top_logprobs,
..
} = instance.clone();
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
let max_new_tokens = max_tokens.or(Some(100));
let tool_prompt = tool_prompt
.filter(|s| !s.is_empty())
.unwrap_or_else(default_tool_prompt);
let stop = stop.unwrap_or_default();
// enable greedy only when temperature is 0
let (do_sample, temperature) = match temperature {
Some(temperature) if temperature == 0.0 => (false, None),
other => (true, other),
};
let (inputs, grammar, _using_tools) = match prepare_chat_input(
&infer,
response_format,
tools,
tool_choice,
&tool_prompt,
guideline,
messages,
) {
Ok(result) => result,
Err(e) => {
return Err((
StatusCode::BAD_REQUEST,
Json(ErrorResponse {
error: format!("Failed to prepare chat input: {}", e),
error_type: "Input preparation error".to_string(),
}),
));
}
};
async {
GenerateRequest {
inputs: inputs.to_string(),
add_special_tokens: false,
parameters: GenerateParameters {
best_of: None,
temperature,
repetition_penalty,
frequency_penalty,
top_k: None,
top_p,
typical_p: None,
do_sample,
max_new_tokens,
return_full_text: None,
stop,
truncate: None,
watermark: false,
details: true,
decoder_input_details: !stream,
seed,
top_n_tokens: top_logprobs,
grammar,
adapter_id: model.filter(|m| *m != "tgi").map(String::from),
},
}
}
};
let infer_clone = infer.clone();
let compute_type_clone = compute_type.clone();
let span_clone = span.clone();
futures.push(async move {
generate_internal(
Extension(infer.clone()),
compute_type.clone(),
Extension(infer_clone),
compute_type_clone,
Json(generate_request),
span.clone(),
span_clone,
)
.await
.map(|(_, Json(generation))| generation.generated_text)
@ -1442,11 +1526,13 @@ async fn vertex_compatibility(
}),
)
})
});
}
})
.collect::<FuturesUnordered<_>>()
.try_collect::<Vec<_>>()
.await?;
// execute all futures in parallel, collect results, returning early if any error occurs
let results = futures::future::join_all(futures).await;
let predictions: Result<Vec<_>, _> = results.into_iter().collect();
let predictions = predictions?;
let response = VertexResponse { predictions };
Ok((HeaderMap::new(), Json(response)).into_response())