feat: accept list as prompt and use first string (#1702)

This PR allows the `CompletionRequest.prompt` to be sent as a string or
array of strings. When an array is sent the first value will be used if
it's a string; otherwise the according error will be thrown

Fixes:
https://github.com/huggingface/text-generation-inference/issues/1690
Similar to: https://github.com/vllm-project/vllm/pull/323/files
This commit is contained in:
drbh 2024-04-17 04:41:12 -04:00 committed by GitHub
parent e4d31a40db
commit 06c3d4b1ec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 1186 additions and 105 deletions

View File

@ -59,6 +59,17 @@ class ChatCompletionComplete(BaseModel):
usage: Optional[Any] = None usage: Optional[Any] = None
class CompletionComplete(BaseModel):
# Index of the chat completion
index: int
# Message associated with the chat completion
text: str
# Log probabilities for the chat completion
logprobs: Optional[Any]
# Reason for completion
finish_reason: str
class Function(BaseModel): class Function(BaseModel):
name: Optional[str] name: Optional[str]
arguments: str arguments: str
@ -104,6 +115,16 @@ class ChatComplete(BaseModel):
usage: Any usage: Any
class Completion(BaseModel):
# Completion details
id: str
object: str
created: int
model: str
system_fingerprint: str
choices: List[CompletionComplete]
class ChatRequest(BaseModel): class ChatRequest(BaseModel):
# Model identifier # Model identifier
model: str model: str

View File

@ -398,6 +398,15 @@ Options:
-e, --env -e, --env
Display a lot of information about your runtime environment Display a lot of information about your runtime environment
```
## MAX_CLIENT_BATCH_SIZE
```shell
--max-client-batch-size <MAX_CLIENT_BATCH_SIZE>
Control the maximum number of inputs that a client can send in a single request
[env: MAX_CLIENT_BATCH_SIZE=]
[default: 4]
``` ```
## HELP ## HELP
```shell ```shell

View File

@ -9,6 +9,7 @@ import json
import math import math
import time import time
import random import random
import re
from docker.errors import NotFound from docker.errors import NotFound
from typing import Optional, List, Dict from typing import Optional, List, Dict
@ -26,6 +27,7 @@ from text_generation.types import (
ChatComplete, ChatComplete,
ChatCompletionChunk, ChatCompletionChunk,
ChatCompletionComplete, ChatCompletionComplete,
Completion,
) )
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None) DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
@ -69,17 +71,22 @@ class ResponseComparator(JSONSnapshotExtension):
data = json.loads(data) data = json.loads(data)
if isinstance(data, Dict) and "choices" in data: if isinstance(data, Dict) and "choices" in data:
choices = data["choices"] choices = data["choices"]
if ( if isinstance(choices, List) and len(choices) >= 1:
isinstance(choices, List) if "delta" in choices[0]:
and len(choices) >= 1
and "delta" in choices[0]
):
return ChatCompletionChunk(**data) return ChatCompletionChunk(**data)
if "text" in choices[0]:
return Completion(**data)
return ChatComplete(**data) return ChatComplete(**data)
if isinstance(data, Dict): if isinstance(data, Dict):
return Response(**data) return Response(**data)
if isinstance(data, List): if isinstance(data, List):
if (
len(data) > 0
and "object" in data[0]
and data[0]["object"] == "text_completion"
):
return [Completion(**d) for d in data]
return [Response(**d) for d in data] return [Response(**d) for d in data]
raise NotImplementedError raise NotImplementedError
@ -161,6 +168,9 @@ class ResponseComparator(JSONSnapshotExtension):
) )
) )
def eq_completion(response: Completion, other: Completion) -> bool:
return response.choices[0].text == other.choices[0].text
def eq_chat_complete(response: ChatComplete, other: ChatComplete) -> bool: def eq_chat_complete(response: ChatComplete, other: ChatComplete) -> bool:
return ( return (
response.choices[0].message.content == other.choices[0].message.content response.choices[0].message.content == other.choices[0].message.content
@ -184,6 +194,11 @@ class ResponseComparator(JSONSnapshotExtension):
if not isinstance(snapshot_data, List): if not isinstance(snapshot_data, List):
snapshot_data = [snapshot_data] snapshot_data = [snapshot_data]
if isinstance(serialized_data[0], Completion):
return len(snapshot_data) == len(serialized_data) and all(
[eq_completion(r, o) for r, o in zip(serialized_data, snapshot_data)]
)
if isinstance(serialized_data[0], ChatComplete): if isinstance(serialized_data[0], ChatComplete):
return len(snapshot_data) == len(serialized_data) and all( return len(snapshot_data) == len(serialized_data) and all(
[eq_chat_complete(r, o) for r, o in zip(serialized_data, snapshot_data)] [eq_chat_complete(r, o) for r, o in zip(serialized_data, snapshot_data)]

View File

@ -0,0 +1,38 @@
{
"choices": [
{
"finish_reason": "eos_token",
"index": 1,
"logprobs": null,
"text": " PR for more information?"
},
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"text": "le Business Incubator is providing a workspace"
},
{
"finish_reason": "length",
"index": 2,
"logprobs": null,
"text": " severely flawed and often has a substandard"
},
{
"finish_reason": "length",
"index": 3,
"logprobs": null,
"text": "hd20220811-"
}
],
"created": 1713284455,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native",
"usage": {
"completion_tokens": 36,
"prompt_tokens": 8,
"total_tokens": 44
}
}

View File

@ -0,0 +1,602 @@
[
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": "\n"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": "\n"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": "\n"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "hd"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": "\n"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": "\n"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": "\n"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "aho"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": "2"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": "2"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": "2"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "ima"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": "."
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": "."
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": "."
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "\n"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": " Sarah"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": " Yes"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": " And"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "i"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": "'"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": ","
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": " what"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "'"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": "s"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": " Moh"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": " is"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": "m"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": " Room"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": "s"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": " the"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": " tired"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": ":"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": "'"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": " capital"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": " of"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 0,
"logprobs": null,
"text": " She"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 1,
"logprobs": null,
"text": " scale"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 2,
"logprobs": null,
"text": " of"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
},
{
"choices": [
{
"finish_reason": "",
"index": 3,
"logprobs": null,
"text": " being"
}
],
"created": 1713284431,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native"
}
]

View File

@ -0,0 +1,20 @@
{
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"text": " PR for flake8"
}
],
"created": 1713284454,
"id": "",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"object": "text_completion",
"system_fingerprint": "2.0.0-native",
"usage": {
"completion_tokens": 5,
"prompt_tokens": 6,
"total_tokens": 11
}
}

