Merge commit 'refs/pull/1888/head' of github.com:huggingface/text-generation-inference into main

This commit is contained in:
drbh 2024-05-14 18:35:00 +00:00
commit 7fe123ff36
3 changed files with 30 additions and 4 deletions

View File

@ -979,24 +979,28 @@ mod tests {
content: Some("Hi!".to_string()), content: Some("Hi!".to_string()),
name: None, name: None,
tool_calls: None, tool_calls: None,
tool_call_id: None,
}, },
Message { Message {
role: "assistant".to_string(), role: "assistant".to_string(),
content: Some("Hello how can I help?".to_string()), content: Some("Hello how can I help?".to_string()),
name: None, name: None,
tool_calls: None, tool_calls: None,
tool_call_id: None,
}, },
Message { Message {
role: "user".to_string(), role: "user".to_string(),
content: Some("What is Deep Learning?".to_string()), content: Some("What is Deep Learning?".to_string()),
name: None, name: None,
tool_calls: None, tool_calls: None,
tool_call_id: None,
}, },
Message { Message {
role: "assistant".to_string(), role: "assistant".to_string(),
content: Some("magic!".to_string()), content: Some("magic!".to_string()),
name: None, name: None,
tool_calls: None, tool_calls: None,
tool_call_id: None,
}, },
], ],
bos_token: Some("[BOS]"), bos_token: Some("[BOS]"),
@ -1049,30 +1053,35 @@ mod tests {
content: Some("Hi!".to_string()), content: Some("Hi!".to_string()),
name: None, name: None,
tool_calls: None, tool_calls: None,
tool_call_id: None,
}, },
Message { Message {
role: "user".to_string(), role: "user".to_string(),
content: Some("Hi again!".to_string()), content: Some("Hi again!".to_string()),
name: None, name: None,
tool_calls: None, tool_calls: None,
tool_call_id: None,
}, },
Message { Message {
role: "assistant".to_string(), role: "assistant".to_string(),
content: Some("Hello how can I help?".to_string()), content: Some("Hello how can I help?".to_string()),
name: None, name: None,
tool_calls: None, tool_calls: None,
tool_call_id: None,
}, },
Message { Message {
role: "user".to_string(), role: "user".to_string(),
content: Some("What is Deep Learning?".to_string()), content: Some("What is Deep Learning?".to_string()),
name: None, name: None,
tool_calls: None, tool_calls: None,
tool_call_id: None,
}, },
Message { Message {
role: "assistant".to_string(), role: "assistant".to_string(),
content: Some("magic!".to_string()), content: Some("magic!".to_string()),
name: None, name: None,
tool_calls: None, tool_calls: None,
tool_call_id: None,
}, },
], ],
bos_token: Some("[BOS]"), bos_token: Some("[BOS]"),
@ -1130,24 +1139,28 @@ mod tests {
content: Some("Hi!".to_string()), content: Some("Hi!".to_string()),
name: None, name: None,
tool_calls: None, tool_calls: None,
tool_call_id: None,
}, },
Message { Message {
role: "assistant".to_string(), role: "assistant".to_string(),
content: Some("Hello how can I help?".to_string()), content: Some("Hello how can I help?".to_string()),
name: None, name: None,
tool_calls: None, tool_calls: None,
tool_call_id: None,
}, },
Message { Message {
role: "user".to_string(), role: "user".to_string(),
content: Some("What is Deep Learning?".to_string()), content: Some("What is Deep Learning?".to_string()),
name: None, name: None,
tool_calls: None, tool_calls: None,
tool_call_id: None,
}, },
Message { Message {
role: "assistant".to_string(), role: "assistant".to_string(),
content: Some("magic!".to_string()), content: Some("magic!".to_string()),
name: None, name: None,
tool_calls: None, tool_calls: None,
tool_call_id: None,
}, },
], ],
bos_token: Some("[BOS]"), bos_token: Some("[BOS]"),
@ -1189,24 +1202,28 @@ mod tests {
content: Some("Hi!".to_string()), content: Some("Hi!".to_string()),
name: None, name: None,
tool_calls: None, tool_calls: None,
tool_call_id: None,
}, },
Message { Message {
role: "assistant".to_string(), role: "assistant".to_string(),
content: Some("Hello how can I help?".to_string()), content: Some("Hello how can I help?".to_string()),
name: None, name: None,
tool_calls: None, tool_calls: None,
tool_call_id: None,
}, },
Message { Message {
role: "user".to_string(), role: "user".to_string(),
content: Some("What is Deep Learning?".to_string()), content: Some("What is Deep Learning?".to_string()),
name: None, name: None,
tool_calls: None, tool_calls: None,
tool_call_id: None,
}, },
Message { Message {
role: "assistant".to_string(), role: "assistant".to_string(),
content: Some("magic!".to_string()), content: Some("magic!".to_string()),
name: None, name: None,
tool_calls: None, tool_calls: None,
tool_call_id: None,
}, },
], ],
bos_token: Some("[BOS]"), bos_token: Some("[BOS]"),
@ -1234,18 +1251,21 @@ mod tests {
content: Some("Hello, how are you?".to_string()), content: Some("Hello, how are you?".to_string()),
name: None, name: None,
tool_calls: None, tool_calls: None,
tool_call_id: None,
}, },
Message { Message {
role: "assistant".to_string(), role: "assistant".to_string(),
content: Some("I'm doing great. How can I help you today?".to_string()), content: Some("I'm doing great. How can I help you today?".to_string()),
name: None, name: None,
tool_calls: None, tool_calls: None,
tool_call_id: None,
}, },
Message { Message {
role: "user".to_string(), role: "user".to_string(),
content: Some("I'd like to show off how chat templating works!".to_string()), content: Some("I'd like to show off how chat templating works!".to_string()),
name: None, name: None,
tool_calls: None, tool_calls: None,
tool_call_id: None,
}, },
]; ];
@ -1257,6 +1277,7 @@ mod tests {
), ),
name: None, name: None,
tool_calls: None, tool_calls: None,
tool_call_id: None,
}] }]
.iter() .iter()
.chain(&example_chat) .chain(&example_chat)
@ -1401,12 +1422,14 @@ mod tests {
content: Some("You are a friendly chatbot who always responds in the style of a pirate".to_string()), content: Some("You are a friendly chatbot who always responds in the style of a pirate".to_string()),
name: None, name: None,
tool_calls: None, tool_calls: None,
tool_call_id: None,
}, },
Message { Message {
role: "user".to_string(), role: "user".to_string(),
content: Some("How many helicopters can a human eat in one sitting?".to_string()), content: Some("How many helicopters can a human eat in one sitting?".to_string()),
name: None, name: None,
tool_calls: None, tool_calls: None,
tool_call_id: None,
}, },
], ],
add_generation_prompt: true, add_generation_prompt: true,

