diff --git a/router/src/infer.rs b/router/src/infer.rs index 5f078ba0..4da0da0a 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -198,6 +198,7 @@ impl Infer { messages, eos_token: eos_token.as_deref(), bos_token: bos_token.as_deref(), + add_generation_prompt: true, }) .map_err(|e| { metrics::increment_counter!("tgi_request_failure", "err" => "template"); @@ -806,21 +807,14 @@ mod tests { ], bos_token: Some("[BOS]"), eos_token: Some("[EOS]"), + add_generation_prompt: true, }; let result = tmpl.unwrap().render(chat_template_inputs).unwrap(); assert_eq!( result, - r#"### User: -Hi! - -### Assistant: -Hello how can I help?### User: -What is Deep Learning? - -### Assistant: -magic!"# + "### User:\nHi!\n\n### Assistant:\nHello how can I help?### User:\nWhat is Deep Learning?\n\n### Assistant:\nmagic!### Assistant:\n" ); } @@ -878,6 +872,7 @@ magic!"# ], bos_token: Some("[BOS]"), eos_token: Some("[EOS]"), + add_generation_prompt: true, }; let result = tmpl.unwrap().render(chat_template_inputs); //.err().unwrap(); @@ -943,9 +938,60 @@ magic!"# ], bos_token: Some("[BOS]"), eos_token: Some("[EOS]"), + add_generation_prompt: true, }; let result = tmpl.unwrap().render(chat_template_inputs).unwrap(); assert_eq!(result, "[BOS][INST] Hi! [/INST]Hello how can I help?[EOS][INST] What is Deep Learning? [/INST]magic![EOS]"); } + + #[test] + fn test_chat_template_valid_with_add_generation_prompt() { + let mut env = Environment::new(); + env.add_function("raise_exception", raise_exception); + + let source = r#" + {% for message in messages %} + {{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}} + {% endfor %} + {% if add_generation_prompt %} + {{ '<|im_start|>assistant\n' }} + {% endif %}"#; + + // trim all the whitespace + let source = source + .lines() + .map(|line| line.trim()) + .collect::>() + .join(""); + + let tmpl = env.template_from_str(&source); + + let chat_template_inputs = ChatTemplateInputs { + messages: vec![ + Message { + role: "user".to_string(), + content: "Hi!".to_string(), + }, + Message { + role: "assistant".to_string(), + content: "Hello how can I help?".to_string(), + }, + Message { + role: "user".to_string(), + content: "What is Deep Learning?".to_string(), + }, + Message { + role: "assistant".to_string(), + content: "magic!".to_string(), + }, + ], + bos_token: Some("[BOS]"), + eos_token: Some("[EOS]"), + add_generation_prompt: true, + }; + + let result = tmpl.unwrap().render(chat_template_inputs).unwrap(); + assert_eq!(result, "<|im_start|>user\nHi!<|im_end|>\n<|im_start|>assistant\nHello how can I help?<|im_end|>\n<|im_start|>user\nWhat is Deep Learning?<|im_end|>\n<|im_start|>assistant\nmagic!<|im_end|>\n<|im_start|>assistant\n"); + } } diff --git a/router/src/lib.rs b/router/src/lib.rs index 07360e78..e85519cc 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -398,6 +398,7 @@ pub(crate) struct ChatTemplateInputs<'a> { messages: Vec, bos_token: Option<&'a str>, eos_token: Option<&'a str>, + add_generation_prompt: bool, } #[derive(Clone, Deserialize, ToSchema, Serialize)]