diff --git a/router/src/lib.rs b/router/src/lib.rs index 14bb8270..386b0556 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1157,6 +1157,12 @@ pub(crate) struct GenerateResponse { pub details: Option
, } +#[derive(Serialize, ToSchema)] +pub(crate) struct ChatTokenizeResponse { + pub(crate) tokenize_response: TokenizeResponse, + pub(crate) templated_text: String, +} + #[derive(Serialize, ToSchema)] #[serde(transparent)] pub(crate) struct TokenizeResponse(Vec); diff --git a/router/src/server.rs b/router/src/server.rs index dcbaa2ad..7655182a 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -8,6 +8,7 @@ use crate::kserve::{ kserve_model_metadata, kserve_model_metadata_ready, }; use crate::validation::ValidationError; +use crate::ChatTokenizeResponse; use crate::{ usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, @@ -22,7 +23,7 @@ use crate::{ CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, VertexRequest, VertexResponse, }; -use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType}; +use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType, Tools}; use async_stream::__private::AsyncStream; use axum::extract::Extension; use axum::http::{HeaderMap, HeaderValue, Method, StatusCode}; @@ -115,6 +116,107 @@ async fn get_model_info(info: Extension) -> Json { Json(info.0) } +#[utoipa::path( + post, + tag = "Text Generation Inference", + path = "/chat_tokenize", + request_body = ChatRequest, + responses((status = 200, description = "Templated and tokenized ChatRequest", body = ChatTokenizeResponse)) +)] +async fn get_chat_tokenize( + Extension(infer): Extension, + Json(req): Json, +) -> Result<(HeaderMap, Json), (StatusCode, Json)> { + metrics::counter!("tgi_request_count").increment(1); + + let ChatRequest { + model, + max_tokens, + messages, + seed, + stop, + stream, + tools, + tool_choice, + tool_prompt, + temperature, + response_format, + .. + } = req; + + let tool_prompt = tool_prompt.unwrap_or_default(); + let (inputs, _grammar, _tool_grammar) = prepare_chat_input( + &infer, + response_format, + tools, + tool_choice, + &tool_prompt, + messages, + )?; + + let generate_request = GenerateRequest { + inputs, + parameters: GenerateParameters { + best_of: None, + temperature, + repetition_penalty: None, + frequency_penalty: None, + top_k: None, + top_p: None, + typical_p: None, + do_sample: true, + max_new_tokens: max_tokens, + return_full_text: None, + stop: stop.unwrap_or_default(), + truncate: None, + watermark: false, + details: false, + decoder_input_details: !stream, + seed, + top_n_tokens: None, + grammar: _grammar, + adapter_id: model.as_ref().filter(|m| *m != "tgi").map(String::from), + }, + }; + + let input = generate_request.inputs.clone(); + let encoding = infer.tokenize(generate_request).await?; + if let Some(encoding) = encoding { + let tokens: Vec = encoding + .get_ids() + .iter() + .zip(encoding.get_offsets()) + .map(|(&id, &(start, stop))| { + let text = input + .chars() + .skip(start) + .take(stop - start) + .collect::(); + SimpleToken { + id, + text, + start, + stop, + } + }) + .collect(); + + let resp = ChatTokenizeResponse { + tokenize_response: TokenizeResponse(tokens), + templated_text: input, + }; + Ok((HeaderMap::new(), Json(resp))) + } else { + Err(( + StatusCode::NOT_FOUND, + Json(ErrorResponse { + error: "No fast tokenizer or tokenizer.json for this model".to_string(), + error_type: "no fast tokenizer".to_string(), + }), + )) + } +} + #[utoipa::path( get, tag = "Text Generation Inference", @@ -1034,63 +1136,14 @@ async fn chat_completions( Some(temperature) if temperature == 0.0 => (false, None), other => (true, other), }; - - // response_format and tools are mutually exclusive - if response_format.is_some() && tools.as_ref().is_some() { - metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); - return Err(( - StatusCode::UNPROCESSABLE_ENTITY, - Json(ErrorResponse { - error: "Grammar and tools are mutually exclusive".to_string(), - error_type: "grammar and tools".to_string(), - }), - )); - } - - // extract tool grammar if present - let tool_grammar = match ToolGrammar::apply(tools, tool_choice) { - Ok(grammar) => grammar, - Err(err) => { - metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); - tracing::error!("{err}"); - return Err(( - StatusCode::UNPROCESSABLE_ENTITY, - Json(ErrorResponse { - error: err.to_string(), - error_type: err.error_type().to_string(), - }), - )); - } - }; - - // determine the appropriate arguments for apply_chat_template - let tools_grammar_prompt = tool_grammar - .as_ref() - .map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt)); - - let (tools_grammar_prompt, grammar) = match response_format { - Some(response_format) => (None, Some(response_format)), - None => ( - tools_grammar_prompt.clone(), - tools_grammar_prompt.map(|(grammar, _)| grammar.clone()), - ), - }; - - // apply chat template to flatten the request into a single input - let inputs = match infer.apply_chat_template(messages, tools_grammar_prompt) { - Ok(inputs) => inputs, - Err(err) => { - metrics::counter!("tgi_request_failure", "err" => "validation").increment(1); - tracing::error!("{err}"); - return Err(( - StatusCode::UNPROCESSABLE_ENTITY, - Json(ErrorResponse { - error: err.to_string(), - error_type: err.error_type().to_string(), - }), - )); - } - }; + let (inputs, grammar, tool_grammar) = prepare_chat_input( + &infer, + response_format, + tools, + tool_choice, + &tool_prompt, + messages, + )?; // build the request passing some parameters let generate_request = GenerateRequest { @@ -1360,8 +1413,11 @@ async fn tokenize( .iter() .zip(encoding.get_offsets()) .map(|(&id, &(start, stop))| { - let text: String = - String::from_utf8_lossy(&input.as_bytes()[start..stop]).to_string(); + let text = input + .chars() + .skip(start) + .take(stop - start) + .collect::(); SimpleToken { id, text, @@ -2036,6 +2092,7 @@ async fn start( } let info_routes = Router::new() .route("/", get(health)) + .route("/chat_tokenize", post(get_chat_tokenize)) .route("/info", get(get_model_info)) .route("/health", get(health)) .route("/ping", get(health)) @@ -2332,3 +2389,36 @@ fn create_post_processor( Ok(post_processor) } + +type PreparedInput = (String, Option, Option); + +fn prepare_chat_input( + infer: &Infer, + response_format: Option, + tools: Option>, + tool_choice: ToolChoice, + tool_prompt: &str, + messages: Vec, +) -> Result { + if response_format.is_some() && tools.is_some() { + return Err(InferError::ToolError( + "Grammar and tools are mutually exclusive".into(), + )); + } + + if let Some(format) = response_format { + let inputs = infer.apply_chat_template(messages, None)?; + return Ok((inputs, Some(format), None)); + } + + // if tools are set, apply the tool grammar and then the chat template + let tool_grammar: Option = ToolGrammar::apply(tools, tool_choice)?; + let grammar = tool_grammar + .as_ref() + .map(|t| GrammarType::Json(serde_json::json!(t))); + let tools_grammar_prompt = tool_grammar + .as_ref() + .map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt.into())); + let inputs = infer.apply_chat_template(messages, tools_grammar_prompt)?; + Ok((inputs, grammar, tool_grammar)) +}