feat: accept list as prompt and use first string (#1702)
This PR allows the `CompletionRequest.prompt` to be sent as a string or array of strings. When an array is sent the first value will be used if it's a string; otherwise the according error will be thrown Fixes: https://github.com/huggingface/text-generation-inference/issues/1690 Similar to: https://github.com/vllm-project/vllm/pull/323/files
This commit is contained in:
parent
e4d31a40db
commit
06c3d4b1ec
|
@ -59,6 +59,17 @@ class ChatCompletionComplete(BaseModel):
|
||||||
usage: Optional[Any] = None
|
usage: Optional[Any] = None
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionComplete(BaseModel):
|
||||||
|
# Index of the chat completion
|
||||||
|
index: int
|
||||||
|
# Message associated with the chat completion
|
||||||
|
text: str
|
||||||
|
# Log probabilities for the chat completion
|
||||||
|
logprobs: Optional[Any]
|
||||||
|
# Reason for completion
|
||||||
|
finish_reason: str
|
||||||
|
|
||||||
|
|
||||||
class Function(BaseModel):
|
class Function(BaseModel):
|
||||||
name: Optional[str]
|
name: Optional[str]
|
||||||
arguments: str
|
arguments: str
|
||||||
|
@ -104,6 +115,16 @@ class ChatComplete(BaseModel):
|
||||||
usage: Any
|
usage: Any
|
||||||
|
|
||||||
|
|
||||||
|
class Completion(BaseModel):
|
||||||
|
# Completion details
|
||||||
|
id: str
|
||||||
|
object: str
|
||||||
|
created: int
|
||||||
|
model: str
|
||||||
|
system_fingerprint: str
|
||||||
|
choices: List[CompletionComplete]
|
||||||
|
|
||||||
|
|
||||||
class ChatRequest(BaseModel):
|
class ChatRequest(BaseModel):
|
||||||
# Model identifier
|
# Model identifier
|
||||||
model: str
|
model: str
|
||||||
|
|
|
@ -398,6 +398,15 @@ Options:
|
||||||
-e, --env
|
-e, --env
|
||||||
Display a lot of information about your runtime environment
|
Display a lot of information about your runtime environment
|
||||||
|
|
||||||
|
```
|
||||||
|
## MAX_CLIENT_BATCH_SIZE
|
||||||
|
```shell
|
||||||
|
--max-client-batch-size <MAX_CLIENT_BATCH_SIZE>
|
||||||
|
Control the maximum number of inputs that a client can send in a single request
|
||||||
|
|
||||||
|
[env: MAX_CLIENT_BATCH_SIZE=]
|
||||||
|
[default: 4]
|
||||||
|
|
||||||
```
|
```
|
||||||
## HELP
|
## HELP
|
||||||
```shell
|
```shell
|
||||||
|
|
|
@ -9,6 +9,7 @@ import json
|
||||||
import math
|
import math
|
||||||
import time
|
import time
|
||||||
import random
|
import random
|
||||||
|
import re
|
||||||
|
|
||||||
from docker.errors import NotFound
|
from docker.errors import NotFound
|
||||||
from typing import Optional, List, Dict
|
from typing import Optional, List, Dict
|
||||||
|
@ -26,6 +27,7 @@ from text_generation.types import (
|
||||||
ChatComplete,
|
ChatComplete,
|
||||||
ChatCompletionChunk,
|
ChatCompletionChunk,
|
||||||
ChatCompletionComplete,
|
ChatCompletionComplete,
|
||||||
|
Completion,
|
||||||
)
|
)
|
||||||
|
|
||||||
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
|
DOCKER_IMAGE = os.getenv("DOCKER_IMAGE", None)
|
||||||
|
@ -69,17 +71,22 @@ class ResponseComparator(JSONSnapshotExtension):
|
||||||
data = json.loads(data)
|
data = json.loads(data)
|
||||||
if isinstance(data, Dict) and "choices" in data:
|
if isinstance(data, Dict) and "choices" in data:
|
||||||
choices = data["choices"]
|
choices = data["choices"]
|
||||||
if (
|
if isinstance(choices, List) and len(choices) >= 1:
|
||||||
isinstance(choices, List)
|
if "delta" in choices[0]:
|
||||||
and len(choices) >= 1
|
|
||||||
and "delta" in choices[0]
|
|
||||||
):
|
|
||||||
return ChatCompletionChunk(**data)
|
return ChatCompletionChunk(**data)
|
||||||
|
if "text" in choices[0]:
|
||||||
|
return Completion(**data)
|
||||||
return ChatComplete(**data)
|
return ChatComplete(**data)
|
||||||
|
|
||||||
if isinstance(data, Dict):
|
if isinstance(data, Dict):
|
||||||
return Response(**data)
|
return Response(**data)
|
||||||
if isinstance(data, List):
|
if isinstance(data, List):
|
||||||
|
if (
|
||||||
|
len(data) > 0
|
||||||
|
and "object" in data[0]
|
||||||
|
and data[0]["object"] == "text_completion"
|
||||||
|
):
|
||||||
|
return [Completion(**d) for d in data]
|
||||||
return [Response(**d) for d in data]
|
return [Response(**d) for d in data]
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -161,6 +168,9 @@ class ResponseComparator(JSONSnapshotExtension):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def eq_completion(response: Completion, other: Completion) -> bool:
|
||||||
|
return response.choices[0].text == other.choices[0].text
|
||||||
|
|
||||||
def eq_chat_complete(response: ChatComplete, other: ChatComplete) -> bool:
|
def eq_chat_complete(response: ChatComplete, other: ChatComplete) -> bool:
|
||||||
return (
|
return (
|
||||||
response.choices[0].message.content == other.choices[0].message.content
|
response.choices[0].message.content == other.choices[0].message.content
|
||||||
|
@ -184,6 +194,11 @@ class ResponseComparator(JSONSnapshotExtension):
|
||||||
if not isinstance(snapshot_data, List):
|
if not isinstance(snapshot_data, List):
|
||||||
snapshot_data = [snapshot_data]
|
snapshot_data = [snapshot_data]
|
||||||
|
|
||||||
|
if isinstance(serialized_data[0], Completion):
|
||||||
|
return len(snapshot_data) == len(serialized_data) and all(
|
||||||
|
[eq_completion(r, o) for r, o in zip(serialized_data, snapshot_data)]
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(serialized_data[0], ChatComplete):
|
if isinstance(serialized_data[0], ChatComplete):
|
||||||
return len(snapshot_data) == len(serialized_data) and all(
|
return len(snapshot_data) == len(serialized_data) and all(
|
||||||
[eq_chat_complete(r, o) for r, o in zip(serialized_data, snapshot_data)]
|
[eq_chat_complete(r, o) for r, o in zip(serialized_data, snapshot_data)]
|
||||||
|
|
|
@ -0,0 +1,38 @@
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "eos_token",
|
||||||
|
"index": 1,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": " PR for more information?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"finish_reason": "length",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": "le Business Incubator is providing a workspace"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"finish_reason": "length",
|
||||||
|
"index": 2,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": " severely flawed and often has a substandard"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"finish_reason": "length",
|
||||||
|
"index": 3,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": "hd20220811-"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1713284455,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "2.0.0-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 36,
|
||||||
|
"prompt_tokens": 8,
|
||||||
|
"total_tokens": 44
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,602 @@
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": "\n"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1713284431,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "2.0.0-native"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "",
|
||||||
|
"index": 1,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": "\n"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1713284431,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "2.0.0-native"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "",
|
||||||
|
"index": 2,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": "\n"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1713284431,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "2.0.0-native"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "",
|
||||||
|
"index": 3,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": "hd"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1713284431,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "2.0.0-native"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": "\n"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1713284431,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "2.0.0-native"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "",
|
||||||
|
"index": 1,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": "\n"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1713284431,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "2.0.0-native"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "",
|
||||||
|
"index": 2,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": "\n"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1713284431,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "2.0.0-native"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "",
|
||||||
|
"index": 3,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": "aho"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1713284431,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "2.0.0-native"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": "2"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1713284431,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "2.0.0-native"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "",
|
||||||
|
"index": 1,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": "2"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1713284431,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "2.0.0-native"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "",
|
||||||
|
"index": 2,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": "2"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1713284431,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "2.0.0-native"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "",
|
||||||
|
"index": 3,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": "ima"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1713284431,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "2.0.0-native"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": "."
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1713284431,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "2.0.0-native"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "",
|
||||||
|
"index": 1,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": "."
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1713284431,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "2.0.0-native"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "",
|
||||||
|
"index": 2,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": "."
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1713284431,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "2.0.0-native"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "",
|
||||||
|
"index": 3,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": "\n"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1713284431,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "2.0.0-native"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": " Sarah"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1713284431,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "2.0.0-native"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "",
|
||||||
|
"index": 1,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": " Yes"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1713284431,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "2.0.0-native"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "",
|
||||||
|
"index": 2,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": " And"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1713284431,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "2.0.0-native"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "",
|
||||||
|
"index": 3,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": "i"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1713284431,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "2.0.0-native"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": "'"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1713284431,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "2.0.0-native"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "",
|
||||||
|
"index": 1,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": ","
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1713284431,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "2.0.0-native"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "",
|
||||||
|
"index": 2,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": " what"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1713284431,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "2.0.0-native"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "",
|
||||||
|
"index": 3,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": "'"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1713284431,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "2.0.0-native"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": "s"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1713284431,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "2.0.0-native"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "",
|
||||||
|
"index": 1,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": " Moh"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1713284431,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "2.0.0-native"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "",
|
||||||
|
"index": 2,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": " is"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1713284431,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "2.0.0-native"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "",
|
||||||
|
"index": 3,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": "m"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1713284431,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "2.0.0-native"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": " Room"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1713284431,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "2.0.0-native"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "",
|
||||||
|
"index": 1,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": "s"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1713284431,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "2.0.0-native"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "",
|
||||||
|
"index": 2,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": " the"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1713284431,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "2.0.0-native"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "",
|
||||||
|
"index": 3,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": " tired"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1713284431,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "2.0.0-native"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": ":"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1713284431,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "2.0.0-native"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "",
|
||||||
|
"index": 1,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": "'"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1713284431,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "2.0.0-native"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "",
|
||||||
|
"index": 2,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": " capital"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1713284431,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "2.0.0-native"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "",
|
||||||
|
"index": 3,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": " of"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1713284431,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "2.0.0-native"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": " She"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1713284431,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "2.0.0-native"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "",
|
||||||
|
"index": 1,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": " scale"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1713284431,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "2.0.0-native"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "",
|
||||||
|
"index": 2,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": " of"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1713284431,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "2.0.0-native"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "",
|
||||||
|
"index": 3,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": " being"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1713284431,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "2.0.0-native"
|
||||||
|
}
|
||||||
|
]
|
|
@ -0,0 +1,20 @@
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "length",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": " PR for flake8"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1713284454,
|
||||||
|
"id": "",
|
||||||
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "2.0.0-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 5,
|
||||||
|
"prompt_tokens": 6,
|
||||||
|
"total_tokens": 11
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,109 @@
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
import json
|
||||||
|
from aiohttp import ClientSession
|
||||||
|
|
||||||
|
from text_generation.types import (
|
||||||
|
Completion,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def flash_llama_completion_handle(launcher):
|
||||||
|
with launcher(
|
||||||
|
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
|
) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def flash_llama_completion(flash_llama_completion_handle):
|
||||||
|
await flash_llama_completion_handle.health(300)
|
||||||
|
return flash_llama_completion_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE: since `v1/completions` is a deprecated inferface/endpoint we do not provide a convience
|
||||||
|
# method for it. Instead, we use the `requests` library to make the HTTP request directly.
|
||||||
|
|
||||||
|
|
||||||
|
def test_flash_llama_completion_single_prompt(
|
||||||
|
flash_llama_completion, response_snapshot
|
||||||
|
):
|
||||||
|
response = requests.post(
|
||||||
|
f"{flash_llama_completion.base_url}/v1/completions",
|
||||||
|
json={
|
||||||
|
"model": "tgi",
|
||||||
|
"prompt": "Say this is a test",
|
||||||
|
"max_tokens": 5,
|
||||||
|
"seed": 0,
|
||||||
|
},
|
||||||
|
headers=flash_llama_completion.headers,
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
response = response.json()
|
||||||
|
assert len(response["choices"]) == 1
|
||||||
|
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot):
|
||||||
|
response = requests.post(
|
||||||
|
f"{flash_llama_completion.base_url}/v1/completions",
|
||||||
|
json={
|
||||||
|
"model": "tgi",
|
||||||
|
"prompt": ["Say", "this", "is", "a"],
|
||||||
|
"max_tokens": 10,
|
||||||
|
"seed": 0,
|
||||||
|
},
|
||||||
|
headers=flash_llama_completion.headers,
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
response = response.json()
|
||||||
|
assert len(response["choices"]) == 4
|
||||||
|
|
||||||
|
all_indexes = [choice["index"] for choice in response["choices"]]
|
||||||
|
all_indexes.sort()
|
||||||
|
assert all_indexes == [0, 1, 2, 3]
|
||||||
|
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
async def test_flash_llama_completion_many_prompts_stream(
|
||||||
|
flash_llama_completion, response_snapshot
|
||||||
|
):
|
||||||
|
request = {
|
||||||
|
"model": "tgi",
|
||||||
|
"prompt": [
|
||||||
|
"What color is the sky?",
|
||||||
|
"Is water wet?",
|
||||||
|
"What is the capital of France?",
|
||||||
|
"def mai",
|
||||||
|
],
|
||||||
|
"max_tokens": 10,
|
||||||
|
"seed": 0,
|
||||||
|
"stream": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
url = f"{flash_llama_completion.base_url}/v1/completions"
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
async with ClientSession(headers=flash_llama_completion.headers) as session:
|
||||||
|
async with session.post(url, json=request) as response:
|
||||||
|
# iterate over the stream
|
||||||
|
async for chunk in response.content.iter_any():
|
||||||
|
# remove "data:"
|
||||||
|
chunk = chunk.decode().split("\n\n")
|
||||||
|
# remove "data:" if present
|
||||||
|
chunk = [c.replace("data:", "") for c in chunk]
|
||||||
|
# remove empty strings
|
||||||
|
chunk = [c for c in chunk if c]
|
||||||
|
# parse json
|
||||||
|
chunk = [json.loads(c) for c in chunk]
|
||||||
|
|
||||||
|
for c in chunk:
|
||||||
|
chunks.append(Completion(**c))
|
||||||
|
assert "choices" in c
|
||||||
|
assert 0 <= c["choices"][0]["index"] <= 4
|
||||||
|
|
||||||
|
assert response.status == 200
|
||||||
|
assert chunks == response_snapshot
|
|
@ -414,6 +414,10 @@ struct Args {
|
||||||
/// Display a lot of information about your runtime environment
|
/// Display a lot of information about your runtime environment
|
||||||
#[clap(long, short, action)]
|
#[clap(long, short, action)]
|
||||||
env: bool,
|
env: bool,
|
||||||
|
|
||||||
|
/// Control the maximum number of inputs that a client can send in a single request
|
||||||
|
#[clap(default_value = "4", long, env)]
|
||||||
|
max_client_batch_size: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
|
@ -1078,6 +1082,8 @@ fn spawn_webserver(
|
||||||
// Start webserver
|
// Start webserver
|
||||||
tracing::info!("Starting Webserver");
|
tracing::info!("Starting Webserver");
|
||||||
let mut router_args = vec![
|
let mut router_args = vec![
|
||||||
|
"--max-client-batch-size".to_string(),
|
||||||
|
args.max_client_batch_size.to_string(),
|
||||||
"--max-concurrent-requests".to_string(),
|
"--max-concurrent-requests".to_string(),
|
||||||
args.max_concurrent_requests.to_string(),
|
args.max_concurrent_requests.to_string(),
|
||||||
"--max-best-of".to_string(),
|
"--max-best-of".to_string(),
|
||||||
|
|
|
@ -155,6 +155,8 @@ pub struct Info {
|
||||||
pub max_batch_size: Option<usize>,
|
pub max_batch_size: Option<usize>,
|
||||||
#[schema(example = "2")]
|
#[schema(example = "2")]
|
||||||
pub validation_workers: usize,
|
pub validation_workers: usize,
|
||||||
|
#[schema(example = "32")]
|
||||||
|
pub max_client_batch_size: usize,
|
||||||
/// Router Info
|
/// Router Info
|
||||||
#[schema(example = "0.5.0")]
|
#[schema(example = "0.5.0")]
|
||||||
pub version: &'static str,
|
pub version: &'static str,
|
||||||
|
@ -280,6 +282,34 @@ fn default_parameters() -> GenerateParameters {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mod prompt_serde {
|
||||||
|
use serde::{self, Deserialize, Deserializer};
|
||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
|
pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
|
||||||
|
where
|
||||||
|
D: Deserializer<'de>,
|
||||||
|
{
|
||||||
|
let value = Value::deserialize(deserializer)?;
|
||||||
|
match value {
|
||||||
|
Value::String(s) => Ok(vec![s]),
|
||||||
|
Value::Array(arr) if arr.is_empty() => Err(serde::de::Error::custom(
|
||||||
|
"Empty array detected. Do not use an empty array for the prompt.",
|
||||||
|
)),
|
||||||
|
Value::Array(arr) => arr
|
||||||
|
.iter()
|
||||||
|
.map(|v| match v {
|
||||||
|
Value::String(s) => Ok(s.to_owned()),
|
||||||
|
_ => Err(serde::de::Error::custom("Expected a string")),
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
_ => Err(serde::de::Error::custom(
|
||||||
|
"Expected a string or an array of strings",
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]
|
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)]
|
||||||
pub struct CompletionRequest {
|
pub struct CompletionRequest {
|
||||||
/// UNUSED
|
/// UNUSED
|
||||||
|
@ -289,7 +319,8 @@ pub struct CompletionRequest {
|
||||||
|
|
||||||
/// The prompt to generate completions for.
|
/// The prompt to generate completions for.
|
||||||
#[schema(example = "What is Deep Learning?")]
|
#[schema(example = "What is Deep Learning?")]
|
||||||
pub prompt: String,
|
#[serde(deserialize_with = "prompt_serde::deserialize")]
|
||||||
|
pub prompt: Vec<String>,
|
||||||
|
|
||||||
/// The maximum number of tokens that can be generated in the chat completion.
|
/// The maximum number of tokens that can be generated in the chat completion.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
|
|
@ -78,6 +78,8 @@ struct Args {
|
||||||
messages_api_enabled: bool,
|
messages_api_enabled: bool,
|
||||||
#[clap(long, env, default_value_t = false)]
|
#[clap(long, env, default_value_t = false)]
|
||||||
disable_grammar_support: bool,
|
disable_grammar_support: bool,
|
||||||
|
#[clap(default_value = "4", long, env)]
|
||||||
|
max_client_batch_size: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
|
@ -112,6 +114,7 @@ async fn main() -> Result<(), RouterError> {
|
||||||
ngrok_edge,
|
ngrok_edge,
|
||||||
messages_api_enabled,
|
messages_api_enabled,
|
||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
|
max_client_batch_size,
|
||||||
} = args;
|
} = args;
|
||||||
|
|
||||||
// Launch Tokio runtime
|
// Launch Tokio runtime
|
||||||
|
@ -393,6 +396,7 @@ async fn main() -> Result<(), RouterError> {
|
||||||
tokenizer_config,
|
tokenizer_config,
|
||||||
messages_api_enabled,
|
messages_api_enabled,
|
||||||
disable_grammar_support,
|
disable_grammar_support,
|
||||||
|
max_client_batch_size,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
|
@ -16,6 +16,7 @@ use crate::{
|
||||||
CompletionRequest, DeltaToolCall, Function, Tool, VertexRequest, VertexResponse,
|
CompletionRequest, DeltaToolCall, Function, Tool, VertexRequest, VertexResponse,
|
||||||
};
|
};
|
||||||
use crate::{FunctionDefinition, ToolCall, ToolType};
|
use crate::{FunctionDefinition, ToolCall, ToolType};
|
||||||
|
use async_stream::__private::AsyncStream;
|
||||||
use axum::extract::Extension;
|
use axum::extract::Extension;
|
||||||
use axum::http::{HeaderMap, Method, StatusCode};
|
use axum::http::{HeaderMap, Method, StatusCode};
|
||||||
use axum::response::sse::{Event, KeepAlive, Sse};
|
use axum::response::sse::{Event, KeepAlive, Sse};
|
||||||
|
@ -23,8 +24,8 @@ use axum::response::{IntoResponse, Response};
|
||||||
use axum::routing::{get, post};
|
use axum::routing::{get, post};
|
||||||
use axum::{http, Json, Router};
|
use axum::{http, Json, Router};
|
||||||
use axum_tracing_opentelemetry::middleware::OtelAxumLayer;
|
use axum_tracing_opentelemetry::middleware::OtelAxumLayer;
|
||||||
use futures::stream::FuturesUnordered;
|
|
||||||
use futures::stream::StreamExt;
|
use futures::stream::StreamExt;
|
||||||
|
use futures::stream::{FuturesOrdered, FuturesUnordered};
|
||||||
use futures::Stream;
|
use futures::Stream;
|
||||||
use futures::TryStreamExt;
|
use futures::TryStreamExt;
|
||||||
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
|
||||||
|
@ -35,7 +36,9 @@ use std::sync::atomic::AtomicBool;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use text_generation_client::{ShardInfo, ShardedClient};
|
use text_generation_client::{ShardInfo, ShardedClient};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
use tokio::select;
|
||||||
use tokio::signal;
|
use tokio::signal;
|
||||||
|
use tokio::sync::oneshot;
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
use tower_http::cors::{AllowOrigin, CorsLayer};
|
use tower_http::cors::{AllowOrigin, CorsLayer};
|
||||||
use tracing::{info_span, instrument, Instrument};
|
use tracing::{info_span, instrument, Instrument};
|
||||||
|
@ -161,6 +164,15 @@ async fn generate(
|
||||||
Json(req): Json<GenerateRequest>,
|
Json(req): Json<GenerateRequest>,
|
||||||
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
|
||||||
let span = tracing::Span::current();
|
let span = tracing::Span::current();
|
||||||
|
generate_internal(infer, ComputeType(compute_type), Json(req), span).await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn generate_internal(
|
||||||
|
infer: Extension<Infer>,
|
||||||
|
ComputeType(compute_type): ComputeType,
|
||||||
|
Json(req): Json<GenerateRequest>,
|
||||||
|
span: tracing::Span,
|
||||||
|
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
metrics::increment_counter!("tgi_request_count");
|
metrics::increment_counter!("tgi_request_count");
|
||||||
|
|
||||||
|
@ -359,12 +371,13 @@ async fn generate_stream(
|
||||||
HeaderMap,
|
HeaderMap,
|
||||||
Sse<impl Stream<Item = Result<Event, Infallible>>>,
|
Sse<impl Stream<Item = Result<Event, Infallible>>>,
|
||||||
) {
|
) {
|
||||||
|
let span = tracing::Span::current();
|
||||||
let on_message_callback = |stream_token: StreamResponse| {
|
let on_message_callback = |stream_token: StreamResponse| {
|
||||||
let event = Event::default();
|
let event = Event::default();
|
||||||
event.json_data(stream_token).unwrap()
|
event.json_data(stream_token).unwrap()
|
||||||
};
|
};
|
||||||
let (headers, response_stream) =
|
let (headers, response_stream) =
|
||||||
generate_stream_internal(infer, compute_type, Json(req), on_message_callback).await;
|
generate_stream_internal(infer, compute_type, Json(req), on_message_callback, span).await;
|
||||||
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
||||||
(headers, sse)
|
(headers, sse)
|
||||||
}
|
}
|
||||||
|
@ -374,8 +387,8 @@ async fn generate_stream_internal(
|
||||||
ComputeType(compute_type): ComputeType,
|
ComputeType(compute_type): ComputeType,
|
||||||
Json(req): Json<GenerateRequest>,
|
Json(req): Json<GenerateRequest>,
|
||||||
on_message_callback: impl Fn(StreamResponse) -> Event,
|
on_message_callback: impl Fn(StreamResponse) -> Event,
|
||||||
|
span: tracing::Span,
|
||||||
) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) {
|
) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) {
|
||||||
let span = tracing::Span::current();
|
|
||||||
let start_time = Instant::now();
|
let start_time = Instant::now();
|
||||||
metrics::increment_counter!("tgi_request_count");
|
metrics::increment_counter!("tgi_request_count");
|
||||||
|
|
||||||
|
@ -581,6 +594,7 @@ async fn completions(
|
||||||
Extension(info): Extension<Info>,
|
Extension(info): Extension<Info>,
|
||||||
Json(req): Json<CompletionRequest>,
|
Json(req): Json<CompletionRequest>,
|
||||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
|
let span = tracing::Span::current();
|
||||||
metrics::increment_counter!("tgi_request_count");
|
metrics::increment_counter!("tgi_request_count");
|
||||||
|
|
||||||
let stream = req.stream;
|
let stream = req.stream;
|
||||||
|
@ -600,9 +614,25 @@ async fn completions(
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
// build the request passing some parameters
|
if req.prompt.len() > info.max_client_batch_size {
|
||||||
let generate_request = GenerateRequest {
|
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||||
inputs: req.prompt.to_string(),
|
return Err((
|
||||||
|
StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
|
Json(ErrorResponse {
|
||||||
|
error: format!(
|
||||||
|
"Number of prompts exceeds the maximum allowed batch size of {}",
|
||||||
|
info.max_client_batch_size
|
||||||
|
),
|
||||||
|
error_type: "batch size exceeded".to_string(),
|
||||||
|
}),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let generate_requests: Vec<GenerateRequest> = req
|
||||||
|
.prompt
|
||||||
|
.iter()
|
||||||
|
.map(|prompt| GenerateRequest {
|
||||||
|
inputs: prompt.to_string(),
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
best_of: None,
|
best_of: None,
|
||||||
temperature: req.temperature,
|
temperature: req.temperature,
|
||||||
|
@ -623,9 +653,25 @@ async fn completions(
|
||||||
top_n_tokens: None,
|
top_n_tokens: None,
|
||||||
grammar: None,
|
grammar: None,
|
||||||
},
|
},
|
||||||
};
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let mut x_compute_type = None;
|
||||||
|
let mut x_compute_characters = 0u32;
|
||||||
|
let mut x_accel_buffering = None;
|
||||||
|
|
||||||
if stream {
|
if stream {
|
||||||
|
let mut response_streams = FuturesOrdered::new();
|
||||||
|
for (index, generate_request) in generate_requests.into_iter().enumerate() {
|
||||||
|
let model_id = info.model_id.clone();
|
||||||
|
let system_fingerprint =
|
||||||
|
format!("{}-{}", info.version, info.docker_label.unwrap_or("native"));
|
||||||
|
let infer_clone = infer.clone();
|
||||||
|
let compute_type_clone = compute_type.clone();
|
||||||
|
let span_clone = span.clone();
|
||||||
|
|
||||||
|
// Create a future for each generate_stream_internal call.
|
||||||
|
let generate_future = async move {
|
||||||
let on_message_callback = move |stream_token: StreamResponse| {
|
let on_message_callback = move |stream_token: StreamResponse| {
|
||||||
let event = Event::default();
|
let event = Event::default();
|
||||||
|
|
||||||
|
@ -642,50 +688,158 @@ async fn completions(
|
||||||
|
|
||||||
choices: vec![CompletionComplete {
|
choices: vec![CompletionComplete {
|
||||||
finish_reason: "".to_string(),
|
finish_reason: "".to_string(),
|
||||||
index: 0,
|
index: index as u32,
|
||||||
logprobs: None,
|
logprobs: None,
|
||||||
text: stream_token.token.text,
|
text: stream_token.token.text,
|
||||||
}],
|
}],
|
||||||
|
|
||||||
model: info.model_id.clone(),
|
model: model_id.clone(),
|
||||||
system_fingerprint: format!(
|
system_fingerprint: system_fingerprint.clone(),
|
||||||
"{}-{}",
|
|
||||||
info.version,
|
|
||||||
info.docker_label.unwrap_or("native")
|
|
||||||
),
|
|
||||||
})
|
})
|
||||||
.map_or_else(
|
.map_or_else(|_e| Event::default(), |data| data)
|
||||||
|e| {
|
|
||||||
println!("Failed to serialize CompletionCompleteChunk: {:?}", e);
|
|
||||||
Event::default()
|
|
||||||
},
|
|
||||||
|data| data,
|
|
||||||
)
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let (headers, response_stream) = generate_stream_internal(
|
let (header_tx, header_rx) = oneshot::channel();
|
||||||
infer,
|
let (sse_tx, sse_rx) = tokio::sync::mpsc::unbounded_channel();
|
||||||
compute_type,
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let (header_map, sse) = generate_stream_internal(
|
||||||
|
infer_clone.clone(),
|
||||||
|
compute_type_clone.clone(),
|
||||||
Json(generate_request),
|
Json(generate_request),
|
||||||
on_message_callback,
|
on_message_callback,
|
||||||
|
span_clone.clone(),
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
// send and dont wait for response
|
||||||
|
let _ = header_tx.send(header_map);
|
||||||
|
|
||||||
|
// pin an emit messages to the sse_tx
|
||||||
|
let mut sse = Box::pin(sse);
|
||||||
|
while let Some(event) = sse.next().await {
|
||||||
|
if sse_tx.send(event).is_err() {
|
||||||
|
tracing::error!("Failed to send event. Receiver dropped.");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
(header_rx, sse_rx)
|
||||||
|
};
|
||||||
|
response_streams.push_back(generate_future);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut all_rxs = vec![];
|
||||||
|
|
||||||
|
while let Some((header_rx, sse_rx)) = response_streams.next().await {
|
||||||
|
all_rxs.push(sse_rx);
|
||||||
|
|
||||||
|
// get the headers from the first response of each stream
|
||||||
|
let headers = header_rx.await.map_err(|e| {
|
||||||
|
tracing::error!("Failed to get headers: {:?}", e);
|
||||||
|
(
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(ErrorResponse {
|
||||||
|
error: "Failed to get headers".to_string(),
|
||||||
|
error_type: "headers".to_string(),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
if x_compute_type.is_none() {
|
||||||
|
x_compute_type = headers
|
||||||
|
.get("x-compute-type")
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.map(|v| v.to_string());
|
||||||
|
|
||||||
|
x_accel_buffering = headers
|
||||||
|
.get("x-accel-buffering")
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.map(|v| v.to_string());
|
||||||
|
}
|
||||||
|
x_compute_characters += headers
|
||||||
|
.get("x-compute-characters")
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.and_then(|v| v.parse().ok())
|
||||||
|
.unwrap_or(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut headers = HeaderMap::new();
|
||||||
|
if let Some(x_compute_type) = x_compute_type {
|
||||||
|
headers.insert("x-compute-type", x_compute_type.parse().unwrap());
|
||||||
|
}
|
||||||
|
headers.insert("x-compute-characters", x_compute_characters.into());
|
||||||
|
if let Some(x_accel_buffering) = x_accel_buffering {
|
||||||
|
headers.insert("x-accel-buffering", x_accel_buffering.parse().unwrap());
|
||||||
|
}
|
||||||
|
|
||||||
|
// now sink the sse streams into a single stream and remove the ones that are done
|
||||||
|
let stream: AsyncStream<Result<Event, Infallible>, _> = async_stream::stream! {
|
||||||
|
loop {
|
||||||
|
let mut i = 0;
|
||||||
|
while i < all_rxs.len() {
|
||||||
|
let rx = &mut all_rxs[i];
|
||||||
|
select! {
|
||||||
|
Some(event) = rx.recv() => {
|
||||||
|
yield event;
|
||||||
|
}
|
||||||
|
else => {
|
||||||
|
all_rxs.remove(i);
|
||||||
|
continue; // skip the increment to handle the next element at the same index
|
||||||
|
}
|
||||||
|
}
|
||||||
|
i += 1; // only increment when no element was removed
|
||||||
|
}
|
||||||
|
|
||||||
|
if all_rxs.is_empty() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let sse = Sse::new(stream).keep_alive(KeepAlive::default());
|
||||||
Ok((headers, sse).into_response())
|
Ok((headers, sse).into_response())
|
||||||
} else {
|
} else {
|
||||||
let (headers, Json(generation)) = generate(
|
|
||||||
Extension(infer),
|
|
||||||
Extension(compute_type),
|
|
||||||
Json(generate_request),
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let current_time = std::time::SystemTime::now()
|
let current_time = std::time::SystemTime::now()
|
||||||
.duration_since(std::time::UNIX_EPOCH)
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||||
.as_secs();
|
.as_secs();
|
||||||
|
|
||||||
|
let responses = FuturesUnordered::new();
|
||||||
|
for (index, generate_request) in generate_requests.into_iter().enumerate() {
|
||||||
|
let infer_clone = infer.clone();
|
||||||
|
let compute_type_clone = compute_type.clone();
|
||||||
|
let span_clone = span.clone();
|
||||||
|
let response_future = async move {
|
||||||
|
let result = generate_internal(
|
||||||
|
Extension(infer_clone),
|
||||||
|
compute_type_clone,
|
||||||
|
Json(generate_request),
|
||||||
|
span_clone,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
result.map(|(headers, generation)| (index, headers, generation))
|
||||||
|
};
|
||||||
|
responses.push(response_future);
|
||||||
|
}
|
||||||
|
let generate_responses = responses.try_collect::<Vec<_>>().await?;
|
||||||
|
|
||||||
|
let mut prompt_tokens = 0u32;
|
||||||
|
let mut completion_tokens = 0u32;
|
||||||
|
let mut total_tokens = 0u32;
|
||||||
|
|
||||||
|
let mut x_compute_time = 0u32;
|
||||||
|
let mut x_total_time = 0u32;
|
||||||
|
let mut x_validation_time = 0u32;
|
||||||
|
let mut x_queue_time = 0u32;
|
||||||
|
let mut x_inference_time = 0u32;
|
||||||
|
let mut x_time_per_token = 0u32;
|
||||||
|
let mut x_prompt_tokens = 0u32;
|
||||||
|
let mut x_generated_tokens = 0u32;
|
||||||
|
|
||||||
|
let choices = generate_responses
|
||||||
|
.into_iter()
|
||||||
|
.map(|(index, headers, Json(generation))| {
|
||||||
let details = generation.details.ok_or((
|
let details = generation.details.ok_or((
|
||||||
// this should never happen but handle if details are missing unexpectedly
|
// this should never happen but handle if details are missing unexpectedly
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
@ -695,6 +849,65 @@ async fn completions(
|
||||||
}),
|
}),
|
||||||
))?;
|
))?;
|
||||||
|
|
||||||
|
if x_compute_type.is_none() {
|
||||||
|
x_compute_type = headers
|
||||||
|
.get("x-compute-type")
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.map(|v| v.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
// accumulate headers and usage from each response
|
||||||
|
x_compute_time += headers
|
||||||
|
.get("x-compute-time")
|
||||||
|
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||||
|
.unwrap_or(0);
|
||||||
|
x_compute_characters += headers
|
||||||
|
.get("x-compute-characters")
|
||||||
|
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||||
|
.unwrap_or(0);
|
||||||
|
x_total_time += headers
|
||||||
|
.get("x-total-time")
|
||||||
|
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||||
|
.unwrap_or(0);
|
||||||
|
x_validation_time += headers
|
||||||
|
.get("x-validation-time")
|
||||||
|
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||||
|
.unwrap_or(0);
|
||||||
|
x_queue_time += headers
|
||||||
|
.get("x-queue-time")
|
||||||
|
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||||
|
.unwrap_or(0);
|
||||||
|
x_inference_time += headers
|
||||||
|
.get("x-inference-time")
|
||||||
|
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||||
|
.unwrap_or(0);
|
||||||
|
x_time_per_token += headers
|
||||||
|
.get("x-time-per-token")
|
||||||
|
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||||
|
.unwrap_or(0);
|
||||||
|
x_prompt_tokens += headers
|
||||||
|
.get("x-prompt-tokens")
|
||||||
|
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||||
|
.unwrap_or(0);
|
||||||
|
x_generated_tokens += headers
|
||||||
|
.get("x-generated-tokens")
|
||||||
|
.and_then(|v| v.to_str().ok()?.parse().ok())
|
||||||
|
.unwrap_or(0);
|
||||||
|
|
||||||
|
prompt_tokens += details.prefill.len() as u32;
|
||||||
|
completion_tokens += details.generated_tokens;
|
||||||
|
total_tokens += details.prefill.len() as u32 + details.generated_tokens;
|
||||||
|
|
||||||
|
Ok(CompletionComplete {
|
||||||
|
finish_reason: details.finish_reason.to_string(),
|
||||||
|
index: index as u32,
|
||||||
|
logprobs: None,
|
||||||
|
text: generation.generated_text,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>, _>>()
|
||||||
|
.map_err(|(status, Json(err))| (status, Json(err)))?;
|
||||||
|
|
||||||
let response = Completion {
|
let response = Completion {
|
||||||
id: "".to_string(),
|
id: "".to_string(),
|
||||||
object: "text_completion".to_string(),
|
object: "text_completion".to_string(),
|
||||||
|
@ -705,19 +918,30 @@ async fn completions(
|
||||||
info.version,
|
info.version,
|
||||||
info.docker_label.unwrap_or("native")
|
info.docker_label.unwrap_or("native")
|
||||||
),
|
),
|
||||||
choices: vec![CompletionComplete {
|
choices,
|
||||||
finish_reason: details.finish_reason.to_string(),
|
|
||||||
index: 0,
|
|
||||||
logprobs: None,
|
|
||||||
text: generation.generated_text,
|
|
||||||
}],
|
|
||||||
usage: Usage {
|
usage: Usage {
|
||||||
prompt_tokens: details.prefill.len() as u32,
|
prompt_tokens,
|
||||||
completion_tokens: details.generated_tokens,
|
completion_tokens,
|
||||||
total_tokens: details.prefill.len() as u32 + details.generated_tokens,
|
total_tokens,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// headers similar to `generate` but aggregated
|
||||||
|
let mut headers = HeaderMap::new();
|
||||||
|
if let Some(x_compute_type) = x_compute_type {
|
||||||
|
headers.insert("x-compute-type", x_compute_type.parse().unwrap());
|
||||||
|
}
|
||||||
|
headers.insert("x-compute-characters", x_compute_characters.into());
|
||||||
|
headers.insert("x-total-time", x_total_time.into());
|
||||||
|
headers.insert("x-validation-time", x_validation_time.into());
|
||||||
|
headers.insert("x-queue-time", x_queue_time.into());
|
||||||
|
headers.insert("x-inference-time", x_inference_time.into());
|
||||||
|
headers.insert("x-time-per-token", x_time_per_token.into());
|
||||||
|
headers.insert("x-prompt-tokens", x_prompt_tokens.into());
|
||||||
|
headers.insert("x-generated-tokens", x_generated_tokens.into());
|
||||||
|
if let Some(x_accel_buffering) = x_accel_buffering {
|
||||||
|
headers.insert("x-accel-buffering", x_accel_buffering.parse().unwrap());
|
||||||
|
}
|
||||||
Ok((headers, Json(response)).into_response())
|
Ok((headers, Json(response)).into_response())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -762,6 +986,7 @@ async fn chat_completions(
|
||||||
Extension(info): Extension<Info>,
|
Extension(info): Extension<Info>,
|
||||||
Json(req): Json<ChatRequest>,
|
Json(req): Json<ChatRequest>,
|
||||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
|
let span = tracing::Span::current();
|
||||||
metrics::increment_counter!("tgi_request_count");
|
metrics::increment_counter!("tgi_request_count");
|
||||||
|
|
||||||
let ChatRequest {
|
let ChatRequest {
|
||||||
|
@ -899,17 +1124,14 @@ async fn chat_completions(
|
||||||
compute_type,
|
compute_type,
|
||||||
Json(generate_request),
|
Json(generate_request),
|
||||||
on_message_callback,
|
on_message_callback,
|
||||||
|
span,
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
||||||
Ok((headers, sse).into_response())
|
Ok((headers, sse).into_response())
|
||||||
} else {
|
} else {
|
||||||
let (headers, Json(generation)) = generate(
|
let (headers, Json(generation)) =
|
||||||
Extension(infer),
|
generate_internal(Extension(infer), compute_type, Json(generate_request), span).await?;
|
||||||
Extension(compute_type),
|
|
||||||
Json(generate_request),
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let current_time = std::time::SystemTime::now()
|
let current_time = std::time::SystemTime::now()
|
||||||
.duration_since(std::time::UNIX_EPOCH)
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
@ -1006,6 +1228,7 @@ async fn vertex_compatibility(
|
||||||
Extension(compute_type): Extension<ComputeType>,
|
Extension(compute_type): Extension<ComputeType>,
|
||||||
Json(req): Json<VertexRequest>,
|
Json(req): Json<VertexRequest>,
|
||||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||||
|
let span = tracing::Span::current();
|
||||||
metrics::increment_counter!("tgi_request_count");
|
metrics::increment_counter!("tgi_request_count");
|
||||||
|
|
||||||
// check that theres at least one instance
|
// check that theres at least one instance
|
||||||
|
@ -1037,10 +1260,11 @@ async fn vertex_compatibility(
|
||||||
};
|
};
|
||||||
|
|
||||||
async {
|
async {
|
||||||
generate(
|
generate_internal(
|
||||||
Extension(infer.clone()),
|
Extension(infer.clone()),
|
||||||
Extension(compute_type.clone()),
|
compute_type.clone(),
|
||||||
Json(generate_request),
|
Json(generate_request),
|
||||||
|
span.clone(),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.map(|(_, Json(generation))| generation.generated_text)
|
.map(|(_, Json(generation))| generation.generated_text)
|
||||||
|
@ -1152,6 +1376,7 @@ pub async fn run(
|
||||||
tokenizer_config: HubTokenizerConfig,
|
tokenizer_config: HubTokenizerConfig,
|
||||||
messages_api_enabled: bool,
|
messages_api_enabled: bool,
|
||||||
grammar_support: bool,
|
grammar_support: bool,
|
||||||
|
max_client_batch_size: usize,
|
||||||
) -> Result<(), axum::BoxError> {
|
) -> Result<(), axum::BoxError> {
|
||||||
// OpenAPI documentation
|
// OpenAPI documentation
|
||||||
#[derive(OpenApi)]
|
#[derive(OpenApi)]
|
||||||
|
@ -1326,6 +1551,7 @@ pub async fn run(
|
||||||
max_waiting_tokens,
|
max_waiting_tokens,
|
||||||
max_batch_size,
|
max_batch_size,
|
||||||
validation_workers,
|
validation_workers,
|
||||||
|
max_client_batch_size,
|
||||||
version: env!("CARGO_PKG_VERSION"),
|
version: env!("CARGO_PKG_VERSION"),
|
||||||
sha: option_env!("VERGEN_GIT_SHA"),
|
sha: option_env!("VERGEN_GIT_SHA"),
|
||||||
docker_label: option_env!("DOCKER_LABEL"),
|
docker_label: option_env!("DOCKER_LABEL"),
|
||||||
|
|
Loading…
Reference in New Issue