diff --git a/router/src/infer/chat_template.rs b/router/src/infer/chat_template.rs
index ef4beee2..a8537818 100644
--- a/router/src/infer/chat_template.rs
+++ b/router/src/infer/chat_template.rs
@@ -94,7 +94,9 @@ impl ChatTemplate {
mod tests {
use crate::infer::chat_template::raise_exception;
use crate::infer::ChatTemplate;
- use crate::{ChatTemplateInputs, Message, MessageContent, TextMessage, TokenizerConfigToken};
+ use crate::{
+ ChatTemplateInputs, GrammarType, Message, MessageContent, TextMessage, TokenizerConfigToken,
+ };
use minijinja::Environment;
#[test]
@@ -823,4 +825,40 @@ mod tests {
}
}
}
+
+ #[test]
+ fn test_chat_template_with_default_tool_template() {
+ let ct = ChatTemplate::new(
+ "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}".to_string(),
+ Some(TokenizerConfigToken::String("".to_string())),
+ Some(TokenizerConfigToken::String("".to_string())),
+ );
+
+ // convert TextMessage to Message
+ let msgs: Vec = vec![
+ Message {
+ name: None,
+ role: "user".to_string(),
+ content: MessageContent::SingleText(
+ "I'd like to show off how chat templating works!".to_string(),
+ ),
+ },
+ Message {
+ name: None,
+ role: "assistant".to_string(),
+ content: MessageContent::SingleText("Great! How can I help you today?".to_string()),
+ },
+ Message {
+ name: None,
+ role: "user".to_string(),
+ content: MessageContent::SingleText("Just testing".to_string()),
+ },
+ ];
+ let tools = serde_json::json!("[]");
+ let tool_prompt = "This default prompt will be used".to_string();
+ let grammer_with_prompt = (GrammarType::Json(tools), tool_prompt);
+ let result = ct.apply(None, msgs, Some(grammer_with_prompt));
+ let expected = "[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today? [INST] Just testing\n---\nThis default prompt will be used\n\"[]\" [/INST]".to_string();
+ assert_eq!(result.unwrap(), expected);
+ }
}
diff --git a/router/src/lib.rs b/router/src/lib.rs
index d7eb4475..1b2ff153 100644
--- a/router/src/lib.rs
+++ b/router/src/lib.rs
@@ -1016,8 +1016,10 @@ impl MessageContent {
pub fn push(&mut self, chunk: MessageChunk) {
match self {
MessageContent::SingleText(text) => {
- *self =
- MessageContent::MultipleChunks(vec![MessageChunk::Text { text: text.clone() }]);
+ *self = MessageContent::MultipleChunks(vec![
+ MessageChunk::Text { text: text.clone() },
+ chunk,
+ ]);
}
MessageContent::MultipleChunks(chunks) => {
chunks.push(chunk);
@@ -1348,6 +1350,35 @@ mod tests {
);
}
+ #[test]
+ fn test_message_content_append() {
+ let mut content = MessageContent::SingleText("Initial text".to_string());
+ let chunk = MessageChunk::Text {
+ text: "Additional text".to_string(),
+ };
+
+ content.push(chunk);
+
+ match content {
+ MessageContent::MultipleChunks(chunks) => {
+ assert_eq!(chunks.len(), 2);
+ assert_eq!(
+ chunks[0],
+ MessageChunk::Text {
+ text: "Initial text".to_string()
+ }
+ );
+ assert_eq!(
+ chunks[1],
+ MessageChunk::Text {
+ text: "Additional text".to_string()
+ }
+ );
+ }
+ _ => panic!("Expected MultipleChunks, but got a different variant"),
+ }
+ }
+
#[test]
fn test_chat_request() {
let json = json!({