fix: remove accidentally included guideline from rebase
This commit is contained in:
parent
4069955e44
commit
8770b39c20
|
@ -2,7 +2,6 @@ use crate::infer::InferError;
|
|||
use crate::{ChatTemplateInputs, Message, MessageChunk, TextMessage, TokenizerConfigToken, Tool};
|
||||
use minijinja::{Environment, ErrorKind, Template};
|
||||
use minijinja_contrib::pycompat;
|
||||
use std::collections::HashSet;
|
||||
|
||||
/// Raise a exception (custom function) used in the chat templates
|
||||
pub(crate) fn raise_exception(err_text: String) -> Result<String, minijinja::Error> {
|
||||
|
@ -15,7 +14,6 @@ pub(crate) struct ChatTemplate {
|
|||
bos_token: Option<String>,
|
||||
eos_token: Option<String>,
|
||||
use_default_tool_template: bool,
|
||||
variables: HashSet<String>,
|
||||
}
|
||||
|
||||
impl ChatTemplate {
|
||||
|
@ -47,21 +45,14 @@ impl ChatTemplate {
|
|||
bos_token: bos_token.map(|token| token.as_str().to_string()),
|
||||
eos_token: eos_token.map(|token| token.as_str().to_string()),
|
||||
use_default_tool_template,
|
||||
variables,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn apply(
|
||||
&self,
|
||||
guideline: Option<&str>,
|
||||
mut messages: Vec<Message>,
|
||||
tools_and_prompt: Option<(Vec<Tool>, String)>,
|
||||
) -> Result<String, InferError> {
|
||||
// check if guideline is expected but not provided
|
||||
if self.variables.contains("guideline") && guideline.is_none() {
|
||||
return Err(InferError::MissingTemplateVariable("guideline".to_string()));
|
||||
}
|
||||
|
||||
let tools = match tools_and_prompt {
|
||||
Some((tools, tool_prompt)) => {
|
||||
// check if the `tools` variable is used in the template
|
||||
|
@ -88,7 +79,6 @@ impl ChatTemplate {
|
|||
let mut rendered_template = self
|
||||
.template
|
||||
.render(ChatTemplateInputs {
|
||||
guideline,
|
||||
messages,
|
||||
bos_token: self.bos_token.as_deref(),
|
||||
eos_token: self.eos_token.as_deref(),
|
||||
|
@ -782,7 +772,6 @@ mod tests {
|
|||
add_generation_prompt: false,
|
||||
bos_token: Some("<s>"),
|
||||
eos_token: Some("</s>"),
|
||||
guideline: Some("Do not use offensive language."),
|
||||
..Default::default()
|
||||
},
|
||||
target: "<s>You are a policy expert trying to help determine whether a user prompt is in violation of the defined safety policies.\n\n<start_of_turn>\nHuman Question: I'd like to show off how chat templating works!\n<end_of_turn>\n\nOur safety principle is defined in the below:\n\n* Do not use offensive language.\n\n===\n\nDoes the human question violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\n\n",
|
||||
|
@ -843,7 +832,7 @@ mod tests {
|
|||
},
|
||||
];
|
||||
|
||||
let result = ct.apply(None, msgs, None);
|
||||
let result = ct.apply(msgs, None);
|
||||
|
||||
match result {
|
||||
Ok(_) => panic!("Should have failed since no guideline is provided"),
|
||||
|
@ -885,7 +874,7 @@ mod tests {
|
|||
let tools: Vec<Tool> = serde_json::from_str(&tools_string).unwrap();
|
||||
let tool_prompt = "This default prompt will be used".to_string();
|
||||
let tools_and_prompt = Some((tools, tool_prompt));
|
||||
let result = ct.apply(None, msgs, tools_and_prompt);
|
||||
let result = ct.apply(msgs, tools_and_prompt);
|
||||
let expected = "<s>[INST] I'd like to show off how chat templating works! [/INST]Great! How can I help you today?</s> [INST] Just testing\n---\n[{\"type\":\"function\",\"function\":{\"description\":\"Get the current weather\",\"name\":\"get_current_weather\",\"arguments\":{\"properties\":{\"format\":{\"description\":\"The temperature unit to use. Infer this from the users location.\",\"enum\":[\"celsius\",\"fahrenheit\"],\"type\":\"string\"},\"location\":{\"description\":\"The city and state, e.g. San Francisco, CA\",\"type\":\"string\"}},\"required\":[\"location\",\"format\"],\"type\":\"object\"}}}]\nThis default prompt will be used [/INST]".to_string();
|
||||
assert_eq!(result.unwrap(), expected);
|
||||
}
|
||||
|
@ -919,7 +908,7 @@ mod tests {
|
|||
let tools: Vec<Tool> = serde_json::from_str(&tools_string).unwrap();
|
||||
let tool_prompt = "This default prompt will be used".to_string();
|
||||
let tools_and_prompt = Some((tools, tool_prompt));
|
||||
let result = ct.apply(None, msgs, tools_and_prompt);
|
||||
let result = ct.apply(msgs, tools_and_prompt);
|
||||
let expected = "<s><|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYoure a helpful assistant! Answer the users question best you can.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.Do not use variables.\n\n{\n \"function\": {\n \"arguments\": {\n \"properties\": {\n \"format\": {\n \"description\": \"The temperature unit to use. Infer this from the users location.\",\n \"enum\": [\n \"celsius\",\n \"fahrenheit\"\n ],\n \"type\": \"string\"\n },\n \"location\": {\n \"description\": \"The city and state, e.g. San Francisco, CA\",\n \"type\": \"string\"\n }\n },\n \"required\": [\n \"location\",\n \"format\"\n ],\n \"type\": \"object\"\n },\n \"description\": \"Get the current weather\",\n \"name\": \"get_current_weather\"\n },\n \"type\": \"function\"\n}\n\nWhat is the weather like in Brooklyn, New York?\n---\nThis default prompt will be used<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n".to_string();
|
||||
assert_eq!(result.unwrap(), expected);
|
||||
}
|
||||
|
|
|
@ -159,14 +159,13 @@ impl Infer {
|
|||
#[instrument(skip_all)]
|
||||
pub(crate) fn apply_chat_template(
|
||||
&self,
|
||||
guideline: Option<String>,
|
||||
messages: Vec<Message>,
|
||||
tools_and_prompt: Option<(Vec<Tool>, String)>,
|
||||
) -> Result<String, InferError> {
|
||||
self.chat_template
|
||||
.as_ref()
|
||||
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
|
||||
.apply(guideline.as_deref(), messages, tools_and_prompt)
|
||||
.apply(messages, tools_and_prompt)
|
||||
.map_err(|e| {
|
||||
metrics::counter!("tgi_request_failure", "err" => "template").increment(1);
|
||||
tracing::error!("{e}");
|
||||
|
|
Loading…
Reference in New Issue