fix: remove continue_final_message chat request param

This commit is contained in:
David Holtz 2024-11-19 21:24:18 +00:00
parent 1895d7b745
commit fa577c9be2
8 changed files with 33 additions and 62 deletions

View File

@ -984,12 +984,6 @@
"messages" "messages"
], ],
"properties": { "properties": {
"continue_final_message": {
"type": "boolean",
"description": "Whether to continue the final message in the next request.",
"default": "false",
"example": true
},
"frequency_penalty": { "frequency_penalty": {
"type": "number", "type": "number",
"format": "float", "format": "float",

View File

@ -5,19 +5,19 @@
"index": 0, "index": 0,
"logprobs": null, "logprobs": null,
"message": { "message": {
"content": "Hi, I hope this is the right place for your written question. Please provide the maximum possible length to help me complete the message for you! Based", "content": "\nGenerate according to: It is an elephant's one year old baby or a mouse's one year old baby. It is",
"role": "assistant" "role": "assistant"
} }
} }
], ],
"created": 1731082056, "created": 1732050325,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "chat.completion", "object": "chat.completion",
"system_fingerprint": "2.4.1-dev0-native", "system_fingerprint": "2.4.1-dev0-native",
"usage": { "usage": {
"completion_tokens": 30, "completion_tokens": 30,
"prompt_tokens": 57, "prompt_tokens": 37,
"total_tokens": 87 "total_tokens": 67
} }
} }

View File

@ -5,19 +5,19 @@
"index": 0, "index": 0,
"logprobs": null, "logprobs": null,
"message": { "message": {
"content": ": Thanks for the awesome slides, they were just what we needed to produce the presentation we needed to deliver for our company's budgeting system", "content": " Shere Kohan's fantastic exploits? written by David Shimomura & illustrated by Sarah Stern\n\nTitle: Elephant",
"role": "assistant" "role": "assistant"
} }
} }
], ],
"created": 1731082129, "created": 1732050326,
"id": "", "id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "chat.completion", "object": "chat.completion",
"system_fingerprint": "2.4.1-dev0-native", "system_fingerprint": "2.4.1-dev0-native",
"usage": { "usage": {
"completion_tokens": 30, "completion_tokens": 30,
"prompt_tokens": 44, "prompt_tokens": 61,
"total_tokens": 74 "total_tokens": 91
} }
} }

View File