View File

@ -0,0 +1,109 @@
import pytest
import requests
import json
from aiohttp import ClientSession
from text_generation.types import (
Completion,
)
@pytest.fixture(scope="module")
def flash_llama_completion_handle(launcher):
with launcher(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_llama_completion(flash_llama_completion_handle):
await flash_llama_completion_handle.health(300)
return flash_llama_completion_handle.client
# NOTE: since `v1/completions` is a deprecated inferface/endpoint we do not provide a convience
# method for it. Instead, we use the `requests` library to make the HTTP request directly.
def test_flash_llama_completion_single_prompt(
flash_llama_completion, response_snapshot
):
response = requests.post(
f"{flash_llama_completion.base_url}/v1/completions",
json={
"model": "tgi",
"prompt": "Say this is a test",
"max_tokens": 5,
"seed": 0,
},
headers=flash_llama_completion.headers,
stream=False,
)
response = response.json()
assert len(response["choices"]) == 1
assert response == response_snapshot
def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot):
response = requests.post(
f"{flash_llama_completion.base_url}/v1/completions",
json={
"model": "tgi",
"prompt": ["Say", "this", "is", "a"],
"max_tokens": 10,
"seed": 0,
},
headers=flash_llama_completion.headers,
stream=False,
)
response = response.json()
assert len(response["choices"]) == 4
all_indexes = [choice["index"] for choice in response["choices"]]
all_indexes.sort()
assert all_indexes == [0, 1, 2, 3]
assert response == response_snapshot
async def test_flash_llama_completion_many_prompts_stream(
flash_llama_completion, response_snapshot
):
request = {
"model": "tgi",
"prompt": [
"What color is the sky?",
"Is water wet?",
"What is the capital of France?",
"def mai",
],
"max_tokens": 10,
"seed": 0,
"stream": True,
}
url = f"{flash_llama_completion.base_url}/v1/completions"
chunks = []
async with ClientSession(headers=flash_llama_completion.headers) as session:
async with session.post(url, json=request) as response:
# iterate over the stream
async for chunk in response.content.iter_any():
# remove "data:"
chunk = chunk.decode().split("\n\n")
# remove "data:" if present
chunk = [c.replace("data:", "") for c in chunk]
# remove empty strings
chunk = [c for c in chunk if c]
# parse json
chunk = [json.loads(c) for c in chunk]
for c in chunk:
chunks.append(Completion(**c))
assert "choices" in c
assert 0 <= c["choices"][0]["index"] <= 4
assert response.status == 200
assert chunks == response_snapshot

