feat: use existing add_generation_prompt variable from config in temp… (#1533)

This PR adds support to read the `add_generation_prompt` from the config
and use it in the chat template. If `add_generation_prompt` does not
exist we default to false
This commit is contained in:
drbh 2024-02-07 03:35:53 -05:00 committed by GitHub
parent 0da00be52c
commit 1734540211
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 56 additions and 9 deletions

View File

@ -198,6 +198,7 @@ impl Infer {
messages,
eos_token: eos_token.as_deref(),
bos_token: bos_token.as_deref(),
add_generation_prompt: true,
})
.map_err(|e| {
metrics::increment_counter!("tgi_request_failure", "err" => "template");
@ -806,21 +807,14 @@ mod tests {
],
bos_token: Some("[BOS]"),
eos_token: Some("[EOS]"),
add_generation_prompt: true,
};
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
assert_eq!(
result,
r#"### User:
Hi!
### Assistant:
Hello how can I help?### User:
What is Deep Learning?
### Assistant:
magic!"#
"### User:\nHi!\n\n### Assistant:\nHello how can I help?### User:\nWhat is Deep Learning?\n\n### Assistant:\nmagic!### Assistant:\n"
);
}
@ -878,6 +872,7 @@ magic!"#
],
bos_token: Some("[BOS]"),
eos_token: Some("[EOS]"),
add_generation_prompt: true,
};
let result = tmpl.unwrap().render(chat_template_inputs); //.err().unwrap();
@ -943,9 +938,60 @@ magic!"#
],
bos_token: Some("[BOS]"),
eos_token: Some("[EOS]"),
add_generation_prompt: true,
};
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
assert_eq!(result, "[BOS][INST] Hi! [/INST]Hello how can I help?[EOS][INST] What is Deep Learning? [/INST]magic![EOS]");
}
#[test]
fn test_chat_template_valid_with_add_generation_prompt() {
let mut env = Environment::new();
env.add_function("raise_exception", raise_exception);
let source = r#"
{% for message in messages %}
{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}
{% endfor %}
{% if add_generation_prompt %}
{{ '<|im_start|>assistant\n' }}
{% endif %}"#;
// trim all the whitespace
let source = source
.lines()
.map(|line| line.trim())
.collect::<Vec<&str>>()
.join("");
let tmpl = env.template_from_str(&source);
let chat_template_inputs = ChatTemplateInputs {
messages: vec![
Message {
role: "user".to_string(),
content: "Hi!".to_string(),
},
Message {
role: "assistant".to_string(),
content: "Hello how can I help?".to_string(),
},
Message {
role: "user".to_string(),
content: "What is Deep Learning?".to_string(),
},
Message {
role: "assistant".to_string(),
content: "magic!".to_string(),
},
],
bos_token: Some("[BOS]"),
eos_token: Some("[EOS]"),
add_generation_prompt: true,
};
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
assert_eq!(result, "<|im_start|>user\nHi!<|im_end|>\n<|im_start|>assistant\nHello how can I help?<|im_end|>\n<|im_start|>user\nWhat is Deep Learning?<|im_end|>\n<|im_start|>assistant\nmagic!<|im_end|>\n<|im_start|>assistant\n");
}
}

View File

@ -398,6 +398,7 @@ pub(crate) struct ChatTemplateInputs<'a> {
messages: Vec<Message>,
bos_token: Option<&'a str>,
eos_token: Option<&'a str>,
add_generation_prompt: bool,
}
#[derive(Clone, Deserialize, ToSchema, Serialize)]