@ -28,13 +28,11 @@ def test_llama_completion_single_prompt(
"model": "tgi", "model": "tgi",
"messages": [ "messages": [
{"role": "system", "content": "system message"}, {"role": "system", "content": "system message"},
{"role": "user", "content": "user message"}, {"role": "user", "content": "Which is bigger an elephant or a mouse?"},
{"role": "assistant", "content": "assistant message"},
], ],
"max_tokens": 30, "max_tokens": 30,
"stream": False, "stream": False,
"seed": 1337, "seed": 1337,
"continue_final_message": False,
}, },
headers=llama_continue_final_message.headers, headers=llama_continue_final_message.headers,
stream=False, stream=False,
@ -45,7 +43,7 @@ def test_llama_completion_single_prompt(
content = response["choices"][0]["message"]["content"] content = response["choices"][0]["message"]["content"]
assert ( assert (
content content
== "Hi, I hope this is the right place for your written question. Please provide the maximum possible length to help me complete the message for you! Based" == "\nGenerate according to: It is an elephant's one year old baby or a mouse's one year old baby. It is"
) )
assert response == response_snapshot assert response == response_snapshot
@ -59,13 +57,15 @@ def test_llama_completion_single_prompt_continue(
"model": "tgi", "model": "tgi",
"messages": [ "messages": [
{"role": "system", "content": "system message"}, {"role": "system", "content": "system message"},
{"role": "user", "content": "user message"}, {"role": "user", "content": "Which is bigger an elephant or a mouse?"},
{"role": "assistant", "content": "assistant message"}, {
"role": "assistant",
"content": "the elephant, but have you heard about",
},
], ],
"max_tokens": 30, "max_tokens": 30,
"stream": False, "stream": False,
"seed": 1337, "seed": 1337,
"continue_final_message": True,
}, },
headers=llama_continue_final_message.headers, headers=llama_continue_final_message.headers,
stream=False, stream=False,
@ -76,6 +76,6 @@ def test_llama_completion_single_prompt_continue(
content = response["choices"][0]["message"]["content"] content = response["choices"][0]["message"]["content"]
assert ( assert (
content content
== ": Thanks for the awesome slides, they were just what we needed to produce the presentation we needed to deliver for our company's budgeting system" == " Shere Kohan's fantastic exploits? written by David Shimomura & illustrated by Sarah Stern\n\nTitle: Elephant"
) )
assert response == response_snapshot assert response == response_snapshot

View File

@ -54,7 +54,6 @@ impl ChatTemplate {
pub(crate) fn apply( pub(crate) fn apply(
&self, &self,
guideline: Option<&str>, guideline: Option<&str>,
continue_final_message: bool,
mut messages: Vec<Message>, mut messages: Vec<Message>,
tools_and_prompt: Option<(Vec<Tool>, String)>, tools_and_prompt: Option<(Vec<Tool>, String)>,
) -> Result<String, InferError> { ) -> Result<String, InferError> {
@ -85,7 +84,7 @@ impl ChatTemplate {
}; };
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect(); let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
let final_message_content = messages.last().map(|m| m.content.clone()); let final_message = messages.last().cloned();
let mut rendered_template = self let mut rendered_template = self
.template .template
.render(ChatTemplateInputs { .render(ChatTemplateInputs {
@ -98,20 +97,20 @@ impl ChatTemplate {
}) })
.map_err(InferError::TemplateError)?; .map_err(InferError::TemplateError)?;
if continue_final_message { // if the last message is from the assistant, continue the generation prompt
// find the last occurrence of the final message in the rendered chat rendered_template = match final_message {
if let Some(final_message) = final_message_content { Some(msg) if msg.role == "assistant" => {
rendered_template = if let Some(index) = rendered_template.rfind(&final_message) { match rendered_template.rfind(msg.content.as_str()) {
// implementation based on feature in transformers pipeline // implementation based on feature in transformers pipeline
// https://github.com/huggingface/transformers/blob/1cf17077bf2d4affed31387c0943251a4ba8fab7/src/transformers/pipelines/text_generation.py#L418 // https://github.com/huggingface/transformers/blob/1cf17077bf2d4affed31387c0943251a4ba8fab7/src/transformers/pipelines/text_generation.py#L418
rendered_template[..index + final_message.len()] Some(index) => rendered_template[..index + msg.content.len()]
.trim_end() .trim_end()
.to_string() .to_string(),
} else { None => rendered_template,
rendered_template }
};
} }
} _ => rendered_template,
};
Ok(rendered_template) Ok(rendered_template)
} }
@ -843,9 +842,8 @@ mod tests {
content: MessageContent::SingleText("Hello, how are you?".to_string()), content: MessageContent::SingleText("Hello, how are you?".to_string()),
}, },
]; ];
let continue_final_message = false;
let result = ct.apply(None, continue_final_message, msgs, None); let result = ct.apply(None, msgs, None);
match result { match result {
Ok(_) => panic!("Should have failed since no guideline is provided"), Ok(_) => panic!("Should have failed since no guideline is provided"),
@ -885,10 +883,9 @@ mod tests {
]; ];
let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string(); let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string();
let tools: Vec<Tool> = serde_json::from_str(&tools_string).unwrap(); let tools: Vec<Tool> = serde_json::from_str(&tools_string).unwrap();
let continue_final_message = false;
let tool_prompt = "This default prompt will be used".to_string(); let tool_prompt = "This default prompt will be used".to_string();
let tools_and_prompt = Some((tools, tool_prompt)); let tools_and_prompt = Some((tools, tool_prompt));
let result = ct.apply(None, continue_final_message, msgs, tools_and_prompt); let result = ct.apply(None, msgs, tools_and_prompt);
let expected = "<s>[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today?</s> [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"arguments\":{\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}}}]\nThis default prompt will be used [/INST]".to_string(); let expected = "<s>[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today?</s> [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"arguments\":{\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}}}]\nThis default prompt will be used [/INST]".to_string();
assert_eq!(result.unwrap(), expected); assert_eq!(result.unwrap(), expected);
} }
@ -920,10 +917,9 @@ mod tests {
]; ];
let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string(); let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string();
let tools: Vec<Tool> = serde_json::from_str(&tools_string).unwrap(); let tools: Vec<Tool> = serde_json::from_str(&tools_string).unwrap();
let continue_final_message = false;
let tool_prompt = "This default prompt will be used".to_string(); let tool_prompt = "This default prompt will be used".to_string();
let tools_and_prompt = Some((tools, tool_prompt)); let tools_and_prompt = Some((tools, tool_prompt));
let result = ct.apply(None, continue_final_message, msgs, tools_and_prompt); let result = ct.apply(None, msgs, tools_and_prompt);
let expected = "<s><|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYoure a helpful assistant! Answer the users question best you can.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.Do not use variables.\n\n{\n \"function\": {\n \"arguments\": {\n \"properties\": {\n \"format\": {\n \"description\": \"The temperature unit to use. Infer this from the users location.\",\n \"enum\": [\n \"celsius\",\n \"fahrenheit\"\n ],\n \"type\": \"string\"\n },\n \"location\": {\n \"description\": \"The city and state, e.g. San Francisco, CA\",\n \"type\": \"string\"\n }\n },\n \"required\": [\n \"location\",\n \"format\"\n ],\n \"type\": \"object\"\n },\n \"description\": \"Get the current weather\",\n \"name\": \"get_current_weather\"\n },\n \"type\": \"function\"\n}\n\nWhat is the weather like in Brooklyn, New York?\n---\nThis default prompt will be used<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n".to_string(); let expected = "<s><|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYoure a helpful assistant! Answer the users question best you can.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.Do not use variables.\n\n{\n \"function\": {\n \"arguments\": {\n \"properties\": {\n \"format\": {\n \"description\": \"The temperature unit to use. Infer this from the users location.\",\n \"enum\": [\n \"celsius\",\n \"fahrenheit\"\n ],\n \"type\": \"string\"\n },\n \"location\": {\n \"description\": \"The city and state, e.g. San Francisco, CA\",\n \"type\": \"string\"\n }\n },\n \"required\": [\n \"location\",\n \"format\"\n ],\n \"type\": \"object\"\n },\n \"description\": \"Get the current weather\",\n \"name\": \"get_current_weather\"\n },\n \"type\": \"function\"\n}\n\nWhat is the weather like in Brooklyn, New York?\n---\nThis default prompt will be used<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n".to_string();
assert_eq!(result.unwrap(), expected); assert_eq!(result.unwrap(), expected);
} }

