From 72ed3036fc6546d1e902224c19c63bd6c3390141 Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 7 Nov 2024 18:24:58 -0400 Subject: [PATCH] feat: support continue_final_message param in chat request --- router/src/infer/chat_template.rs | 34 +++++++++++++++++++++++++------ router/src/infer/mod.rs | 8 +++++++- router/src/lib.rs | 7 +++++++ router/src/server.rs | 10 +++++++-- 4 files changed, 50 insertions(+), 9 deletions(-) diff --git a/router/src/infer/chat_template.rs b/router/src/infer/chat_template.rs index 1071d0ba..e680d600 100644 --- a/router/src/infer/chat_template.rs +++ b/router/src/infer/chat_template.rs @@ -54,6 +54,7 @@ impl ChatTemplate { pub(crate) fn apply( &self, guideline: Option<&str>, + continue_final_message: bool, mut messages: Vec, tools_and_prompt: Option<(Vec, String)>, ) -> Result { @@ -84,8 +85,9 @@ impl ChatTemplate { }; let messages: Vec = messages.into_iter().map(|c| c.into()).collect(); - - self.template + let final_message_content = messages.last().map(|m| m.content.clone()); + let mut rendered_template = self + .template .render(ChatTemplateInputs { guideline, messages, @@ -94,7 +96,24 @@ impl ChatTemplate { add_generation_prompt: true, tools, }) - .map_err(InferError::TemplateError) + .map_err(InferError::TemplateError)?; + + if continue_final_message { + // find the last occurrence of the final message in the rendered chat + if let Some(final_message) = final_message_content { + rendered_template = if let Some(index) = rendered_template.rfind(&final_message) { + // implementation based on feature in transformers pipeline + // https://github.com/huggingface/transformers/blob/1cf17077bf2d4affed31387c0943251a4ba8fab7/src/transformers/pipelines/text_generation.py#L418 + rendered_template[..index + final_message.len()] + .trim_end() + .to_string() + } else { + rendered_template + }; + } + } + + Ok(rendered_template) } } @@ -824,8 +843,9 @@ mod tests { content: MessageContent::SingleText("Hello, how are you?".to_string()), }, ]; + let continue_final_message = false; - let result = ct.apply(None, msgs, None); + let result = ct.apply(None, continue_final_message, msgs, None); match result { Ok(_) => panic!("Should have failed since no guideline is provided"), @@ -865,9 +885,10 @@ mod tests { ]; let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string(); let tools: Vec = serde_json::from_str(&tools_string).unwrap(); + let continue_final_message = false; let tool_prompt = "This default prompt will be used".to_string(); let tools_and_prompt = Some((tools, tool_prompt)); - let result = ct.apply(None, msgs, tools_and_prompt); + let result = ct.apply(None, continue_final_message, msgs, tools_and_prompt); let expected = "[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today? [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"arguments\":{\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}}}]\nThis default prompt will be used [/INST]".to_string(); assert_eq!(result.unwrap(), expected); } @@ -899,9 +920,10 @@ mod tests { ]; let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string(); let tools: Vec = serde_json::from_str(&tools_string).unwrap(); + let continue_final_message = false; let tool_prompt = "This default prompt will be used".to_string(); let tools_and_prompt = Some((tools, tool_prompt)); - let result = ct.apply(None, msgs, tools_and_prompt); + let result = ct.apply(None, continue_final_message, msgs, tools_and_prompt); let expected = "<|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYoure a helpful assistant! Answer the users question best you can.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.Do not use variables.\n\n{\n \"function\": {\n \"arguments\": {\n \"properties\": {\n \"format\": {\n \"description\": \"The temperature unit to use. Infer this from the users location.\",\n \"enum\": [\n \"celsius\",\n \"fahrenheit\"\n ],\n \"type\": \"string\"\n },\n \"location\": {\n \"description\": \"The city and state, e.g. San Francisco, CA\",\n \"type\": \"string\"\n }\n },\n \"required\": [\n \"location\",\n \"format\"\n ],\n \"type\": \"object\"\n },\n \"description\": \"Get the current weather\",\n \"name\": \"get_current_weather\"\n },\n \"type\": \"function\"\n}\n\nWhat is the weather like in Brooklyn, New York?\n---\nThis default prompt will be used<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n".to_string(); assert_eq!(result.unwrap(), expected); } diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 557e03cb..ffff41bd 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -158,13 +158,19 @@ impl Infer { pub(crate) fn apply_chat_template( &self, guideline: Option, + continue_final_message: bool, messages: Vec, tools_and_prompt: Option<(Vec, String)>, ) -> Result { self.chat_template .as_ref() .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))? - .apply(guideline.as_deref(), messages, tools_and_prompt) + .apply( + guideline.as_deref(), + continue_final_message, + messages, + tools_and_prompt, + ) .map_err(|e| { metrics::counter!("tgi_request_failure", "err" => "template").increment(1); tracing::error!("{e}"); diff --git a/router/src/lib.rs b/router/src/lib.rs index d9cacb91..5d47cd3e 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -917,6 +917,11 @@ pub(crate) struct ChatRequest { #[serde(default)] #[schema(nullable = true, example = "null")] pub stream_options: Option, + + /// Whether to continue the final message in the next request. + #[serde(default)] + #[schema(default = "false", example = true)] + pub continue_final_message: bool, } impl ChatRequest { @@ -938,6 +943,7 @@ impl ChatRequest { frequency_penalty, top_p, top_logprobs, + continue_final_message, .. } = self; @@ -960,6 +966,7 @@ impl ChatRequest { &tool_prompt, guideline, messages, + continue_final_message, )?; Ok(( diff --git a/router/src/server.rs b/router/src/server.rs index 2058bce3..d822031e 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -2525,6 +2525,7 @@ pub enum WebServerError { type PreparedInput = (String, Option, bool); +#[allow(clippy::too_many_arguments)] pub(crate) fn prepare_chat_input( infer: &Infer, response_format: Option, @@ -2533,6 +2534,7 @@ pub(crate) fn prepare_chat_input( tool_prompt: &str, guideline: Option, messages: Vec, + continue_final_message: bool, ) -> Result { if response_format.is_some() && tools.is_some() { return Err(InferError::ToolError( @@ -2542,7 +2544,8 @@ pub(crate) fn prepare_chat_input( // when response_format is set, tools are not included when applying the chat template to generate inputs if let Some(format) = response_format { - let inputs = infer.apply_chat_template(guideline, messages, None)?; + let inputs = + infer.apply_chat_template(guideline, continue_final_message, messages, None)?; return Ok((inputs, Some(format), false)); } @@ -2557,6 +2560,7 @@ pub(crate) fn prepare_chat_input( let inputs: String = infer.apply_chat_template( guideline, + continue_final_message, messages, Some((updated_tools, tool_prompt.into())), )?; @@ -2564,7 +2568,7 @@ pub(crate) fn prepare_chat_input( } // if no response_format or tools are set simply apply the chat template to generate inputs - let inputs = infer.apply_chat_template(guideline, messages, None)?; + let inputs = infer.apply_chat_template(guideline, continue_final_message, messages, None)?; Ok((inputs, None, false)) } @@ -2662,6 +2666,7 @@ mod tests { "What is the weather like in New York?".to_string(), ), }]; + let continue_final_message = false; let result = prepare_chat_input( &infer, @@ -2671,6 +2676,7 @@ mod tests { tool_prompt, guideline, messages, + continue_final_message, ); assert!(result.is_ok());