diff --git a/router/src/infer/chat_template.rs b/router/src/infer/chat_template.rs index ef4beee2..a8537818 100644 --- a/router/src/infer/chat_template.rs +++ b/router/src/infer/chat_template.rs @@ -94,7 +94,9 @@ impl ChatTemplate { mod tests { use crate::infer::chat_template::raise_exception; use crate::infer::ChatTemplate; - use crate::{ChatTemplateInputs, Message, MessageContent, TextMessage, TokenizerConfigToken}; + use crate::{ + ChatTemplateInputs, GrammarType, Message, MessageContent, TextMessage, TokenizerConfigToken, + }; use minijinja::Environment; #[test] @@ -823,4 +825,40 @@ mod tests { } } } + + #[test] + fn test_chat_template_with_default_tool_template() { + let ct = ChatTemplate::new( + "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}".to_string(), + Some(TokenizerConfigToken::String("".to_string())), + Some(TokenizerConfigToken::String("".to_string())), + ); + + // convert TextMessage to Message + let msgs: Vec = vec![ + Message { + name: None, + role: "user".to_string(), + content: MessageContent::SingleText( + "I'd like to show off how chat templating works!".to_string(), + ), + }, + Message { + name: None, + role: "assistant".to_string(), + content: MessageContent::SingleText("Great! How can I help you today?".to_string()), + }, + Message { + name: None, + role: "user".to_string(), + content: MessageContent::SingleText("Just testing".to_string()), + }, + ]; + let tools = serde_json::json!("[]"); + let tool_prompt = "This default prompt will be used".to_string(); + let grammer_with_prompt = (GrammarType::Json(tools), tool_prompt); + let result = ct.apply(None, msgs, Some(grammer_with_prompt)); + let expected = "[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today? [INST] Just testing\n---\nThis default prompt will be used\n\"[]\" [/INST]".to_string(); + assert_eq!(result.unwrap(), expected); + } } diff --git a/router/src/lib.rs b/router/src/lib.rs index d7eb4475..1b2ff153 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1016,8 +1016,10 @@ impl MessageContent { pub fn push(&mut self, chunk: MessageChunk) { match self { MessageContent::SingleText(text) => { - *self = - MessageContent::MultipleChunks(vec![MessageChunk::Text { text: text.clone() }]); + *self = MessageContent::MultipleChunks(vec![ + MessageChunk::Text { text: text.clone() }, + chunk, + ]); } MessageContent::MultipleChunks(chunks) => { chunks.push(chunk); @@ -1348,6 +1350,35 @@ mod tests { ); } + #[test] + fn test_message_content_append() { + let mut content = MessageContent::SingleText("Initial text".to_string()); + let chunk = MessageChunk::Text { + text: "Additional text".to_string(), + }; + + content.push(chunk); + + match content { + MessageContent::MultipleChunks(chunks) => { + assert_eq!(chunks.len(), 2); + assert_eq!( + chunks[0], + MessageChunk::Text { + text: "Initial text".to_string() + } + ); + assert_eq!( + chunks[1], + MessageChunk::Text { + text: "Additional text".to_string() + } + ); + } + _ => panic!("Expected MultipleChunks, but got a different variant"), + } + } + #[test] fn test_chat_request() { let json = json!({