View File

@ -414,6 +414,10 @@ struct Args {
/// Display a lot of information about your runtime environment /// Display a lot of information about your runtime environment
#[clap(long, short, action)] #[clap(long, short, action)]
env: bool, env: bool,
/// Control the maximum number of inputs that a client can send in a single request
#[clap(default_value = "4", long, env)]
max_client_batch_size: usize,
} }
#[derive(Debug)] #[derive(Debug)]
@ -1078,6 +1082,8 @@ fn spawn_webserver(
// Start webserver // Start webserver
tracing::info!("Starting Webserver"); tracing::info!("Starting Webserver");
let mut router_args = vec![ let mut router_args = vec![
"--max-client-batch-size".to_string(),
args.max_client_batch_size.to_string(),
"--max-concurrent-requests".to_string(), "--max-concurrent-requests".to_string(),
args.max_concurrent_requests.to_string(), args.max_concurrent_requests.to_string(),
"--max-best-of".to_string(), "--max-best-of".to_string(),

View File

@ -155,6 +155,8 @@ pub struct Info {
pub max_batch_size: Option<usize>, pub max_batch_size: Option<usize>,
#[schema(example = "2")] #[schema(example = "2")]
pub validation_workers: usize, pub validation_workers: usize,
#[schema(example = "32")]
pub max_client_batch_size: usize,
/// Router Info /// Router Info
#[schema(example = "0.5.0")] #[schema(example = "0.5.0")]
pub version: &'static str, pub version: &'static str,
@ -280,6 +282,34 @@ fn default_parameters() -> GenerateParameters {
} }
} }
mod prompt_serde {
use serde::{self, Deserialize, Deserializer};
use serde_json::Value;
pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
where
D: Deserializer<'de>,
{
let value = Value::deserialize(deserializer)?;
match value {
Value::String(s) => Ok(vec![s]),
Value::Array(arr) if arr.is_empty() => Err(serde::de::Error::custom(
"Empty array detected. Do not use an empty array for the prompt.",
)),
Value::Array(arr) => arr
.iter()
.map(|v| match v {
Value::String(s) => Ok(s.to_owned()),
_ => Err(serde::de::Error::custom("Expected a string")),
})
.collect(),
_ => Err(serde::de::Error::custom(
"Expected a string or an array of strings",
)),
}
}
}
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)] #[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]
pub struct CompletionRequest { pub struct CompletionRequest {
/// UNUSED /// UNUSED
@ -289,7 +319,8 @@ pub struct CompletionRequest {
/// The prompt to generate completions for. /// The prompt to generate completions for.
#[schema(example = "What is Deep Learning?")] #[schema(example = "What is Deep Learning?")]
pub prompt: String, #[serde(deserialize_with = "prompt_serde::deserialize")]
pub prompt: Vec<String>,
/// The maximum number of tokens that can be generated in the chat completion. /// The maximum number of tokens that can be generated in the chat completion.
#[serde(default)] #[serde(default)]

View File

@ -78,6 +78,8 @@ struct Args {
messages_api_enabled: bool, messages_api_enabled: bool,
#[clap(long, env, default_value_t = false)] #[clap(long, env, default_value_t = false)]
disable_grammar_support: bool, disable_grammar_support: bool,
#[clap(default_value = "4", long, env)]
max_client_batch_size: usize,
} }
#[tokio::main] #[tokio::main]
@ -112,6 +114,7 @@ async fn main() -> Result<(), RouterError> {
ngrok_edge, ngrok_edge,
messages_api_enabled, messages_api_enabled,
disable_grammar_support, disable_grammar_support,
max_client_batch_size,
} = args; } = args;
// Launch Tokio runtime // Launch Tokio runtime
@ -393,6 +396,7 @@ async fn main() -> Result<(), RouterError> {
tokenizer_config, tokenizer_config,
messages_api_enabled, messages_api_enabled,
disable_grammar_support, disable_grammar_support,
max_client_batch_size,
) )
.await?; .await?;
Ok(()) Ok(())

