From 0d06aed02dba2cfadca5c9fb7e9183545c78f39e Mon Sep 17 00:00:00 2001 From: drbh Date: Fri, 9 Aug 2024 10:56:45 -0400 Subject: [PATCH] feat: add guideline to chat request and template (#2391) * feat: add guideline to chat request and template * fix: add template test and update docs --- docs/openapi.json | 7 +++++++ router/src/infer/chat_template.rs | 15 +++++++++++++++ router/src/infer/mod.rs | 3 ++- router/src/lib.rs | 6 ++++++ router/src/server.rs | 9 +++++++-- 5 files changed, 37 insertions(+), 3 deletions(-) diff --git a/docs/openapi.json b/docs/openapi.json index ed9b0b96..ecd56e4d 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -819,6 +819,13 @@ "example": "1.0", "nullable": true }, + "guideline": { + "type": "string", + "description": "A guideline to be used in the chat_template", + "default": "null", + "example": "null", + "nullable": true + }, "logit_bias": { "type": "array", "items": { diff --git a/router/src/infer/chat_template.rs b/router/src/infer/chat_template.rs index 24a00352..7c2753ed 100644 --- a/router/src/infer/chat_template.rs +++ b/router/src/infer/chat_template.rs @@ -48,6 +48,7 @@ impl ChatTemplate { pub(crate) fn apply( &self, + guideline: Option<&str>, mut messages: Vec, grammar_with_prompt: Option<(GrammarType, String)>, ) -> Result { @@ -65,6 +66,7 @@ impl ChatTemplate { self.template .render(ChatTemplateInputs { + guideline, messages, bos_token: self.bos_token.as_deref(), eos_token: self.eos_token.as_deref(), @@ -731,6 +733,19 @@ mod tests { }, target: "You are a friendly chatbot who always responds in the style of a pirateYou are a friendly chatbot who always responds in the style of a pirate### Instruction: Hello, how are you?### Response: I'm doing great. How can I help you today?### Instruction: I'd like to show off how chat templating works!", }, + ChatTemplateTestItem { + name: "google/shieldgemma-9b", + chat_template: "{{- bos_token }}\n{%- if messages[-1].role == 'user' %}\n {{- \"You are a policy expert trying to help determine whether a user prompt is in violation of the defined safety policies.\\n\\n\" }}\n {{- \"\\n\" }}\n {{- \"Human Question: \" + messages[-1].content }}\n {{- \"\\n\\n\\n\" }}\n {{- \"Our safety principle is defined in the below:\\n\\n\" }}\n {{- \"* \" + guideline + \"\\n\" }}\n {{- \"\\n===\\n\\n\" }}\n {{- \"Does the human question violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\\n\\n\" }}\n{%- else %}\n {{- \"You are a policy expert trying to help determine whether an AI response to a prompt is in violation of the defined safety policies.\\n\\n\" }}\n {{- \"\\n\" }}\n {{- \"Human Question: \" + messages[-2].content }}\n {{- \"\\n\\n\" }}\n {{- \"\\n\" }}\n {{- \"Chatbot Response: \" + messages[-1].content }}\n {{- \"\\n\\n\\n\" }}\n {{- \"Our safety principle is defined in the below:\\n\\n\" }}\n {{- \"* \" + guideline + \"\\n\" }}\n {{- \"\\n===\\n\\n\" }}\n {{- \"Does the Chatbot Response violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\\n\\n\" }}\n{%- endif %}\n\n", + input: ChatTemplateInputs { + messages: example_chat_with_system.clone(), + add_generation_prompt: false, + bos_token: Some(""), + eos_token: Some(""), + guideline: Some("Do not use offensive language."), + ..Default::default() + }, + target: "You are a policy expert trying to help determine whether a user prompt is in violation of the defined safety policies.\n\n\nHuman Question: I'd like to show off how chat templating works!\n\n\nOur safety principle is defined in the below:\n\n* Do not use offensive language.\n\n===\n\nDoes the human question violate the above principle? Your answer must start with 'Yes' or 'No'. And then walk through step by step to be sure we answer correctly.\n\n", + }, ]; #[allow(unused_variables)] // name is unused diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 534a2647..58d5cf9a 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -138,13 +138,14 @@ impl Infer { #[instrument(skip_all)] pub(crate) fn apply_chat_template( &self, + guideline: Option, messages: Vec, grammar_with_prompt: Option<(GrammarType, String)>, ) -> Result { self.chat_template .as_ref() .ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))? - .apply(messages, grammar_with_prompt) + .apply(guideline.as_deref(), messages, grammar_with_prompt) .map_err(|e| { metrics::counter!("tgi_request_failure", "err" => "template").increment(1); tracing::error!("{e}"); diff --git a/router/src/lib.rs b/router/src/lib.rs index 66738706..0a15c495 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -858,6 +858,11 @@ pub(crate) struct ChatRequest { #[serde(default)] #[schema(nullable = true, default = "null", example = "null")] pub response_format: Option, + + /// A guideline to be used in the chat_template + #[serde(default)] + #[schema(nullable = true, default = "null", example = "null")] + pub guideline: Option, } fn default_tool_prompt() -> Option { @@ -965,6 +970,7 @@ pub(crate) struct ChatTemplateInputs<'a> { add_generation_prompt: bool, tools: Option<&'a str>, tools_prompt: Option<&'a str>, + guideline: Option<&'a str>, } #[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug, PartialEq)] diff --git a/router/src/server.rs b/router/src/server.rs index 1d1cd36a..8c0bd762 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -141,6 +141,7 @@ async fn get_chat_tokenize( tool_prompt, temperature, response_format, + guideline, .. } = req; @@ -151,6 +152,7 @@ async fn get_chat_tokenize( tools, tool_choice, &tool_prompt, + guideline, messages, )?; @@ -1123,6 +1125,7 @@ async fn chat_completions( tool_prompt, temperature, response_format, + guideline, .. } = req; @@ -1142,6 +1145,7 @@ async fn chat_completions( tools, tool_choice, &tool_prompt, + guideline, messages, )?; @@ -2402,6 +2406,7 @@ fn prepare_chat_input( tools: Option>, tool_choice: ToolChoice, tool_prompt: &str, + guideline: Option, messages: Vec, ) -> Result { if response_format.is_some() && tools.is_some() { @@ -2411,7 +2416,7 @@ fn prepare_chat_input( } if let Some(format) = response_format { - let inputs = infer.apply_chat_template(messages, None)?; + let inputs = infer.apply_chat_template(guideline, messages, None)?; return Ok((inputs, Some(format), None)); } @@ -2423,6 +2428,6 @@ fn prepare_chat_input( let tools_grammar_prompt = tool_grammar .as_ref() .map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt.into())); - let inputs = infer.apply_chat_template(messages, tools_grammar_prompt)?; + let inputs = infer.apply_chat_template(guideline, messages, tools_grammar_prompt)?; Ok((inputs, grammar, tool_grammar)) }