fix: enable chat requests in vertex endpoint (#2481)
* fix: enable chat requests in vertex endpoint * feat: avoid unwrap and pre allocate future vec
This commit is contained in:
parent
de2cdeca53
commit
47d7e34458
|
@ -55,13 +55,20 @@ impl std::str::FromStr for Attention {
|
|||
}
|
||||
|
||||
#[derive(Clone, Deserialize, ToSchema)]
|
||||
pub(crate) struct VertexInstance {
|
||||
pub(crate) struct GenerateVertexInstance {
|
||||
#[schema(example = "What is Deep Learning?")]
|
||||
pub inputs: String,
|
||||
#[schema(nullable = true, default = "null", example = "null")]
|
||||
pub parameters: Option<GenerateParameters>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, ToSchema)]
|
||||
#[serde(untagged)]
|
||||
enum VertexInstance {
|
||||
Generate(GenerateVertexInstance),
|
||||
Chat(ChatRequest),
|
||||
}
|
||||
|
||||
#[derive(Deserialize, ToSchema)]
|
||||
pub(crate) struct VertexRequest {
|
||||
#[serde(rename = "instances")]
|
||||
|
|
|
@ -8,7 +8,7 @@ use crate::kserve::{
|
|||
kserve_model_metadata, kserve_model_metadata_ready,
|
||||
};
|
||||
use crate::validation::ValidationError;
|
||||
use crate::{default_tool_prompt, ChatTokenizeResponse};
|
||||
use crate::{default_tool_prompt, ChatTokenizeResponse, VertexInstance};
|
||||
use crate::{
|
||||
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
|
||||
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
|
||||
|
@ -1406,12 +1406,12 @@ async fn vertex_compatibility(
|
|||
));
|
||||
}
|
||||
|
||||
// Process all instances
|
||||
let predictions = req
|
||||
.instances
|
||||
.iter()
|
||||
.map(|instance| {
|
||||
let generate_request = GenerateRequest {
|
||||
// Prepare futures for all instances
|
||||
let mut futures = Vec::with_capacity(req.instances.len());
|
||||
|
||||
for instance in req.instances.iter() {
|
||||
let generate_request = match instance {
|
||||
VertexInstance::Generate(instance) => GenerateRequest {
|
||||
inputs: instance.inputs.clone(),
|
||||
add_special_tokens: true,
|
||||
parameters: GenerateParameters {
|
||||
|
@ -1422,31 +1422,117 @@ async fn vertex_compatibility(
|
|||
decoder_input_details: true,
|
||||
..Default::default()
|
||||
},
|
||||
};
|
||||
},
|
||||
VertexInstance::Chat(instance) => {
|
||||
let ChatRequest {
|
||||
model,
|
||||
max_tokens,
|
||||
messages,
|
||||
seed,
|
||||
stop,
|
||||
stream,
|
||||
tools,
|
||||
tool_choice,
|
||||
tool_prompt,
|
||||
temperature,
|
||||
response_format,
|
||||
guideline,
|
||||
presence_penalty,
|
||||
frequency_penalty,
|
||||
top_p,
|
||||
top_logprobs,
|
||||
..
|
||||
} = instance.clone();
|
||||
|
||||
async {
|
||||
generate_internal(
|
||||
Extension(infer.clone()),
|
||||
compute_type.clone(),
|
||||
Json(generate_request),
|
||||
span.clone(),
|
||||
)
|
||||
.await
|
||||
.map(|(_, Json(generation))| generation.generated_text)
|
||||
.map_err(|_| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ErrorResponse {
|
||||
error: "Incomplete generation".into(),
|
||||
error_type: "Incomplete generation".into(),
|
||||
}),
|
||||
)
|
||||
})
|
||||
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
|
||||
let max_new_tokens = max_tokens.or(Some(100));
|
||||
let tool_prompt = tool_prompt
|
||||
.filter(|s| !s.is_empty())
|
||||
.unwrap_or_else(default_tool_prompt);
|
||||
let stop = stop.unwrap_or_default();
|
||||
// enable greedy only when temperature is 0
|
||||
let (do_sample, temperature) = match temperature {
|
||||
Some(temperature) if temperature == 0.0 => (false, None),
|
||||
other => (true, other),
|
||||
};
|
||||
let (inputs, grammar, _using_tools) = match prepare_chat_input(
|
||||
&infer,
|
||||
response_format,
|
||||
tools,
|
||||
tool_choice,
|
||||
&tool_prompt,
|
||||
guideline,
|
||||
messages,
|
||||
) {
|
||||
Ok(result) => result,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(ErrorResponse {
|
||||
error: format!("Failed to prepare chat input: {}", e),
|
||||
error_type: "Input preparation error".to_string(),
|
||||
}),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
GenerateRequest {
|
||||
inputs: inputs.to_string(),
|
||||
add_special_tokens: false,
|
||||
parameters: GenerateParameters {
|
||||
best_of: None,
|
||||
temperature,
|
||||
repetition_penalty,
|
||||
frequency_penalty,
|
||||
top_k: None,
|
||||
top_p,
|
||||
typical_p: None,
|
||||
do_sample,
|
||||
max_new_tokens,
|
||||
return_full_text: None,
|
||||
stop,
|
||||
truncate: None,
|
||||
watermark: false,
|
||||
details: true,
|
||||
decoder_input_details: !stream,
|
||||
seed,
|
||||
top_n_tokens: top_logprobs,
|
||||
grammar,
|
||||
adapter_id: model.filter(|m| *m != "tgi").map(String::from),
|
||||
},
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect::<FuturesUnordered<_>>()
|
||||
.try_collect::<Vec<_>>()
|
||||
.await?;
|
||||
};
|
||||
|
||||
let infer_clone = infer.clone();
|
||||
let compute_type_clone = compute_type.clone();
|
||||
let span_clone = span.clone();
|
||||
|
||||
futures.push(async move {
|
||||
generate_internal(
|
||||
Extension(infer_clone),
|
||||
compute_type_clone,
|
||||
Json(generate_request),
|
||||
span_clone,
|
||||
)
|
||||
.await
|
||||
.map(|(_, Json(generation))| generation.generated_text)
|
||||
.map_err(|_| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ErrorResponse {
|
||||
error: "Incomplete generation".into(),
|
||||
error_type: "Incomplete generation".into(),
|
||||
}),
|
||||
)
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
// execute all futures in parallel, collect results, returning early if any error occurs
|
||||
let results = futures::future::join_all(futures).await;
|
||||
let predictions: Result<Vec<_>, _> = results.into_iter().collect();
|
||||
let predictions = predictions?;
|
||||
|
||||
let response = VertexResponse { predictions };
|
||||
Ok((HeaderMap::new(), Json(response)).into_response())
|
||||
|
|
Loading…
Reference in New Issue