feat: use existing add_generation_prompt variable from config in temp… (#1533)
This PR adds support to read the `add_generation_prompt` from the config and use it in the chat template. If `add_generation_prompt` does not exist we default to false
This commit is contained in:
parent
0da00be52c
commit
1734540211
|
@ -198,6 +198,7 @@ impl Infer {
|
||||||
messages,
|
messages,
|
||||||
eos_token: eos_token.as_deref(),
|
eos_token: eos_token.as_deref(),
|
||||||
bos_token: bos_token.as_deref(),
|
bos_token: bos_token.as_deref(),
|
||||||
|
add_generation_prompt: true,
|
||||||
})
|
})
|
||||||
.map_err(|e| {
|
.map_err(|e| {
|
||||||
metrics::increment_counter!("tgi_request_failure", "err" => "template");
|
metrics::increment_counter!("tgi_request_failure", "err" => "template");
|
||||||
|
@ -806,21 +807,14 @@ mod tests {
|
||||||
],
|
],
|
||||||
bos_token: Some("[BOS]"),
|
bos_token: Some("[BOS]"),
|
||||||
eos_token: Some("[EOS]"),
|
eos_token: Some("[EOS]"),
|
||||||
|
add_generation_prompt: true,
|
||||||
};
|
};
|
||||||
|
|
||||||
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
|
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
result,
|
result,
|
||||||
r#"### User:
|
"### User:\nHi!\n\n### Assistant:\nHello how can I help?### User:\nWhat is Deep Learning?\n\n### Assistant:\nmagic!### Assistant:\n"
|
||||||
Hi!
|
|
||||||
|
|
||||||
### Assistant:
|
|
||||||
Hello how can I help?### User:
|
|
||||||
What is Deep Learning?
|
|
||||||
|
|
||||||
### Assistant:
|
|
||||||
magic!"#
|
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -878,6 +872,7 @@ magic!"#
|
||||||
],
|
],
|
||||||
bos_token: Some("[BOS]"),
|
bos_token: Some("[BOS]"),
|
||||||
eos_token: Some("[EOS]"),
|
eos_token: Some("[EOS]"),
|
||||||
|
add_generation_prompt: true,
|
||||||
};
|
};
|
||||||
|
|
||||||
let result = tmpl.unwrap().render(chat_template_inputs); //.err().unwrap();
|
let result = tmpl.unwrap().render(chat_template_inputs); //.err().unwrap();
|
||||||
|
@ -943,9 +938,60 @@ magic!"#
|
||||||
],
|
],
|
||||||
bos_token: Some("[BOS]"),
|
bos_token: Some("[BOS]"),
|
||||||
eos_token: Some("[EOS]"),
|
eos_token: Some("[EOS]"),
|
||||||
|
add_generation_prompt: true,
|
||||||
};
|
};
|
||||||
|
|
||||||
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
|
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
|
||||||
assert_eq!(result, "[BOS][INST] Hi! [/INST]Hello how can I help?[EOS][INST] What is Deep Learning? [/INST]magic![EOS]");
|
assert_eq!(result, "[BOS][INST] Hi! [/INST]Hello how can I help?[EOS][INST] What is Deep Learning? [/INST]magic![EOS]");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_chat_template_valid_with_add_generation_prompt() {
|
||||||
|
let mut env = Environment::new();
|
||||||
|
env.add_function("raise_exception", raise_exception);
|
||||||
|
|
||||||
|
let source = r#"
|
||||||
|
{% for message in messages %}
|
||||||
|
{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}
|
||||||
|
{% endfor %}
|
||||||
|
{% if add_generation_prompt %}
|
||||||
|
{{ '<|im_start|>assistant\n' }}
|
||||||
|
{% endif %}"#;
|
||||||
|
|
||||||
|
// trim all the whitespace
|
||||||
|
let source = source
|
||||||
|
.lines()
|
||||||
|
.map(|line| line.trim())
|
||||||
|
.collect::<Vec<&str>>()
|
||||||
|
.join("");
|
||||||
|
|
||||||
|
let tmpl = env.template_from_str(&source);
|
||||||
|
|
||||||
|
let chat_template_inputs = ChatTemplateInputs {
|
||||||
|
messages: vec![
|
||||||
|
Message {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: "Hi!".to_string(),
|
||||||
|
},
|
||||||
|
Message {
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
content: "Hello how can I help?".to_string(),
|
||||||
|
},
|
||||||
|
Message {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: "What is Deep Learning?".to_string(),
|
||||||
|
},
|
||||||
|
Message {
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
content: "magic!".to_string(),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
bos_token: Some("[BOS]"),
|
||||||
|
eos_token: Some("[EOS]"),
|
||||||
|
add_generation_prompt: true,
|
||||||
|
};
|
||||||
|
|
||||||
|
let result = tmpl.unwrap().render(chat_template_inputs).unwrap();
|
||||||
|
assert_eq!(result, "<|im_start|>user\nHi!<|im_end|>\n<|im_start|>assistant\nHello how can I help?<|im_end|>\n<|im_start|>user\nWhat is Deep Learning?<|im_end|>\n<|im_start|>assistant\nmagic!<|im_end|>\n<|im_start|>assistant\n");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -398,6 +398,7 @@ pub(crate) struct ChatTemplateInputs<'a> {
|
||||||
messages: Vec<Message>,
|
messages: Vec<Message>,
|
||||||
bos_token: Option<&'a str>,
|
bos_token: Option<&'a str>,
|
||||||
eos_token: Option<&'a str>,
|
eos_token: Option<&'a str>,
|
||||||
|
add_generation_prompt: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||||
|
|
Loading…
Reference in New Issue