Handle images in chat api (#1828)

This PR allows for messages to be formatted as simple strings, or as an
array of objects including image urls. This is done by formatting
content arrays into a simple string.

Example using `llava-hf/llava-v1.6-mistral-7b-hf` 

```bash
curl localhost: 3000/v1/chat/completions \
-X POST \
-H 'Content-Type: application/json' \
-d '{
    "model": "tgi",
    "messages": [
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": "Whats in this image?"
                },
                {
                    "type": "image_url",
                    "image_url": {
                        "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png"
                    }
                }
            ]
        }
    ],
    "stream": false,
    "max_tokens": 20,
    "seed": 42
}'
```

is equivlant to this more simple request

```bash
curl localhost: 3000/v1/chat/completions \
-X POST \
-H 'Content-Type: application/json' \
-d '{
    "model": "tgi",
    "messages": [
        {
            "role": "user",
            "content": "Whats in this image?\n![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png)"
        }
    ],
    "stream": false,
    "max_tokens": 20,
    "seed": 42
}'
```

output
```
# {"id":"","object":"text_completion","created":1714406985,"model":"llava-hf/llava-v1.6-mistral-7b-hf","system_fingerprint":"2.0.1-native","choices":[{"index":0,"message":{"role":"assistant","content":" This is an illustration of an anthropomorphic rabbit in a spacesuit, standing on what"},"logprobs":null,"finish_reason":"length"}],"usage":{"prompt_tokens":2945,"completion_tokens":20,"total_tokens":2965}}%
```

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
drbh 2024-04-30 06:18:32 -04:00 committed by GitHub
parent b2c982750a
commit c99ecd77ec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 64 additions and 1 deletions

View File

@ -884,12 +884,75 @@ pub(crate) struct ToolCall {
pub function: FunctionDefinition, pub function: FunctionDefinition,
} }
#[derive(Clone, Deserialize, ToSchema, Serialize)] #[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)]
pub(crate) struct Text {
#[serde(default)]
pub text: String,
}
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)]
pub(crate) struct ImageUrl {
#[serde(default)]
pub url: String,
}
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)]
pub(crate) struct Content {
pub r#type: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub text: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub image_url: Option<ImageUrl>,
}
mod message_content_serde {
use super::*;
use serde::de;
use serde::Deserializer;
use serde_json::Value;
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<String>, D::Error>
where
D: Deserializer<'de>,
{
let value = Value::deserialize(deserializer)?;
match value {
Value::String(s) => Ok(Some(s)),
Value::Array(arr) => {
let results: Result<Vec<String>, _> = arr
.into_iter()
.map(|v| {
let content: Content =
serde_json::from_value(v).map_err(de::Error::custom)?;
match content.r#type.as_str() {
"text" => Ok(content.text.unwrap_or_default()),
"image_url" => {
if let Some(url) = content.image_url {
Ok(format!("![]({})", url.url))
} else {
Ok(String::new())
}
}
_ => Err(de::Error::custom("invalid content type")),
}
})
.collect();
results.map(|strings| Some(strings.join("")))
}
Value::Null => Ok(None),
_ => Err(de::Error::custom("invalid token format")),
}
}
}
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug)]
pub(crate) struct Message { pub(crate) struct Message {
#[schema(example = "user")] #[schema(example = "user")]
pub role: String, pub role: String,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
#[schema(example = "My name is David and I")] #[schema(example = "My name is David and I")]
#[serde(deserialize_with = "message_content_serde::deserialize")]
pub content: Option<String>, pub content: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
#[schema(example = "\"David\"")] #[schema(example = "\"David\"")]