From c99ecd77ecc079a67c176b46b61c7a2d85ac068f Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 30 Apr 2024 06:18:32 -0400 Subject: [PATCH] Handle images in chat api (#1828) This PR allows for messages to be formatted as simple strings, or as an array of objects including image urls. This is done by formatting content arrays into a simple string. Example using `llava-hf/llava-v1.6-mistral-7b-hf` ```bash curl localhost: 3000/v1/chat/completions \ -X POST \ -H 'Content-Type: application/json' \ -d '{ "model": "tgi", "messages": [ { "role": "user", "content": [ { "type": "text", "text": "Whats in this image?" }, { "type": "image_url", "image_url": { "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png" } } ] } ], "stream": false, "max_tokens": 20, "seed": 42 }' ``` is equivlant to this more simple request ```bash curl localhost: 3000/v1/chat/completions \ -X POST \ -H 'Content-Type: application/json' \ -d '{ "model": "tgi", "messages": [ { "role": "user", "content": "Whats in this image?\n![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png)" } ], "stream": false, "max_tokens": 20, "seed": 42 }' ``` output ``` # {"id":"","object":"text_completion","created":1714406985,"model":"llava-hf/llava-v1.6-mistral-7b-hf","system_fingerprint":"2.0.1-native","choices":[{"index":0,"message":{"role":"assistant","content":" This is an illustration of an anthropomorphic rabbit in a spacesuit, standing on what"},"logprobs":null,"finish_reason":"length"}],"usage":{"prompt_tokens":2945,"completion_tokens":20,"total_tokens":2965}}% ``` --------- Co-authored-by: Nicolas Patry --- router/src/lib.rs | 65 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 64 insertions(+), 1 deletion(-) diff --git a/router/src/lib.rs b/router/src/lib.rs index 9b9097f6..fac4c14e 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -884,12 +884,75 @@ pub(crate) struct ToolCall { pub function: FunctionDefinition, } -#[derive(Clone, Deserialize, ToSchema, Serialize)] +#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)] +pub(crate) struct Text { + #[serde(default)] + pub text: String, +} + +#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)] +pub(crate) struct ImageUrl { + #[serde(default)] + pub url: String, +} + +#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)] +pub(crate) struct Content { + pub r#type: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub text: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub image_url: Option, +} + +mod message_content_serde { + use super::*; + use serde::de; + use serde::Deserializer; + use serde_json::Value; + + pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> + where + D: Deserializer<'de>, + { + let value = Value::deserialize(deserializer)?; + match value { + Value::String(s) => Ok(Some(s)), + Value::Array(arr) => { + let results: Result, _> = arr + .into_iter() + .map(|v| { + let content: Content = + serde_json::from_value(v).map_err(de::Error::custom)?; + match content.r#type.as_str() { + "text" => Ok(content.text.unwrap_or_default()), + "image_url" => { + if let Some(url) = content.image_url { + Ok(format!("![]({})", url.url)) + } else { + Ok(String::new()) + } + } + _ => Err(de::Error::custom("invalid content type")), + } + }) + .collect(); + + results.map(|strings| Some(strings.join(""))) + } + Value::Null => Ok(None), + _ => Err(de::Error::custom("invalid token format")), + } + } +} + +#[derive(Clone, Deserialize, ToSchema, Serialize, Debug)] pub(crate) struct Message { #[schema(example = "user")] pub role: String, #[serde(skip_serializing_if = "Option::is_none")] #[schema(example = "My name is David and I")] + #[serde(deserialize_with = "message_content_serde::deserialize")] pub content: Option, #[serde(default, skip_serializing_if = "Option::is_none")] #[schema(example = "\"David\"")]