View File

@ -546,6 +546,7 @@ impl ChatCompletion {
content: output, content: output,
name: None, name: None,
tool_calls, tool_calls,
tool_call_id:None
}, },
logprobs: return_logprobs logprobs: return_logprobs
.then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))), .then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))),
@ -881,7 +882,7 @@ pub(crate) struct ChatTemplateInputs<'a> {
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)] #[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)]
pub(crate) struct ToolCall { pub(crate) struct ToolCall {
pub id: u32, pub id: String,
pub r#type: String, pub r#type: String,
pub function: FunctionDefinition, pub function: FunctionDefinition,
} }
@ -954,13 +955,16 @@ pub(crate) struct Message {
pub role: String, pub role: String,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
#[schema(example = "My name is David and I")] #[schema(example = "My name is David and I")]
#[serde(deserialize_with = "message_content_serde::deserialize")] #[serde(default, deserialize_with = "message_content_serde::deserialize")]
pub content: Option<String>, pub content: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
#[schema(example = "\"David\"")] #[schema(example = "\"David\"")]
pub name: Option<String>, pub name: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>, pub tool_calls: Option<Vec<ToolCall>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[schema(example = "\"get_weather\"")]
pub tool_call_id: Option<String>,
} }
#[derive(Clone, Debug, Deserialize, ToSchema)] #[derive(Clone, Debug, Deserialize, ToSchema)]

View File

@ -988,7 +988,6 @@ async fn chat_completions(
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current(); let span = tracing::Span::current();
metrics::increment_counter!("tgi_request_count"); metrics::increment_counter!("tgi_request_count");
let ChatRequest { let ChatRequest {
logprobs, logprobs,
max_tokens, max_tokens,
@ -1160,7 +1159,7 @@ async fn chat_completions(
) )
})?; })?;
let tool_calls = vec![ToolCall { let tool_calls = vec![ToolCall {
id: 0, id: "0".to_string(),
r#type: "function".to_string(), r#type: "function".to_string(),
function: FunctionDefinition { function: FunctionDefinition {
description: None, description: None,