feat: implement a templated endpoint for visibility into chat requests (#2333)
* feat: implement a templated endpoint for visibility into chat requests * feat: improve to tokenize too * fix: adjust return type * feat: simplify prepare_chat_input logic and adjust start stop chars
This commit is contained in:
parent
29b8d19cdf
commit
e11f5f1c38
|
@ -1157,6 +1157,12 @@ pub(crate) struct GenerateResponse {
|
||||||
pub details: Option<Details>,
|
pub details: Option<Details>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, ToSchema)]
|
||||||
|
pub(crate) struct ChatTokenizeResponse {
|
||||||
|
pub(crate) tokenize_response: TokenizeResponse,
|
||||||
|
pub(crate) templated_text: String,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Serialize, ToSchema)]
|
#[derive(Serialize, ToSchema)]
|
||||||
#[serde(transparent)]
|
#[serde(transparent)]
|
||||||
pub(crate) struct TokenizeResponse(Vec<SimpleToken>);
|
pub(crate) struct TokenizeResponse(Vec<SimpleToken>);
|
||||||
|
|
|
@ -8,6 +8,7 @@ use crate::kserve::{
|
||||||
kserve_model_metadata, kserve_model_metadata_ready,
|
kserve_model_metadata, kserve_model_metadata_ready,
|
||||||
};
|
};
|
||||||
use crate::validation::ValidationError;
|
use crate::validation::ValidationError;
|
||||||
|
use crate::ChatTokenizeResponse;
|
||||||
use crate::{
|
use crate::{
|
||||||
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
|
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
|
||||||
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
|
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
|
||||||
|
@ -22,7 +23,7 @@ use crate::{
|
||||||
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, VertexRequest,
|
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, VertexRequest,
|
||||||
VertexResponse,
|
VertexResponse,
|
||||||
};
|
};
|
||||||
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType};
|
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType, Tools};
|
||||||
use async_stream::__private::AsyncStream;
|
use async_stream::__private::AsyncStream;
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
use axum::http::{HeaderMap, HeaderValue, Method, StatusCode};
|
use axum::http::{HeaderMap, HeaderValue, Method, StatusCode};
|
||||||
|
@ -115,6 +116,107 @@ async fn get_model_info(info: Extension<Info>) -> Json<Info> {
|
||||||
Json(info.0)
|
Json(info.0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[utoipa::path(
|
||||||
|
post,
|
||||||
|
tag = "Text Generation Inference",
|
||||||
|
path = "/chat_tokenize",
|
||||||
|
request_body = ChatRequest,
|
||||||
|
responses((status = 200, description = "Templated and tokenized ChatRequest", body = ChatTokenizeResponse))
|
||||||
|
)]
|
||||||
|
async fn get_chat_tokenize(
|
||||||
|
Extension(infer): Extension<Infer>,
|
||||||
|
Json(req): Json<ChatRequest>,
|
||||||
|
) -> Result<(HeaderMap, Json<ChatTokenizeResponse>), (StatusCode, Json<ErrorResponse>)> {
|
||||||
|
metrics::counter!("tgi_request_count").increment(1);
|
||||||
|
|
||||||
|
let ChatRequest {
|
||||||
|
model,
|
||||||
|
max_tokens,
|
||||||
|
messages,
|
||||||
|
seed,
|
||||||
|
stop,
|
||||||
|
stream,
|
||||||
|
tools,
|
||||||
|
tool_choice,
|
||||||
|
tool_prompt,
|
||||||
|
temperature,
|
||||||
|
response_format,
|
||||||
|
..
|
||||||
|
} = req;
|
||||||
|
|
||||||
|
let tool_prompt = tool_prompt.unwrap_or_default();
|
||||||
|
let (inputs, _grammar, _tool_grammar) = prepare_chat_input(
|
||||||
|
&infer,
|
||||||
|
response_format,
|
||||||
|
tools,
|
||||||
|
tool_choice,
|
||||||
|
&tool_prompt,
|
||||||
|
messages,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let generate_request = GenerateRequest {
|
||||||
|
inputs,
|
||||||
|
parameters: GenerateParameters {
|
||||||
|
best_of: None,
|
||||||
|
temperature,
|
||||||
|
repetition_penalty: None,
|
||||||
|
frequency_penalty: None,
|
||||||
|
top_k: None,
|
||||||
|
top_p: None,
|
||||||
|
typical_p: None,
|
||||||
|
do_sample: true,
|
||||||
|
max_new_tokens: max_tokens,
|
||||||
|
return_full_text: None,
|
||||||
|
stop: stop.unwrap_or_default(),
|
||||||
|
truncate: None,
|
||||||
|
watermark: false,
|
||||||
|
details: false,
|
||||||
|
decoder_input_details: !stream,
|
||||||
|
seed,
|
||||||
|
top_n_tokens: None,
|
||||||
|
grammar: _grammar,
|
||||||
|
adapter_id: model.as_ref().filter(|m| *m != "tgi").map(String::from),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
let input = generate_request.inputs.clone();
|
||||||
|
let encoding = infer.tokenize(generate_request).await?;
|
||||||
|
if let Some(encoding) = encoding {
|
||||||
|
let tokens: Vec<SimpleToken> = encoding
|
||||||
|
.get_ids()
|
||||||
|
.iter()
|
||||||
|
.zip(encoding.get_offsets())
|
||||||
|
.map(|(&id, &(start, stop))| {
|
||||||
|
let text = input
|
||||||
|
.chars()
|
||||||
|
.skip(start)
|
||||||
|
.take(stop - start)
|
||||||
|
.collect::<String>();
|
||||||
|
SimpleToken {
|
||||||
|
id,
|
||||||
|
text,
|
||||||
|
start,
|
||||||
|
stop,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let resp = ChatTokenizeResponse {
|
||||||
|
tokenize_response: TokenizeResponse(tokens),
|
||||||
|
templated_text: input,
|
||||||
|
};
|
||||||
|
Ok((HeaderMap::new(), Json(resp)))
|
||||||
|
} else {
|
||||||
|
Err((
|
||||||
|
StatusCode::NOT_FOUND,
|
||||||
|
Json(ErrorResponse {
|
||||||
|
error: "No fast tokenizer or tokenizer.json for this model".to_string(),
|
||||||
|
error_type: "no fast tokenizer".to_string(),
|
||||||
|
}),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[utoipa::path(
|
#[utoipa::path(
|
||||||
get,
|
get,
|
||||||
tag = "Text Generation Inference",
|
tag = "Text Generation Inference",
|
||||||
|
@ -1034,63 +1136,14 @@ async fn chat_completions(
|
||||||
Some(temperature) if temperature == 0.0 => (false, None),
|
Some(temperature) if temperature == 0.0 => (false, None),
|
||||||
other => (true, other),
|
other => (true, other),
|
||||||
};
|
};
|
||||||
|
let (inputs, grammar, tool_grammar) = prepare_chat_input(
|
||||||
// response_format and tools are mutually exclusive
|
&infer,
|
||||||
if response_format.is_some() && tools.as_ref().is_some() {
|
response_format,
|
||||||
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
|
tools,
|
||||||
return Err((
|
tool_choice,
|
||||||
StatusCode::UNPROCESSABLE_ENTITY,
|
&tool_prompt,
|
||||||
Json(ErrorResponse {
|
messages,
|
||||||
error: "Grammar and tools are mutually exclusive".to_string(),
|
)?;
|
||||||
error_type: "grammar and tools".to_string(),
|
|
||||||
}),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
// extract tool grammar if present
|
|
||||||
let tool_grammar = match ToolGrammar::apply(tools, tool_choice) {
|
|
||||||
Ok(grammar) => grammar,
|
|
||||||
Err(err) => {
|
|
||||||
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
|
|
||||||
tracing::error!("{err}");
|
|
||||||
return Err((
|
|
||||||
StatusCode::UNPROCESSABLE_ENTITY,
|
|
||||||
Json(ErrorResponse {
|
|
||||||
error: err.to_string(),
|
|
||||||
error_type: err.error_type().to_string(),
|
|
||||||
}),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// determine the appropriate arguments for apply_chat_template
|
|
||||||
let tools_grammar_prompt = tool_grammar
|
|
||||||
.as_ref()
|
|
||||||
.map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt));
|
|
||||||
|
|
||||||
let (tools_grammar_prompt, grammar) = match response_format {
|
|
||||||
Some(response_format) => (None, Some(response_format)),
|
|
||||||
None => (
|
|
||||||
tools_grammar_prompt.clone(),
|
|
||||||
tools_grammar_prompt.map(|(grammar, _)| grammar.clone()),
|
|
||||||
),
|
|
||||||
};
|
|
||||||
|
|
||||||
// apply chat template to flatten the request into a single input
|
|
||||||
let inputs = match infer.apply_chat_template(messages, tools_grammar_prompt) {
|
|
||||||
Ok(inputs) => inputs,
|
|
||||||
Err(err) => {
|
|
||||||
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
|
|
||||||
tracing::error!("{err}");
|
|
||||||
return Err((
|
|
||||||
StatusCode::UNPROCESSABLE_ENTITY,
|
|
||||||
Json(ErrorResponse {
|
|
||||||
error: err.to_string(),
|
|
||||||
error_type: err.error_type().to_string(),
|
|
||||||
}),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// build the request passing some parameters
|
// build the request passing some parameters
|
||||||
let generate_request = GenerateRequest {
|
let generate_request = GenerateRequest {
|
||||||
|
@ -1360,8 +1413,11 @@ async fn tokenize(
|
||||||
.iter()
|
.iter()
|
||||||
.zip(encoding.get_offsets())
|
.zip(encoding.get_offsets())
|
||||||
.map(|(&id, &(start, stop))| {
|
.map(|(&id, &(start, stop))| {
|
||||||
let text: String =
|
let text = input
|
||||||
String::from_utf8_lossy(&input.as_bytes()[start..stop]).to_string();
|
.chars()
|
||||||
|
.skip(start)
|
||||||
|
.take(stop - start)
|
||||||
|
.collect::<String>();
|
||||||
SimpleToken {
|
SimpleToken {
|
||||||
id,
|
id,
|
||||||
text,
|
text,
|
||||||
|
@ -2036,6 +2092,7 @@ async fn start(
|
||||||
}
|
}
|
||||||
let info_routes = Router::new()
|
let info_routes = Router::new()
|
||||||
.route("/", get(health))
|
.route("/", get(health))
|
||||||
|
.route("/chat_tokenize", post(get_chat_tokenize))
|
||||||
.route("/info", get(get_model_info))
|
.route("/info", get(get_model_info))
|
||||||
.route("/health", get(health))
|
.route("/health", get(health))
|
||||||
.route("/ping", get(health))
|
.route("/ping", get(health))
|
||||||
|
@ -2332,3 +2389,36 @@ fn create_post_processor(
|
||||||
|
|
||||||
Ok(post_processor)
|
Ok(post_processor)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type PreparedInput = (String, Option<GrammarType>, Option<Tools>);
|
||||||
|
|
||||||
|
fn prepare_chat_input(
|
||||||
|
infer: &Infer,
|
||||||
|
response_format: Option<GrammarType>,
|
||||||
|
tools: Option<Vec<Tool>>,
|
||||||
|
tool_choice: ToolChoice,
|
||||||
|
tool_prompt: &str,
|
||||||
|
messages: Vec<Message>,
|
||||||
|
) -> Result<PreparedInput, InferError> {
|
||||||
|
if response_format.is_some() && tools.is_some() {
|
||||||
|
return Err(InferError::ToolError(
|
||||||
|
"Grammar and tools are mutually exclusive".into(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(format) = response_format {
|
||||||
|
let inputs = infer.apply_chat_template(messages, None)?;
|
||||||
|
return Ok((inputs, Some(format), None));
|
||||||
|
}
|
||||||
|
|
||||||
|
// if tools are set, apply the tool grammar and then the chat template
|
||||||
|
let tool_grammar: Option<Tools> = ToolGrammar::apply(tools, tool_choice)?;
|
||||||
|
let grammar = tool_grammar
|
||||||
|
.as_ref()
|
||||||
|
.map(|t| GrammarType::Json(serde_json::json!(t)));
|
||||||
|
let tools_grammar_prompt = tool_grammar
|
||||||
|
.as_ref()
|
||||||
|
.map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt.into()));
|
||||||
|
let inputs = infer.apply_chat_template(messages, tools_grammar_prompt)?;
|
||||||
|
Ok((inputs, grammar, tool_grammar))
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue