Stream options. (#2533)

* Stream options.

* Fetch stuff from nix integration test for easier testing.

* Adding the assert.

* Only send the usage when asked for.

* Update the docs.

* Impure test because we need network.

* develop.

* Optional usage.

* Fixes.

* Workflow
This commit is contained in:
Nicolas Patry 2024-09-19 20:50:37 +02:00 committed by GitHub
parent ce85efa968
commit f512021e77
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 475 additions and 38 deletions

View File

@ -38,4 +38,4 @@ jobs:
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
- name: Rust tests.
run: nix build .#checks.$(nix eval --impure --raw --expr 'builtins.currentSystem').rust -L
run: nix develop .#test --command cargo test

View File

@ -42,6 +42,7 @@ jobs:
sudo rm -rf /usr/share/dotnet # will release about 20GB if you don't need .NET
- name: Install
run: |
sudo apt update
sudo apt install python3.11-dev -y
make install-cpu
- name: Run server tests

View File

@ -23,9 +23,11 @@ docs/openapi.json:
- '#/components/schemas/GenerateResponse/properties/details/nullable'
- '#/components/schemas/StreamResponse/properties/details/nullable'
- '#/components/schemas/ChatRequest/properties/response_format/nullable'
- '#/components/schemas/ChatRequest/properties/stream_options/nullable'
- '#/components/schemas/ChatRequest/properties/tool_choice/nullable'
- '#/components/schemas/ToolChoice/nullable'
- '#/components/schemas/ChatCompletionComplete/properties/logprobs/nullable'
- '#/components/schemas/ChatCompletionChunk/properties/usage/nullable'
- '#/components/schemas/ChatCompletionChoice/properties/logprobs/nullable'
no-invalid-media-type-examples:
- '#/paths/~1/post/responses/422/content/application~1json/example'

View File

@ -5,7 +5,8 @@ members = [
"backends/grpc-metadata",
"backends/trtllm",
"backends/client",
"launcher"
"launcher",
"router"
]
default-members = [
"benchmark",
@ -13,7 +14,8 @@ default-members = [
"backends/grpc-metadata",
# "backends/trtllm",
"backends/client",
"launcher"
"launcher",
"router"
]
resolver = "2"

View File

@ -168,7 +168,7 @@ class ChatCompletionComplete(BaseModel):
# Log probabilities for the chat completion
logprobs: Optional[Any]
# Reason for completion
finish_reason: str
finish_reason: Optional[str]
# Usage details of the chat completion
usage: Optional[Any] = None
@ -191,6 +191,7 @@ class ChatCompletionChunk(BaseModel):
model: str
system_fingerprint: str
choices: List[Choice]
usage: Optional[Any] = None
class Parameters(BaseModel):

View File

@ -742,6 +742,14 @@
},
"system_fingerprint": {
"type": "string"
},
"usage": {
"allOf": [
{
"$ref": "#/components/schemas/Usage"
}
],
"nullable": true
}
}
},
@ -937,6 +945,14 @@
"stream": {
"type": "boolean"
},
"stream_options": {
"allOf": [
{
"$ref": "#/components/schemas/StreamOptions"
}
],
"nullable": true
},
"temperature": {
"type": "number",
"format": "float",
@ -1912,6 +1928,19 @@
}
}
},
"StreamOptions": {
"type": "object",
"required": [
"include_usage"
],
"properties": {
"include_usage": {
"type": "boolean",
"description": "If set, an additional chunk will be streamed before the data: [DONE] message. The usage field on this chunk shows the token usage statistics for the entire request, and the choices field will always be an empty array. All other chunks will also include a usage field, but with a null value.",
"example": "true"
}
}
},
"StreamResponse": {
"type": "object",
"required": [

View File

@ -479,11 +479,11 @@
"systems": "systems_6"
},
"locked": {
"lastModified": 1710146030,
"narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=",
"lastModified": 1726560853,
"narHash": "sha256-X6rJYSESBVr3hBoH0WbKE5KvhPU5bloyZ2L4K60/fPQ=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a",
"rev": "c1dfcf08411b08f6b8615f7d8971a2bfa81d5e8a",
"type": "github"
},
"original": {
@ -853,11 +853,11 @@
]
},
"locked": {
"lastModified": 1726280639,
"narHash": "sha256-YfLRPlFZWrT2oRLNAoqf7G3+NnUTDdlIJk6tmBU7kXM=",
"lastModified": 1726626348,
"narHash": "sha256-sYV7e1B1yLcxo8/h+/hTwzZYmaju2oObNiy5iRI0C30=",
"owner": "oxalica",
"repo": "rust-overlay",
"rev": "e9f8641c92f26fd1e076e705edb12147c384171d",
"rev": "6fd52ad8bd88f39efb2c999cc971921c2fb9f3a2",
"type": "github"
},
"original": {

View File

@ -67,31 +67,38 @@
'';
};
server = pkgs.python3.pkgs.callPackage ./nix/server.nix { inherit nix-filter; };
client = pkgs.python3.pkgs.callPackage ./nix/client.nix { };
in
{
checks = {
rust = with pkgs; rustPlatform.buildRustPackage {
name = "rust-checks";
src = ./.;
cargoLock = {
lockFile = ./Cargo.lock;
rust =
with pkgs;
rustPlatform.buildRustPackage {
name = "rust-checks";
src = ./.;
cargoLock = {
lockFile = ./Cargo.lock;
};
buildInputs = [ openssl.dev ];
nativeBuildInputs = [
clippy
pkg-config
protobuf
python3
rustfmt
];
buildPhase = ''
cargo check
'';
checkPhase = ''
cargo fmt -- --check
cargo test -j $NIX_BUILD_CORES
cargo clippy
'';
installPhase = "touch $out";
};
buildInputs = [ openssl.dev ];
nativeBuildInputs = [ clippy pkg-config protobuf python3 rustfmt ];
buildPhase = ''
cargo check
'';
checkPhase = ''
cargo fmt -- --check
cargo test -j $NIX_BUILD_CORES
cargo clippy
'';
installPhase = "touch $out";
} ;
};
formatter = pkgs.nixfmt-rfc-style;
devShells = with pkgs; rec {
default = pure;
@ -106,10 +113,11 @@
test = mkShell {
buildInputs =
[
# benchmark
# launcher
# router
benchmark
launcher
router
server
client
openssl.dev
pkg-config
cargo

View File

@ -0,0 +1,206 @@
[
{
"choices": [
{
"delta": {
"content": "**",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1726656043,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": "Deep",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1726656043,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": " Learning",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1726656043,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": ":",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1726656043,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": " An",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1726656043,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": " Overview",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1726656043,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": "**\n",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1726656044,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": "================================",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1726656044,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": "=====",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1726656044,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": "\n\n",
"role": "assistant",
"tool_calls": null
},
"finish_reason": "length",
"index": 0,
"logprobs": null
}
],
"created": 1726656044,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native",
"usage": {
"completion_tokens": 10,
"prompt_tokens": 40,
"total_tokens": 50
}
}
]

View File

@ -3,9 +3,7 @@ import requests
import json
from aiohttp import ClientSession
from text_generation.types import (
Completion,
)
from text_generation.types import Completion, ChatCompletionChunk
@pytest.fixture(scope="module")
@ -50,6 +48,114 @@ def test_flash_llama_completion_single_prompt(
assert response == response_snapshot
@pytest.mark.release
async def test_flash_llama_completion_stream_usage(
flash_llama_completion, response_snapshot
):
url = f"{flash_llama_completion.base_url}/v1/chat/completions"
request = {
"model": "tgi",
"messages": [
{
"role": "user",
"content": "What is Deep Learning?",
}
],
"max_tokens": 10,
"temperature": 0.0,
"stream_options": {"include_usage": True},
"stream": True,
}
string = ""
chunks = []
had_usage = False
async with ClientSession(headers=flash_llama_completion.headers) as session:
async with session.post(url, json=request) as response:
# iterate over the stream
async for chunk in response.content.iter_any():
# remove "data:"
chunk = chunk.decode().split("\n\n")
# remove "data:" if present
chunk = [c.replace("data:", "") for c in chunk]
# remove empty strings
chunk = [c for c in chunk if c]
# remove completion marking chunk
chunk = [c for c in chunk if c != " [DONE]"]
# parse json
chunk = [json.loads(c) for c in chunk]
for c in chunk:
chunks.append(ChatCompletionChunk(**c))
assert "choices" in c
if len(c["choices"]) == 1:
index = c["choices"][0]["index"]
assert index == 0
string += c["choices"][0]["delta"]["content"]
has_usage = c["usage"] is not None
assert not had_usage
if has_usage:
had_usage = True
else:
raise RuntimeError("Expected different payload")
assert had_usage
assert (
string
== "**Deep Learning: An Overview**\n=====================================\n\n"
)
assert chunks == response_snapshot
request = {
"model": "tgi",
"messages": [
{
"role": "user",
"content": "What is Deep Learning?",
}
],
"max_tokens": 10,
"temperature": 0.0,
"stream": True,
}
string = ""
chunks = []
had_usage = False
async with ClientSession(headers=flash_llama_completion.headers) as session:
async with session.post(url, json=request) as response:
# iterate over the stream
async for chunk in response.content.iter_any():
# remove "data:"
chunk = chunk.decode().split("\n\n")
# remove "data:" if present
chunk = [c.replace("data:", "") for c in chunk]
# remove empty strings
chunk = [c for c in chunk if c]
# remove completion marking chunk
chunk = [c for c in chunk if c != " [DONE]"]
# parse json
chunk = [json.loads(c) for c in chunk]
for c in chunk:
chunks.append(ChatCompletionChunk(**c))
assert "choices" in c
if len(c["choices"]) == 1:
index = c["choices"][0]["index"]
assert index == 0
string += c["choices"][0]["delta"]["content"]
has_usage = c["usage"] is not None
assert not had_usage
if has_usage:
had_usage = True
else:
raise RuntimeError("Expected different payload")
assert not had_usage
assert (
string
== "**Deep Learning: An Overview**\n=====================================\n\n"
)
@pytest.mark.release
def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot):
response = requests.post(

21
nix/client.nix Normal file
View File

@ -0,0 +1,21 @@
{
buildPythonPackage,
poetry-core,
huggingface-hub,
pydantic,
}:
buildPythonPackage {
name = "text-generation";
src = ../clients/python;
pyproject = true;
build-system = [ poetry-core ];
dependencies = [
huggingface-hub
pydantic
];
}

View File

@ -684,6 +684,7 @@ pub(crate) struct ChatCompletionChunk {
pub model: String,
pub system_fingerprint: String,
pub choices: Vec<ChatCompletionChoice>,
pub usage: Option<Usage>,
}
#[derive(Clone, Serialize, ToSchema)]
@ -732,6 +733,7 @@ impl ChatCompletionChunk {
created: u64,
logprobs: Option<ChatCompletionLogprobs>,
finish_reason: Option<String>,
usage: Option<Usage>,
) -> Self {
let delta = match (delta, tool_calls) {
(Some(delta), _) => ChatCompletionDelta::Chat(TextMessage {
@ -766,6 +768,7 @@ impl ChatCompletionChunk {
logprobs,
finish_reason,
}],
usage,
}
}
}
@ -880,6 +883,18 @@ pub(crate) struct ChatRequest {
#[serde(default)]
#[schema(nullable = true, default = "null", example = "null")]
pub guideline: Option<String>,
/// Options for streaming response. Only set this when you set stream: true.
#[serde(default)]
#[schema(nullable = true, example = "null")]
pub stream_options: Option<StreamOptions>,
}
#[derive(Clone, Deserialize, ToSchema, Serialize)]
struct StreamOptions {
/// If set, an additional chunk will be streamed before the data: [DONE] message. The usage field on this chunk shows the token usage statistics for the entire request, and the choices field will always be an empty array. All other chunks will also include a usage field, but with a null value.
#[schema(example = "true")]
include_usage: bool,
}
pub fn default_tool_prompt() -> String {
@ -1472,6 +1487,27 @@ mod tests {
let textmsg: TextMessage = message.into();
assert_eq!(textmsg.content, "Whats in this image?![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png)");
}
#[test]
fn test_chat_stream_options() {
let json = json!({
"model": "",
"stream_options": {"include_usage": true},
"messages": [{
"role": "user",
"content": "Hello"
}]
});
let request: ChatRequest = serde_json::from_str(json.to_string().as_str()).unwrap();
assert!(matches!(
request.stream_options,
Some(StreamOptions {
include_usage: true
})
));
}
#[test]
fn openai_output() {
let message = OutputMessage::ChatMessage(TextMessage {

View File

@ -13,8 +13,8 @@ use crate::{
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
HubProcessorConfig, HubTokenizerConfig, Info, Message, MessageChunk, MessageContent,
OutputMessage, PrefillToken, SimpleToken, StreamDetails, StreamResponse, TextMessage, Token,
TokenizeResponse, ToolCallDelta, ToolCallMessage, Url, Usage, Validation,
OutputMessage, PrefillToken, SimpleToken, StreamDetails, StreamOptions, StreamResponse,
TextMessage, Token, TokenizeResponse, ToolCallDelta, ToolCallMessage, Url, Usage, Validation,
};
use crate::{
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
@ -1175,6 +1175,7 @@ async fn chat_completions(
seed,
stop,
stream,
stream_options,
tools,
tool_choice,
tool_prompt,
@ -1265,6 +1266,28 @@ async fn chat_completions(
(content, None)
};
let (usage, finish_reason) = match stream_token.details {
Some(details) => {
let usage = if stream_options
.as_ref()
.map(|s| s.include_usage)
.unwrap_or(false)
{
let completion_tokens = details.generated_tokens;
let prompt_tokens = details.input_length;
let total_tokens = prompt_tokens + completion_tokens;
Some(Usage {
completion_tokens,
prompt_tokens,
total_tokens,
})
} else {
None
};
(usage, Some(details.finish_reason.format(true)))
}
None => (None, None),
};
event
.json_data(CompletionType::ChatCompletionChunk(
ChatCompletionChunk::new(
@ -1274,7 +1297,8 @@ async fn chat_completions(
tool_calls,
current_time,
logprobs,
stream_token.details.map(|d| d.finish_reason.format(true)),
finish_reason,
usage,
),
))
.unwrap_or_else(|e| {
@ -1664,6 +1688,7 @@ StreamDetails,
ErrorResponse,
GrammarType,
Usage,
StreamOptions,
DeltaToolCall,
ToolType,
Tool,