fix: enable chat requests in vertex endpoint (#2481)
* fix: enable chat requests in vertex endpoint * feat: avoid unwrap and pre allocate future vec
This commit is contained in:
parent
de2cdeca53
commit
47d7e34458
|
@ -55,13 +55,20 @@ impl std::str::FromStr for Attention {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, ToSchema)]
|
#[derive(Clone, Deserialize, ToSchema)]
|
||||||
pub(crate) struct VertexInstance {
|
pub(crate) struct GenerateVertexInstance {
|
||||||
#[schema(example = "What is Deep Learning?")]
|
#[schema(example = "What is Deep Learning?")]
|
||||||
pub inputs: String,
|
pub inputs: String,
|
||||||
#[schema(nullable = true, default = "null", example = "null")]
|
#[schema(nullable = true, default = "null", example = "null")]
|
||||||
pub parameters: Option<GenerateParameters>,
|
pub parameters: Option<GenerateParameters>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, ToSchema)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
enum VertexInstance {
|
||||||
|
Generate(GenerateVertexInstance),
|
||||||
|
Chat(ChatRequest),
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Deserialize, ToSchema)]
|
#[derive(Deserialize, ToSchema)]
|
||||||
pub(crate) struct VertexRequest {
|
pub(crate) struct VertexRequest {
|
||||||
#[serde(rename = "instances")]
|
#[serde(rename = "instances")]
|
||||||
|
|
|
@ -8,7 +8,7 @@ use crate::kserve::{
|
||||||
kserve_model_metadata, kserve_model_metadata_ready,
|
kserve_model_metadata, kserve_model_metadata_ready,
|
||||||
};
|
};
|
||||||
use crate::validation::ValidationError;
|
use crate::validation::ValidationError;
|
||||||
use crate::{default_tool_prompt, ChatTokenizeResponse};
|
use crate::{default_tool_prompt, ChatTokenizeResponse, VertexInstance};
|
||||||
use crate::{
|
use crate::{
|
||||||
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
|
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
|
||||||
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
|
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
|
||||||
|
@ -1406,12 +1406,12 @@ async fn vertex_compatibility(
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process all instances
|
// Prepare futures for all instances
|
||||||
let predictions = req
|
let mut futures = Vec::with_capacity(req.instances.len());
|
||||||
.instances
|
|
||||||
.iter()
|
for instance in req.instances.iter() {
|
||||||
.map(|instance| {
|
let generate_request = match instance {
|
||||||
let generate_request = GenerateRequest {
|
VertexInstance::Generate(instance) => GenerateRequest {
|
||||||
inputs: instance.inputs.clone(),
|
inputs: instance.inputs.clone(),
|
||||||
add_special_tokens: true,
|
add_special_tokens: true,
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
|
@ -1422,14 +1422,98 @@ async fn vertex_compatibility(
|
||||||
decoder_input_details: true,
|
decoder_input_details: true,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
},
|
},
|
||||||
|
},
|
||||||
|
VertexInstance::Chat(instance) => {
|
||||||
|
let ChatRequest {
|
||||||
|
model,
|
||||||
|
max_tokens,
|
||||||
|
messages,
|
||||||
|
seed,
|
||||||
|
stop,
|
||||||
|
stream,
|
||||||
|
tools,
|
||||||
|
tool_choice,
|
||||||
|
tool_prompt,
|
||||||
|
temperature,
|
||||||
|
response_format,
|
||||||
|
guideline,
|
||||||
|
presence_penalty,
|
||||||
|
frequency_penalty,
|
||||||
|
top_p,
|
||||||
|
top_logprobs,
|
||||||
|
..
|
||||||
|
} = instance.clone();
|
||||||
|
|
||||||
|
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
|
||||||
|
let max_new_tokens = max_tokens.or(Some(100));
|
||||||
|
let tool_prompt = tool_prompt
|
||||||
|
.filter(|s| !s.is_empty())
|
||||||
|
.unwrap_or_else(default_tool_prompt);
|
||||||
|
let stop = stop.unwrap_or_default();
|
||||||
|
// enable greedy only when temperature is 0
|
||||||
|
let (do_sample, temperature) = match temperature {
|
||||||
|
Some(temperature) if temperature == 0.0 => (false, None),
|
||||||
|
other => (true, other),
|
||||||
|
};
|
||||||
|
let (inputs, grammar, _using_tools) = match prepare_chat_input(
|
||||||
|
&infer,
|
||||||
|
response_format,
|
||||||
|
tools,
|
||||||
|
tool_choice,
|
||||||
|
&tool_prompt,
|
||||||
|
guideline,
|
||||||
|
messages,
|
||||||
|
) {
|
||||||
|
Ok(result) => result,
|
||||||
|
Err(e) => {
|
||||||
|
return Err((
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
Json(ErrorResponse {
|
||||||
|
error: format!("Failed to prepare chat input: {}", e),
|
||||||
|
error_type: "Input preparation error".to_string(),
|
||||||
|
}),
|
||||||
|
));
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
async {
|
GenerateRequest {
|
||||||
|
inputs: inputs.to_string(),
|
||||||
|
add_special_tokens: false,
|
||||||
|
parameters: GenerateParameters {
|
||||||
|
best_of: None,
|
||||||
|
temperature,
|
||||||
|
repetition_penalty,
|
||||||
|
frequency_penalty,
|
||||||
|
top_k: None,
|
||||||
|
top_p,
|
||||||
|
typical_p: None,
|
||||||
|
do_sample,
|
||||||
|
max_new_tokens,
|
||||||
|
return_full_text: None,
|
||||||
|
stop,
|
||||||
|
truncate: None,
|
||||||
|
watermark: false,
|
||||||
|
details: true,
|
||||||
|
decoder_input_details: !stream,
|
||||||
|
seed,
|
||||||
|
top_n_tokens: top_logprobs,
|
||||||
|
grammar,
|
||||||
|
adapter_id: model.filter(|m| *m != "tgi").map(String::from),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let infer_clone = infer.clone();
|
||||||
|
let compute_type_clone = compute_type.clone();
|
||||||
|
let span_clone = span.clone();
|
||||||
|
|
||||||
|
futures.push(async move {
|
||||||
generate_internal(
|
generate_internal(
|
||||||
Extension(infer.clone()),
|
Extension(infer_clone),
|
||||||
compute_type.clone(),
|
compute_type_clone,
|
||||||
Json(generate_request),
|
Json(generate_request),
|
||||||
span.clone(),
|
span_clone,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.map(|(_, Json(generation))| generation.generated_text)
|
.map(|(_, Json(generation))| generation.generated_text)
|
||||||
|
@ -1442,11 +1526,13 @@ async fn vertex_compatibility(
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
});
|
||||||
}
|
}
|
||||||
})
|
|
||||||
.collect::<FuturesUnordered<_>>()
|
// execute all futures in parallel, collect results, returning early if any error occurs
|
||||||
.try_collect::<Vec<_>>()
|
let results = futures::future::join_all(futures).await;
|
||||||
.await?;
|
let predictions: Result<Vec<_>, _> = results.into_iter().collect();
|
||||||
|
let predictions = predictions?;
|
||||||
|
|
||||||
let response = VertexResponse { predictions };
|
let response = VertexResponse { predictions };
|
||||||
Ok((HeaderMap::new(), Json(response)).into_response())
|
Ok((HeaderMap::new(), Json(response)).into_response())
|
||||||
|
|
Loading…
Reference in New Issue