feat: support continue_final_message param in chat request
This commit is contained in:
parent
97f7a22f0b
commit
72ed3036fc
|
@ -54,6 +54,7 @@ impl ChatTemplate {
|
|||
pub(crate) fn apply(
|
||||
&self,
|
||||
guideline: Option<&str>,
|
||||
continue_final_message: bool,
|
||||
mut messages: Vec<Message>,
|
||||
tools_and_prompt: Option<(Vec<Tool>, String)>,
|
||||
) -> Result<String, InferError> {
|
||||
|
@ -84,8 +85,9 @@ impl ChatTemplate {
|
|||
};
|
||||
|
||||
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
|
||||
|
||||
self.template
|
||||
let final_message_content = messages.last().map(|m| m.content.clone());
|
||||
let mut rendered_template = self
|
||||
.template
|
||||
.render(ChatTemplateInputs {
|
||||
guideline,
|
||||
messages,
|
||||
|
@ -94,7 +96,24 @@ impl ChatTemplate {
|
|||
add_generation_prompt: true,
|
||||
tools,
|
||||
})
|
||||
.map_err(InferError::TemplateError)
|
||||
.map_err(InferError::TemplateError)?;
|
||||
|
||||
if continue_final_message {
|
||||
// find the last occurrence of the final message in the rendered chat
|
||||
if let Some(final_message) = final_message_content {
|
||||
rendered_template = if let Some(index) = rendered_template.rfind(&final_message) {
|
||||
// implementation based on feature in transformers pipeline
|
||||
// https://github.com/huggingface/transformers/blob/1cf17077bf2d4affed31387c0943251a4ba8fab7/src/transformers/pipelines/text_generation.py#L418
|
||||
rendered_template[..index + final_message.len()]
|
||||
.trim_end()
|
||||
.to_string()
|
||||
} else {
|
||||
rendered_template
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
Ok(rendered_template)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -824,8 +843,9 @@ mod tests {
|
|||
content: MessageContent::SingleText("Hello, how are you?".to_string()),
|
||||
},
|
||||
];
|
||||
let continue_final_message = false;
|
||||
|
||||
let result = ct.apply(None, msgs, None);
|
||||
let result = ct.apply(None, continue_final_message, msgs, None);
|
||||
|
||||
match result {
|
||||
Ok(_) => panic!("Should have failed since no guideline is provided"),
|
||||
|
@ -865,9 +885,10 @@ mod tests {
|
|||
];
|
||||
let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string();
|
||||
let tools: Vec<Tool> = serde_json::from_str(&tools_string).unwrap();
|
||||
let continue_final_message = false;
|
||||
let tool_prompt = "This default prompt will be used".to_string();
|
||||
let tools_and_prompt = Some((tools, tool_prompt));
|
||||
let result = ct.apply(None, msgs, tools_and_prompt);
|
||||
let result = ct.apply(None, continue_final_message, msgs, tools_and_prompt);
|
||||
let expected = "<s>[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today?</s> [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"arguments\":{\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}}}]\nThis default prompt will be used [/INST]".to_string();
|
||||
assert_eq!(result.unwrap(), expected);
|
||||
}
|
||||
|
@ -899,9 +920,10 @@ mod tests {
|
|||
];
|
||||
let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string();
|
||||
let tools: Vec<Tool> = serde_json::from_str(&tools_string).unwrap();
|
||||
let continue_final_message = false;
|
||||
let tool_prompt = "This default prompt will be used".to_string();
|
||||
let tools_and_prompt = Some((tools, tool_prompt));
|
||||
let result = ct.apply(None, msgs, tools_and_prompt);
|
||||
let result = ct.apply(None, continue_final_message, msgs, tools_and_prompt);
|
||||
let expected = "<s><|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYoure a helpful assistant! Answer the users question best you can.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.Do not use variables.\n\n{\n \"function\": {\n \"arguments\": {\n \"properties\": {\n \"format\": {\n \"description\": \"The temperature unit to use. Infer this from the users location.\",\n \"enum\": [\n \"celsius\",\n \"fahrenheit\"\n ],\n \"type\": \"string\"\n },\n \"location\": {\n \"description\": \"The city and state, e.g. San Francisco, CA\",\n \"type\": \"string\"\n }\n },\n \"required\": [\n \"location\",\n \"format\"\n ],\n \"type\": \"object\"\n },\n \"description\": \"Get the current weather\",\n \"name\": \"get_current_weather\"\n },\n \"type\": \"function\"\n}\n\nWhat is the weather like in Brooklyn, New York?\n---\nThis default prompt will be used<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n".to_string();
|
||||
assert_eq!(result.unwrap(), expected);
|
||||
}
|
||||
|
|
|
@ -158,13 +158,19 @@ impl Infer {
|
|||
pub(crate) fn apply_chat_template(
|
||||
&self,
|
||||
guideline: Option<String>,
|
||||
continue_final_message: bool,
|
||||
messages: Vec<Message>,
|
||||
tools_and_prompt: Option<(Vec<Tool>, String)>,
|
||||
) -> Result<String, InferError> {
|
||||
self.chat_template
|
||||
.as_ref()
|
||||
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
|
||||
.apply(guideline.as_deref(), messages, tools_and_prompt)
|
||||
.apply(
|
||||
guideline.as_deref(),
|
||||
continue_final_message,
|
||||
messages,
|
||||
tools_and_prompt,
|
||||
)
|
||||
.map_err(|e| {
|
||||
metrics::counter!("tgi_request_failure", "err" => "template").increment(1);
|
||||
tracing::error!("{e}");
|
||||
|
|
|
@ -917,6 +917,11 @@ pub(crate) struct ChatRequest {
|
|||
#[serde(default)]
|
||||
#[schema(nullable = true, example = "null")]
|
||||
pub stream_options: Option<StreamOptions>,
|
||||
|
||||
/// Whether to continue the final message in the next request.
|
||||
#[serde(default)]
|
||||
#[schema(default = "false", example = true)]
|
||||
pub continue_final_message: bool,
|
||||
}
|
||||
|
||||
impl ChatRequest {
|
||||
|
@ -938,6 +943,7 @@ impl ChatRequest {
|
|||
frequency_penalty,
|
||||
top_p,
|
||||
top_logprobs,
|
||||
continue_final_message,
|
||||
..
|
||||
} = self;
|
||||
|
||||
|
@ -960,6 +966,7 @@ impl ChatRequest {
|
|||
&tool_prompt,
|
||||
guideline,
|
||||
messages,
|
||||
continue_final_message,
|
||||
)?;
|
||||
|
||||
Ok((
|
||||
|
|
|
@ -2525,6 +2525,7 @@ pub enum WebServerError {
|
|||
|
||||
type PreparedInput = (String, Option<GrammarType>, bool);
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn prepare_chat_input(
|
||||
infer: &Infer,
|
||||
response_format: Option<GrammarType>,
|
||||
|
@ -2533,6 +2534,7 @@ pub(crate) fn prepare_chat_input(
|
|||
tool_prompt: &str,
|
||||
guideline: Option<String>,
|
||||
messages: Vec<Message>,
|
||||
continue_final_message: bool,
|
||||
) -> Result<PreparedInput, InferError> {
|
||||
if response_format.is_some() && tools.is_some() {
|
||||
return Err(InferError::ToolError(
|
||||
|
@ -2542,7 +2544,8 @@ pub(crate) fn prepare_chat_input(
|
|||
|
||||
// when response_format is set, tools are not included when applying the chat template to generate inputs
|
||||
if let Some(format) = response_format {
|
||||
let inputs = infer.apply_chat_template(guideline, messages, None)?;
|
||||
let inputs =
|
||||
infer.apply_chat_template(guideline, continue_final_message, messages, None)?;
|
||||
return Ok((inputs, Some(format), false));
|
||||
}
|
||||
|
||||
|
@ -2557,6 +2560,7 @@ pub(crate) fn prepare_chat_input(
|
|||
|
||||
let inputs: String = infer.apply_chat_template(
|
||||
guideline,
|
||||
continue_final_message,
|
||||
messages,
|
||||
Some((updated_tools, tool_prompt.into())),
|
||||
)?;
|
||||
|
@ -2564,7 +2568,7 @@ pub(crate) fn prepare_chat_input(
|
|||
}
|
||||
|
||||
// if no response_format or tools are set simply apply the chat template to generate inputs
|
||||
let inputs = infer.apply_chat_template(guideline, messages, None)?;
|
||||
let inputs = infer.apply_chat_template(guideline, continue_final_message, messages, None)?;
|
||||
Ok((inputs, None, false))
|
||||
}
|
||||
|
||||
|
@ -2662,6 +2666,7 @@ mod tests {
|
|||
"What is the weather like in New York?".to_string(),
|
||||
),
|
||||
}];
|
||||
let continue_final_message = false;
|
||||
|
||||
let result = prepare_chat_input(
|
||||
&infer,
|
||||
|
@ -2671,6 +2676,7 @@ mod tests {
|
|||
tool_prompt,
|
||||
guideline,
|
||||
messages,
|
||||
continue_final_message,
|
||||
);
|
||||
|
||||
assert!(result.is_ok());
|
||||
|
|
Loading…
Reference in New Issue