diff --git a/docs/openapi.json b/docs/openapi.json index 02350a56..44691e4b 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -987,12 +987,6 @@ "messages" ], "properties": { - "continue_final_message": { - "type": "boolean", - "description": "Whether to continue the final message in the next request.", - "default": "false", - "example": true - }, "frequency_penalty": { "type": "number", "format": "float", diff --git a/integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt.json b/integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt.json index a452399e..caa00f99 100644 --- a/integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt.json +++ b/integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt.json @@ -5,19 +5,19 @@ "index": 0, "logprobs": null, "message": { - "content": "Hi, I hope this is the right place for your written question. Please provide the maximum possible length to help me complete the message for you! Based", + "content": "\nGenerate according to: It is an elephant's one year old baby or a mouse's one year old baby. It is", "role": "assistant" } } ], - "created": 1731082056, + "created": 1732050325, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "chat.completion", "system_fingerprint": "2.4.1-dev0-native", "usage": { "completion_tokens": 30, - "prompt_tokens": 57, - "total_tokens": 87 + "prompt_tokens": 37, + "total_tokens": 67 } } diff --git a/integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt_continue.json b/integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt_continue.json index 3e48bc37..f880dd74 100644 --- a/integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt_continue.json +++ b/integration-tests/models/__snapshots__/test_continue_final_message/test_llama_completion_single_prompt_continue.json @@ -5,19 +5,19 @@ "index": 0, "logprobs": null, "message": { - "content": ": Thanks for the awesome slides, they were just what we needed to produce the presentation we needed to deliver for our company's budgeting system", + "content": " Shere Kohan's fantastic exploits? written by David Shimomura & illustrated by Sarah Stern\n\nTitle: Elephant", "role": "assistant" } } ], - "created": 1731082129, + "created": 1732050326, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "chat.completion", "system_fingerprint": "2.4.1-dev0-native", "usage": { "completion_tokens": 30, - "prompt_tokens": 44, - "total_tokens": 74 + "prompt_tokens": 61, + "total_tokens": 91 } } diff --git a/integration-tests/models/test_continue_final_message.py b/integration-tests/models/test_continue_final_message.py index 9a8b07ad..ea3d83c9 100644 --- a/integration-tests/models/test_continue_final_message.py +++ b/integration-tests/models/test_continue_final_message.py @@ -28,13 +28,11 @@ def test_llama_completion_single_prompt( "model": "tgi", "messages": [ {"role": "system", "content": "system message"}, - {"role": "user", "content": "user message"}, - {"role": "assistant", "content": "assistant message"}, + {"role": "user", "content": "Which is bigger an elephant or a mouse?"}, ], "max_tokens": 30, "stream": False, "seed": 1337, - "continue_final_message": False, }, headers=llama_continue_final_message.headers, stream=False, @@ -45,7 +43,7 @@ def test_llama_completion_single_prompt( content = response["choices"][0]["message"]["content"] assert ( content - == "Hi, I hope this is the right place for your written question. Please provide the maximum possible length to help me complete the message for you! Based" + == "\nGenerate according to: It is an elephant's one year old baby or a mouse's one year old baby. It is" ) assert response == response_snapshot @@ -59,13 +57,15 @@ def test_llama_completion_single_prompt_continue( "model": "tgi", "messages": [ {"role": "system", "content": "system message"}, - {"role": "user", "content": "user message"}, - {"role": "assistant", "content": "assistant message"}, + {"role": "user", "content": "Which is bigger an elephant or a mouse?"}, + { + "role": "assistant", + "content": "the elephant, but have you heard about", + }, ], "max_tokens": 30, "stream": False, "seed": 1337, - "continue_final_message": True, }, headers=llama_continue_final_message.headers, stream=False, @@ -76,6 +76,6 @@ def test_llama_completion_single_prompt_continue( content = response["choices"][0]["message"]["content"] assert ( content - == ": Thanks for the awesome slides, they were just what we needed to produce the presentation we needed to deliver for our company's budgeting system" + == " Shere Kohan's fantastic exploits? written by David Shimomura & illustrated by Sarah Stern\n\nTitle: Elephant" ) assert response == response_snapshot diff --git a/router/src/infer/chat_template.rs b/router/src/infer/chat_template.rs index e680d600..74f38dda 100644 --- a/router/src/infer/chat_template.rs +++ b/router/src/infer/chat_template.rs @@ -54,7 +54,6 @@ impl ChatTemplate { pub(crate) fn apply( &self, guideline: Option<&str>, - continue_final_message: bool, mut messages: Vec, tools_and_prompt: Option<(Vec, String)>, ) -> Result { @@ -85,7 +84,7 @@ impl ChatTemplate { }; let messages: Vec = messages.into_iter().map(|c| c.into()).collect(); - let final_message_content = messages.last().map(|m| m.content.clone()); + let final_message = messages.last().cloned(); let mut rendered_template = self .template .render(ChatTemplateInputs { @@ -98,20 +97,20 @@ impl ChatTemplate { }) .map_err(InferError::TemplateError)?; - if continue_final_message { - // find the last occurrence of the final message in the rendered chat - if let Some(final_message) = final_message_content { - rendered_template = if let Some(index) = rendered_template.rfind(&final_message) { + // if the last message is from the assistant, continue the generation prompt + rendered_template = match final_message { + Some(msg) if msg.role == "assistant" => { + match rendered_template.rfind(msg.content.as_str()) { // implementation based on feature in transformers pipeline // https://github.com/huggingface/transformers/blob/1cf17077bf2d4affed31387c0943251a4ba8fab7/src/transformers/pipelines/text_generation.py#L418 - rendered_template[..index + final_message.len()] + Some(index) => rendered_template[..index + msg.content.len()] .trim_end() - .to_string() - } else { - rendered_template - }; + .to_string(), + None => rendered_template, + } } - } + _ => rendered_template, + }; Ok(rendered_template) } @@ -843,9 +842,8 @@ mod tests { content: MessageContent::SingleText("Hello, how are you?".to_string()), }, ]; - let continue_final_message = false; - let result = ct.apply(None, continue_final_message, msgs, None); + let result = ct.apply(None, msgs, None); match result { Ok(_) => panic!("Should have failed since no guideline is provided"), @@ -885,10 +883,9 @@ mod tests { ]; let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string(); let tools: Vec = serde_json::from_str(&tools_string).unwrap(); - let continue_final_message = false; let tool_prompt = "This default prompt will be used".to_string(); let tools_and_prompt = Some((tools, tool_prompt)); - let result = ct.apply(None, continue_final_message, msgs, tools_and_prompt); + let result = ct.apply(None, msgs, tools_and_prompt); let expected = "[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today? [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"arguments\":{\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}}}]\nThis default prompt will be used [/INST]".to_string(); assert_eq!(result.unwrap(), expected); } @@ -920,10 +917,9 @@ mod tests { ]; let tools_string = r#"[{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"format": {"type": "string","enum": ["celsius", "fahrenheit"],"description": "The temperature unit to use. Infer this from the users location."}},"required": ["location", "format"]}}}]"#.to_string(); let tools: Vec = serde_json::from_str(&tools_string).unwrap(); - let continue_final_message = false; let tool_prompt = "This default prompt will be used".to_string(); let tools_and_prompt = Some((tools, tool_prompt)); - let result = ct.apply(None, continue_final_message, msgs, tools_and_prompt); + let result = ct.apply(None, msgs, tools_and_prompt); let expected = "<|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYoure a helpful assistant! Answer the users question best you can.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.Do not use variables.\n\n{\n \"function\": {\n \"arguments\": {\n \"properties\": {\n \"format\": {\n \"description\": \"The temperature unit to use. Infer this from the users location.\",\n \"enum\": [\n \"celsius\",\n \"fahrenheit\"\n ],\n \"type\": \"string\"\n },\n \"location\": {\n \"description\": \"The city and state, e.g. San Francisco, CA\",\n \"type\": \"string\"\n }\n },\n \"required\": [\n \"location\",\n \"format\"\n ],\n \"type\": \"object\"\n },\n \"description\": \"Get the current weather\",\n \"name\": \"get_current_weather\"\n },\n \"type\": \"function\"\n}\n\nWhat is the weather like in Brooklyn, New York?\n---\nThis default prompt will be used<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n".to_string(); assert_eq!(result.unwrap(), expected); } diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 41c2ffe8..d3d6bc59 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -160,19 +160,13 @@ impl Infer { pub(crate) fn apply_chat_template( &self, guideline: Option, - continue_final_message: bool, messages: Vec, tools_and_prompt: Option<(Vec, String)>, ) -> Result { self.chat_template .as_ref() .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))? - .apply( - guideline.as_deref(), - continue_final_message, - messages, - tools_and_prompt, - ) + .apply(guideline.as_deref(), messages, tools_and_prompt) .map_err(|e| { metrics::counter!("tgi_request_failure", "err" => "template").increment(1); tracing::error!("{e}");