feat: use existing add_generation_prompt variable from config in temp… (#1533)
This PR adds support to read the `add_generation_prompt` from the config and use it in the chat template. If `add_generation_prompt` does not exist we default to false
This commit is contained in:
parent
0da00be52c
commit
1734540211
|
@ -198,6 +198,7 @@ impl Infer {
|
|||
messages,
|
||||
eos_token: eos_token.as_deref(),
|
||||
bos_token: bos_token.as_deref(),
|
||||
add_generation_prompt: true,
|
||||
})
|
||||
.map_err(|e| {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "template");
|
||||
|
@ -806,21 +807,14 @@ mod tests {
|
|||
],
|
||||
bos_token: Some("[BOS]"),
|
||||
eos_token: Some("[EOS]"),
|
||||
add_generation_prompt: true,
|
||||
};
|
||||
|
||||
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
result,
|
||||
r#"### User:
|
||||
Hi!
|
||||
|
||||
### Assistant:
|
||||
Hello how can I help?### User:
|
||||
What is Deep Learning?
|
||||
|
||||
### Assistant:
|
||||
magic!"#
|
||||
"### User:\nHi!\n\n### Assistant:\nHello how can I help?### User:\nWhat is Deep Learning?\n\n### Assistant:\nmagic!### Assistant:\n"
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -878,6 +872,7 @@ magic!"#
|
|||
],
|
||||
bos_token: Some("[BOS]"),
|
||||
eos_token: Some("[EOS]"),
|
||||
add_generation_prompt: true,
|
||||
};
|
||||
|
||||
let result = tmpl.unwrap().render(chat_template_inputs); //.err().unwrap();
|
||||
|
@ -943,9 +938,60 @@ magic!"#
|
|||
],
|
||||
bos_token: Some("[BOS]"),
|
||||
eos_token: Some("[EOS]"),
|
||||
add_generation_prompt: true,
|
||||
};
|
||||
|
||||
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
|
||||
assert_eq!(result, "[BOS][INST] Hi! [/INST]Hello how can I help?[EOS][INST] What is Deep Learning? [/INST]magic![EOS]");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_template_valid_with_add_generation_prompt() {
|
||||
let mut env = Environment::new();
|
||||
env.add_function("raise_exception", raise_exception);
|
||||
|
||||
let source = r#"
|
||||
{% for message in messages %}
|
||||
{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}
|
||||
{% endfor %}
|
||||
{% if add_generation_prompt %}
|
||||
{{ '<|im_start|>assistant\n' }}
|
||||
{% endif %}"#;
|
||||
|
||||
// trim all the whitespace
|
||||
let source = source
|
||||
.lines()
|
||||
.map(|line| line.trim())
|
||||
.collect::<Vec<&str>>()
|
||||
.join("");
|
||||
|
||||
let tmpl = env.template_from_str(&source);
|
||||
|
||||
let chat_template_inputs = ChatTemplateInputs {
|
||||
messages: vec![
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: "Hi!".to_string(),
|
||||
},
|
||||
Message {
|
||||
role: "assistant".to_string(),
|
||||
content: "Hello how can I help?".to_string(),
|
||||
},
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: "What is Deep Learning?".to_string(),
|
||||
},
|
||||
Message {
|
||||
role: "assistant".to_string(),
|
||||
content: "magic!".to_string(),
|
||||
},
|
||||
],
|
||||
bos_token: Some("[BOS]"),
|
||||
eos_token: Some("[EOS]"),
|
||||
add_generation_prompt: true,
|
||||
};
|
||||
|
||||
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
|
||||
assert_eq!(result, "<|im_start|>user\nHi!<|im_end|>\n<|im_start|>assistant\nHello how can I help?<|im_end|>\n<|im_start|>user\nWhat is Deep Learning?<|im_end|>\n<|im_start|>assistant\nmagic!<|im_end|>\n<|im_start|>assistant\n");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -398,6 +398,7 @@ pub(crate) struct ChatTemplateInputs<'a> {
|
|||
messages: Vec<Message>,
|
||||
bos_token: Option<&'a str>,
|
||||
eos_token: Option<&'a str>,
|
||||
add_generation_prompt: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||
|
|
Loading…
Reference in New Issue