View File

@ -158,19 +158,13 @@ impl Infer {
pub(crate) fn apply_chat_template( pub(crate) fn apply_chat_template(
&self, &self,
guideline: Option<String>, guideline: Option<String>,
continue_final_message: bool,
messages: Vec<Message>, messages: Vec<Message>,
tools_and_prompt: Option<(Vec<Tool>, String)>, tools_and_prompt: Option<(Vec<Tool>, String)>,
) -> Result<String, InferError> { ) -> Result<String, InferError> {
self.chat_template self.chat_template
.as_ref() .as_ref()
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))? .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
.apply( .apply(guideline.as_deref(), messages, tools_and_prompt)
guideline.as_deref(),
continue_final_message,
messages,
tools_and_prompt,
)
.map_err(|e| { .map_err(|e| {
metrics::counter!("tgi_request_failure", "err" => "template").increment(1); metrics::counter!("tgi_request_failure", "err" => "template").increment(1);
tracing::error!("{e}"); tracing::error!("{e}");

View File

@ -917,11 +917,6 @@ pub(crate) struct ChatRequest {
#[serde(default)] #[serde(default)]
#[schema(nullable = true, example = "null")] #[schema(nullable = true, example = "null")]
pub stream_options: Option<StreamOptions>, pub stream_options: Option<StreamOptions>,
/// Whether to continue the final message in the next request.
#[serde(default)]
#[schema(default = "false", example = true)]
pub continue_final_message: bool,
} }
impl ChatRequest { impl ChatRequest {
@ -943,7 +938,6 @@ impl ChatRequest {
frequency_penalty, frequency_penalty,
top_p, top_p,
top_logprobs, top_logprobs,
continue_final_message,
.. ..
} = self; } = self;
@ -966,7 +960,6 @@ impl ChatRequest {
&tool_prompt, &tool_prompt,
guideline, guideline,
messages, messages,
continue_final_message,
)?; )?;
Ok(( Ok((

View File

@ -2525,7 +2525,6 @@ pub enum WebServerError {
type PreparedInput = (String, Option<GrammarType>, bool); type PreparedInput = (String, Option<GrammarType>, bool);
#[allow(clippy::too_many_arguments)]
pub(crate) fn prepare_chat_input( pub(crate) fn prepare_chat_input(
infer: &Infer, infer: &Infer,
response_format: Option<GrammarType>, response_format: Option<GrammarType>,
@ -2534,7 +2533,6 @@ pub(crate) fn prepare_chat_input(
tool_prompt: &str, tool_prompt: &str,
guideline: Option<String>, guideline: Option<String>,
messages: Vec<Message>, messages: Vec<Message>,
continue_final_message: bool,
) -> Result<PreparedInput, InferError> { ) -> Result<PreparedInput, InferError> {
if response_format.is_some() && tools.is_some() { if response_format.is_some() && tools.is_some() {
return Err(InferError::ToolError( return Err(InferError::ToolError(
@ -2544,8 +2542,7 @@ pub(crate) fn prepare_chat_input(
// when response_format is set, tools are not included when applying the chat template to generate inputs // when response_format is set, tools are not included when applying the chat template to generate inputs
if let Some(format) = response_format { if let Some(format) = response_format {
let inputs = let inputs = infer.apply_chat_template(guideline, messages, None)?;
infer.apply_chat_template(guideline, continue_final_message, messages, None)?;
return Ok((inputs, Some(format), false)); return Ok((inputs, Some(format), false));
} }
@ -2560,7 +2557,6 @@ pub(crate) fn prepare_chat_input(
let inputs: String = infer.apply_chat_template( let inputs: String = infer.apply_chat_template(
guideline, guideline,
continue_final_message,
messages, messages,
Some((updated_tools, tool_prompt.into())), Some((updated_tools, tool_prompt.into())),
)?; )?;
@ -2568,7 +2564,7 @@ pub(crate) fn prepare_chat_input(
} }
// if no response_format or tools are set simply apply the chat template to generate inputs // if no response_format or tools are set simply apply the chat template to generate inputs
let inputs = infer.apply_chat_template(guideline, continue_final_message, messages, None)?; let inputs = infer.apply_chat_template(guideline, messages, None)?;
Ok((inputs, None, false)) Ok((inputs, None, false))
} }
@ -2666,7 +2662,6 @@ mod tests {
"What is the weather like in New York?".to_string(), "What is the weather like in New York?".to_string(),
), ),
}]; }];
let continue_final_message = false;
let result = prepare_chat_input( let result = prepare_chat_input(
&infer, &infer,
@ -2676,7 +2671,6 @@ mod tests {
tool_prompt, tool_prompt,
guideline, guideline,
messages, messages,
continue_final_message,
); );
assert!(result.is_ok()); assert!(result.is_ok());