View File

@ -16,6 +16,7 @@ use crate::{
CompletionRequest, DeltaToolCall, Function, Tool, VertexRequest, VertexResponse, CompletionRequest, DeltaToolCall, Function, Tool, VertexRequest, VertexResponse,
}; };
use crate::{FunctionDefinition, ToolCall, ToolType}; use crate::{FunctionDefinition, ToolCall, ToolType};
use async_stream::__private::AsyncStream;
use axum::extract::Extension; use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode}; use axum::http::{HeaderMap, Method, StatusCode};
use axum::response::sse::{Event, KeepAlive, Sse}; use axum::response::sse::{Event, KeepAlive, Sse};
@ -23,8 +24,8 @@ use axum::response::{IntoResponse, Response};
use axum::routing::{get, post}; use axum::routing::{get, post};
use axum::{http, Json, Router}; use axum::{http, Json, Router};
use axum_tracing_opentelemetry::middleware::OtelAxumLayer; use axum_tracing_opentelemetry::middleware::OtelAxumLayer;
use futures::stream::FuturesUnordered;
use futures::stream::StreamExt; use futures::stream::StreamExt;
use futures::stream::{FuturesOrdered, FuturesUnordered};
use futures::Stream; use futures::Stream;
use futures::TryStreamExt; use futures::TryStreamExt;
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
@ -35,7 +36,9 @@ use std::sync::atomic::AtomicBool;
use std::sync::Arc; use std::sync::Arc;
use text_generation_client::{ShardInfo, ShardedClient}; use text_generation_client::{ShardInfo, ShardedClient};
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use tokio::select;
use tokio::signal; use tokio::signal;
use tokio::sync::oneshot;
use tokio::time::Instant; use tokio::time::Instant;
use tower_http::cors::{AllowOrigin, CorsLayer}; use tower_http::cors::{AllowOrigin, CorsLayer};
use tracing::{info_span, instrument, Instrument}; use tracing::{info_span, instrument, Instrument};
@ -161,6 +164,15 @@ async fn generate(
Json(req): Json<GenerateRequest>, Json(req): Json<GenerateRequest>,
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> { ) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current(); let span = tracing::Span::current();
generate_internal(infer, ComputeType(compute_type), Json(req), span).await
}
async fn generate_internal(
infer: Extension<Infer>,
ComputeType(compute_type): ComputeType,
Json(req): Json<GenerateRequest>,
span: tracing::Span,
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
let start_time = Instant::now(); let start_time = Instant::now();
metrics::increment_counter!("tgi_request_count"); metrics::increment_counter!("tgi_request_count");
@ -359,12 +371,13 @@ async fn generate_stream(
HeaderMap, HeaderMap,
Sse<impl Stream<Item = Result<Event, Infallible>>>, Sse<impl Stream<Item = Result<Event, Infallible>>>,
) { ) {
let span = tracing::Span::current();
let on_message_callback = |stream_token: StreamResponse| { let on_message_callback = |stream_token: StreamResponse| {
let event = Event::default(); let event = Event::default();
event.json_data(stream_token).unwrap() event.json_data(stream_token).unwrap()
}; };
let (headers, response_stream) = let (headers, response_stream) =
generate_stream_internal(infer, compute_type, Json(req), on_message_callback).await; generate_stream_internal(infer, compute_type, Json(req), on_message_callback, span).await;
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
(headers, sse) (headers, sse)
} }
@ -374,8 +387,8 @@ async fn generate_stream_internal(
ComputeType(compute_type): ComputeType, ComputeType(compute_type): ComputeType,
Json(req): Json<GenerateRequest>, Json(req): Json<GenerateRequest>,
on_message_callback: impl Fn(StreamResponse) -> Event, on_message_callback: impl Fn(StreamResponse) -> Event,
span: tracing::Span,
) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) { ) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) {
let span = tracing::Span::current();
let start_time = Instant::now(); let start_time = Instant::now();
metrics::increment_counter!("tgi_request_count"); metrics::increment_counter!("tgi_request_count");
@ -581,6 +594,7 @@ async fn completions(
Extension(info): Extension<Info>, Extension(info): Extension<Info>,
Json(req): Json<CompletionRequest>, Json(req): Json<CompletionRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current();
metrics::increment_counter!("tgi_request_count"); metrics::increment_counter!("tgi_request_count");
let stream = req.stream; let stream = req.stream;
@ -600,9 +614,25 @@ async fn completions(
)); ));
} }
// build the request passing some parameters if req.prompt.len() > info.max_client_batch_size {
let generate_request = GenerateRequest { metrics::increment_counter!("tgi_request_failure", "err" => "validation");
inputs: req.prompt.to_string(), return Err((
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
error: format!(
"Number of prompts exceeds the maximum allowed batch size of {}",
info.max_client_batch_size
),
error_type: "batch size exceeded".to_string(),
}),
));
}
let generate_requests: Vec<GenerateRequest> = req
.prompt
.iter()
.map(|prompt| GenerateRequest {
inputs: prompt.to_string(),
parameters: GenerateParameters { parameters: GenerateParameters {
best_of: None, best_of: None,
temperature: req.temperature, temperature: req.temperature,
@ -623,9 +653,25 @@ async fn completions(
top_n_tokens: None, top_n_tokens: None,
grammar: None, grammar: None,
}, },
}; })
.collect();
let mut x_compute_type = None;
let mut x_compute_characters = 0u32;
let mut x_accel_buffering = None;
if stream { if stream {
let mut response_streams = FuturesOrdered::new();
for (index, generate_request) in generate_requests.into_iter().enumerate() {
let model_id = info.model_id.clone();
let system_fingerprint =
format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));
let infer_clone = infer.clone();
let compute_type_clone = compute_type.clone();
let span_clone = span.clone();
// Create a future for each generate_stream_internal call.
let generate_future = async move {
let on_message_callback = move |stream_token: StreamResponse| { let on_message_callback = move |stream_token: StreamResponse| {
let event = Event::default(); let event = Event::default();
@ -642,50 +688,158 @@ async fn completions(
choices: vec![CompletionComplete { choices: vec![CompletionComplete {
finish_reason: "".to_string(), finish_reason: "".to_string(),
index: 0, index: index as u32,
logprobs: None, logprobs: None,
text: stream_token.token.text, text: stream_token.token.text,
}], }],
model: info.model_id.clone(), model: model_id.clone(),
system_fingerprint: format!( system_fingerprint: system_fingerprint.clone(),
"{}-{}",
info.version,
info.docker_label.unwrap_or("native")
),
}) })
.map_or_else( .map_or_else(|_e| Event::default(), |data| data)
|e| {
println!("Failed to serialize CompletionCompleteChunk: {:?}", e);
Event::default()
},
|data| data,
)
}; };
let (headers, response_stream) = generate_stream_internal( let (header_tx, header_rx) = oneshot::channel();
infer, let (sse_tx, sse_rx) = tokio::sync::mpsc::unbounded_channel();
compute_type,
tokio::spawn(async move {
let (header_map, sse) = generate_stream_internal(
infer_clone.clone(),
compute_type_clone.clone(),
Json(generate_request), Json(generate_request),
on_message_callback, on_message_callback,
span_clone.clone(),
) )
.await; .await;
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); // send and dont wait for response
let _ = header_tx.send(header_map);
// pin an emit messages to the sse_tx
let mut sse = Box::pin(sse);
while let Some(event) = sse.next().await {
if sse_tx.send(event).is_err() {
tracing::error!("Failed to send event. Receiver dropped.");
break;
}
}
});
(header_rx, sse_rx)
};
response_streams.push_back(generate_future);
}
let mut all_rxs = vec![];
while let Some((header_rx, sse_rx)) = response_streams.next().await {
all_rxs.push(sse_rx);
// get the headers from the first response of each stream
let headers = header_rx.await.map_err(|e| {
tracing::error!("Failed to get headers: {:?}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "Failed to get headers".to_string(),
error_type: "headers".to_string(),
}),
)
})?;
if x_compute_type.is_none() {
x_compute_type = headers
.get("x-compute-type")
.and_then(|v| v.to_str().ok())
.map(|v| v.to_string());
x_accel_buffering = headers
.get("x-accel-buffering")
.and_then(|v| v.to_str().ok())
.map(|v| v.to_string());
}
x_compute_characters += headers
.get("x-compute-characters")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse().ok())
.unwrap_or(0);
}
let mut headers = HeaderMap::new();
if let Some(x_compute_type) = x_compute_type {
headers.insert("x-compute-type", x_compute_type.parse().unwrap());
}
headers.insert("x-compute-characters", x_compute_characters.into());
if let Some(x_accel_buffering) = x_accel_buffering {
headers.insert("x-accel-buffering", x_accel_buffering.parse().unwrap());
}
// now sink the sse streams into a single stream and remove the ones that are done
let stream: AsyncStream<Result<Event, Infallible>, _> = async_stream::stream! {
loop {
let mut i = 0;
while i < all_rxs.len() {
let rx = &mut all_rxs[i];
select! {
Some(event) = rx.recv() => {
yield event;
}
else => {
all_rxs.remove(i);
continue; // skip the increment to handle the next element at the same index
}
}
i += 1; // only increment when no element was removed
}
if all_rxs.is_empty() {
break;
}
}
};
let sse = Sse::new(stream).keep_alive(KeepAlive::default());
Ok((headers, sse).into_response()) Ok((headers, sse).into_response())
} else { } else {
let (headers, Json(generation)) = generate(
Extension(infer),
Extension(compute_type),
Json(generate_request),
)
.await?;
let current_time = std::time::SystemTime::now() let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH) .duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(0)) .unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs(); .as_secs();
let responses = FuturesUnordered::new();
for (index, generate_request) in generate_requests.into_iter().enumerate() {
let infer_clone = infer.clone();
let compute_type_clone = compute_type.clone();
let span_clone = span.clone();
let response_future = async move {
let result = generate_internal(
Extension(infer_clone),
compute_type_clone,
Json(generate_request),
span_clone,
)
.await;
result.map(|(headers, generation)| (index, headers, generation))
};
responses.push(response_future);
}
let generate_responses = responses.try_collect::<Vec<_>>().await?;
let mut prompt_tokens = 0u32;
let mut completion_tokens = 0u32;
let mut total_tokens = 0u32;
let mut x_compute_time = 0u32;
let mut x_total_time = 0u32;
let mut x_validation_time = 0u32;
let mut x_queue_time = 0u32;
let mut x_inference_time = 0u32;
let mut x_time_per_token = 0u32;
let mut x_prompt_tokens = 0u32;
let mut x_generated_tokens = 0u32;
let choices = generate_responses
.into_iter()
.map(|(index, headers, Json(generation))| {
let details = generation.details.ok_or(( let details = generation.details.ok_or((
// this should never happen but handle if details are missing unexpectedly // this should never happen but handle if details are missing unexpectedly
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
@ -695,6 +849,65 @@ async fn completions(
}), }),
))?; ))?;
if x_compute_type.is_none() {
x_compute_type = headers
.get("x-compute-type")
.and_then(|v| v.to_str().ok())
.map(|v| v.to_string());
}
// accumulate headers and usage from each response
x_compute_time += headers
.get("x-compute-time")
.and_then(|v| v.to_str().ok()?.parse().ok())
.unwrap_or(0);
x_compute_characters += headers
.get("x-compute-characters")
.and_then(|v| v.to_str().ok()?.parse().ok())
.unwrap_or(0);
x_total_time += headers
.get("x-total-time")
.and_then(|v| v.to_str().ok()?.parse().ok())
.unwrap_or(0);
x_validation_time += headers
.get("x-validation-time")
.and_then(|v| v.to_str().ok()?.parse().ok())
.unwrap_or(0);
x_queue_time += headers
.get("x-queue-time")
.and_then(|v| v.to_str().ok()?.parse().ok())
.unwrap_or(0);
x_inference_time += headers
.get("x-inference-time")
.and_then(|v| v.to_str().ok()?.parse().ok())
.unwrap_or(0);
x_time_per_token += headers
.get("x-time-per-token")
.and_then(|v| v.to_str().ok()?.parse().ok())
.unwrap_or(0);
x_prompt_tokens += headers
.get("x-prompt-tokens")
.and_then(|v| v.to_str().ok()?.parse().ok())
.unwrap_or(0);
x_generated_tokens += headers
.get("x-generated-tokens")
.and_then(|v| v.to_str().ok()?.parse().ok())
.unwrap_or(0);
prompt_tokens += details.prefill.len() as u32;
completion_tokens += details.generated_tokens;
total_tokens += details.prefill.len() as u32 + details.generated_tokens;
Ok(CompletionComplete {
finish_reason: details.finish_reason.to_string(),
index: index as u32,
logprobs: None,
text: generation.generated_text,
})
})
.collect::<Result<Vec<_>, _>>()
.map_err(|(status, Json(err))| (status, Json(err)))?;
let response = Completion { let response = Completion {
id: "".to_string(), id: "".to_string(),
object: "text_completion".to_string(), object: "text_completion".to_string(),
@ -705,19 +918,30 @@ async fn completions(
info.version, info.version,
info.docker_label.unwrap_or("native") info.docker_label.unwrap_or("native")
), ),
choices: vec![CompletionComplete { choices,
finish_reason: details.finish_reason.to_string(),
index: 0,
logprobs: None,
text: generation.generated_text,
}],
usage: Usage { usage: Usage {
prompt_tokens: details.prefill.len() as u32, prompt_tokens,
completion_tokens: details.generated_tokens, completion_tokens,
total_tokens: details.prefill.len() as u32 + details.generated_tokens, total_tokens,
}, },
}; };
// headers similar to `generate` but aggregated
let mut headers = HeaderMap::new();
if let Some(x_compute_type) = x_compute_type {
headers.insert("x-compute-type", x_compute_type.parse().unwrap());
}
headers.insert("x-compute-characters", x_compute_characters.into());
headers.insert("x-total-time", x_total_time.into());
headers.insert("x-validation-time", x_validation_time.into());
headers.insert("x-queue-time", x_queue_time.into());
headers.insert("x-inference-time", x_inference_time.into());
headers.insert("x-time-per-token", x_time_per_token.into());
headers.insert("x-prompt-tokens", x_prompt_tokens.into());
headers.insert("x-generated-tokens", x_generated_tokens.into());
if let Some(x_accel_buffering) = x_accel_buffering {
headers.insert("x-accel-buffering", x_accel_buffering.parse().unwrap());
}
Ok((headers, Json(response)).into_response()) Ok((headers, Json(response)).into_response())
} }
} }
@ -762,6 +986,7 @@ async fn chat_completions(
Extension(info): Extension<Info>, Extension(info): Extension<Info>,
Json(req): Json<ChatRequest>, Json(req): Json<ChatRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current();
metrics::increment_counter!("tgi_request_count"); metrics::increment_counter!("tgi_request_count");
let ChatRequest { let ChatRequest {
@ -899,17 +1124,14 @@ async fn chat_completions(
compute_type, compute_type,
Json(generate_request), Json(generate_request),
on_message_callback, on_message_callback,
span,
) )
.await; .await;
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
Ok((headers, sse).into_response()) Ok((headers, sse).into_response())
} else { } else {
let (headers, Json(generation)) = generate( let (headers, Json(generation)) =
Extension(infer), generate_internal(Extension(infer), compute_type, Json(generate_request), span).await?;
Extension(compute_type),
Json(generate_request),
)
.await?;
let current_time = std::time::SystemTime::now() let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH) .duration_since(std::time::UNIX_EPOCH)
@ -1006,6 +1228,7 @@ async fn vertex_compatibility(
Extension(compute_type): Extension<ComputeType>, Extension(compute_type): Extension<ComputeType>,
Json(req): Json<VertexRequest>, Json(req): Json<VertexRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current();
metrics::increment_counter!("tgi_request_count"); metrics::increment_counter!("tgi_request_count");
// check that theres at least one instance // check that theres at least one instance
@ -1037,10 +1260,11 @@ async fn vertex_compatibility(
}; };
async { async {
generate( generate_internal(
Extension(infer.clone()), Extension(infer.clone()),
Extension(compute_type.clone()), compute_type.clone(),
Json(generate_request), Json(generate_request),
span.clone(),
) )
.await .await
.map(|(_, Json(generation))| generation.generated_text) .map(|(_, Json(generation))| generation.generated_text)
@ -1152,6 +1376,7 @@ pub async fn run(
tokenizer_config: HubTokenizerConfig, tokenizer_config: HubTokenizerConfig,
messages_api_enabled: bool, messages_api_enabled: bool,
grammar_support: bool, grammar_support: bool,
max_client_batch_size: usize,
) -> Result<(), axum::BoxError> { ) -> Result<(), axum::BoxError> {
// OpenAPI documentation // OpenAPI documentation
#[derive(OpenApi)] #[derive(OpenApi)]
@ -1326,6 +1551,7 @@ pub async fn run(
max_waiting_tokens, max_waiting_tokens,
max_batch_size, max_batch_size,
validation_workers, validation_workers,
max_client_batch_size,
version: env!("CARGO_PKG_VERSION"), version: env!("CARGO_PKG_VERSION"),
sha: option_env!("VERGEN_GIT_SHA"), sha: option_env!("VERGEN_GIT_SHA"),
docker_label: option_env!("DOCKER_LABEL"), docker_label: option_env!("DOCKER_LABEL"),