Merge branch 'fix_rocm_fa' into rocm_6.2_fixes
This commit is contained in:
commit
afe3fed1a4
|
@ -38,4 +38,4 @@ jobs:
|
||||||
env:
|
env:
|
||||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||||
- name: Rust tests.
|
- name: Rust tests.
|
||||||
run: nix build .#checks.$(nix eval --impure --raw --expr 'builtins.currentSystem').rust -L
|
run: nix develop .#test --command cargo test
|
||||||
|
|
|
@ -42,6 +42,7 @@ jobs:
|
||||||
sudo rm -rf /usr/share/dotnet # will release about 20GB if you don't need .NET
|
sudo rm -rf /usr/share/dotnet # will release about 20GB if you don't need .NET
|
||||||
- name: Install
|
- name: Install
|
||||||
run: |
|
run: |
|
||||||
|
sudo apt update
|
||||||
sudo apt install python3.11-dev -y
|
sudo apt install python3.11-dev -y
|
||||||
make install-cpu
|
make install-cpu
|
||||||
- name: Run server tests
|
- name: Run server tests
|
||||||
|
|
|
@ -23,9 +23,11 @@ docs/openapi.json:
|
||||||
- '#/components/schemas/GenerateResponse/properties/details/nullable'
|
- '#/components/schemas/GenerateResponse/properties/details/nullable'
|
||||||
- '#/components/schemas/StreamResponse/properties/details/nullable'
|
- '#/components/schemas/StreamResponse/properties/details/nullable'
|
||||||
- '#/components/schemas/ChatRequest/properties/response_format/nullable'
|
- '#/components/schemas/ChatRequest/properties/response_format/nullable'
|
||||||
|
- '#/components/schemas/ChatRequest/properties/stream_options/nullable'
|
||||||
- '#/components/schemas/ChatRequest/properties/tool_choice/nullable'
|
- '#/components/schemas/ChatRequest/properties/tool_choice/nullable'
|
||||||
- '#/components/schemas/ToolChoice/nullable'
|
- '#/components/schemas/ToolChoice/nullable'
|
||||||
- '#/components/schemas/ChatCompletionComplete/properties/logprobs/nullable'
|
- '#/components/schemas/ChatCompletionComplete/properties/logprobs/nullable'
|
||||||
|
- '#/components/schemas/ChatCompletionChunk/properties/usage/nullable'
|
||||||
- '#/components/schemas/ChatCompletionChoice/properties/logprobs/nullable'
|
- '#/components/schemas/ChatCompletionChoice/properties/logprobs/nullable'
|
||||||
no-invalid-media-type-examples:
|
no-invalid-media-type-examples:
|
||||||
- '#/paths/~1/post/responses/422/content/application~1json/example'
|
- '#/paths/~1/post/responses/422/content/application~1json/example'
|
||||||
|
|
|
@ -5,7 +5,8 @@ members = [
|
||||||
"backends/grpc-metadata",
|
"backends/grpc-metadata",
|
||||||
"backends/trtllm",
|
"backends/trtllm",
|
||||||
"backends/client",
|
"backends/client",
|
||||||
"launcher"
|
"launcher",
|
||||||
|
"router"
|
||||||
]
|
]
|
||||||
default-members = [
|
default-members = [
|
||||||
"benchmark",
|
"benchmark",
|
||||||
|
@ -13,7 +14,8 @@ default-members = [
|
||||||
"backends/grpc-metadata",
|
"backends/grpc-metadata",
|
||||||
# "backends/trtllm",
|
# "backends/trtllm",
|
||||||
"backends/client",
|
"backends/client",
|
||||||
"launcher"
|
"launcher",
|
||||||
|
"router"
|
||||||
]
|
]
|
||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
|
||||||
|
|
|
@ -168,7 +168,7 @@ class ChatCompletionComplete(BaseModel):
|
||||||
# Log probabilities for the chat completion
|
# Log probabilities for the chat completion
|
||||||
logprobs: Optional[Any]
|
logprobs: Optional[Any]
|
||||||
# Reason for completion
|
# Reason for completion
|
||||||
finish_reason: str
|
finish_reason: Optional[str]
|
||||||
# Usage details of the chat completion
|
# Usage details of the chat completion
|
||||||
usage: Optional[Any] = None
|
usage: Optional[Any] = None
|
||||||
|
|
||||||
|
@ -191,6 +191,7 @@ class ChatCompletionChunk(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
system_fingerprint: str
|
system_fingerprint: str
|
||||||
choices: List[Choice]
|
choices: List[Choice]
|
||||||
|
usage: Optional[Any] = None
|
||||||
|
|
||||||
|
|
||||||
class Parameters(BaseModel):
|
class Parameters(BaseModel):
|
||||||
|
|
|
@ -742,6 +742,14 @@
|
||||||
},
|
},
|
||||||
"system_fingerprint": {
|
"system_fingerprint": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
|
},
|
||||||
|
"usage": {
|
||||||
|
"allOf": [
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/Usage"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"nullable": true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -937,6 +945,14 @@
|
||||||
"stream": {
|
"stream": {
|
||||||
"type": "boolean"
|
"type": "boolean"
|
||||||
},
|
},
|
||||||
|
"stream_options": {
|
||||||
|
"allOf": [
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/StreamOptions"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"nullable": true
|
||||||
|
},
|
||||||
"temperature": {
|
"temperature": {
|
||||||
"type": "number",
|
"type": "number",
|
||||||
"format": "float",
|
"format": "float",
|
||||||
|
@ -1912,6 +1928,19 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"StreamOptions": {
|
||||||
|
"type": "object",
|
||||||
|
"required": [
|
||||||
|
"include_usage"
|
||||||
|
],
|
||||||
|
"properties": {
|
||||||
|
"include_usage": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "If set, an additional chunk will be streamed before the data: [DONE] message. The usage field on this chunk shows the token usage statistics for the entire request, and the choices field will always be an empty array. All other chunks will also include a usage field, but with a null value.",
|
||||||
|
"example": "true"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"StreamResponse": {
|
"StreamResponse": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": [
|
"required": [
|
||||||
|
|
|
@ -55,7 +55,9 @@ Options:
|
||||||
## QUANTIZE
|
## QUANTIZE
|
||||||
```shell
|
```shell
|
||||||
--quantize <QUANTIZE>
|
--quantize <QUANTIZE>
|
||||||
Whether you want the model to be quantized
|
Quantization method to use for the model. It is not necessary to specify this option for pre-quantized models, since the quantization method is read from the model configuration.
|
||||||
|
|
||||||
|
Marlin kernels will be used automatically for GPTQ/AWQ models.
|
||||||
|
|
||||||
[env: QUANTIZE=]
|
[env: QUANTIZE=]
|
||||||
|
|
||||||
|
|
18
flake.lock
18
flake.lock
|
@ -479,11 +479,11 @@
|
||||||
"systems": "systems_6"
|
"systems": "systems_6"
|
||||||
},
|
},
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1710146030,
|
"lastModified": 1726560853,
|
||||||
"narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=",
|
"narHash": "sha256-X6rJYSESBVr3hBoH0WbKE5KvhPU5bloyZ2L4K60/fPQ=",
|
||||||
"owner": "numtide",
|
"owner": "numtide",
|
||||||
"repo": "flake-utils",
|
"repo": "flake-utils",
|
||||||
"rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a",
|
"rev": "c1dfcf08411b08f6b8615f7d8971a2bfa81d5e8a",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
|
@ -853,11 +853,11 @@
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1726280639,
|
"lastModified": 1726626348,
|
||||||
"narHash": "sha256-YfLRPlFZWrT2oRLNAoqf7G3+NnUTDdlIJk6tmBU7kXM=",
|
"narHash": "sha256-sYV7e1B1yLcxo8/h+/hTwzZYmaju2oObNiy5iRI0C30=",
|
||||||
"owner": "oxalica",
|
"owner": "oxalica",
|
||||||
"repo": "rust-overlay",
|
"repo": "rust-overlay",
|
||||||
"rev": "e9f8641c92f26fd1e076e705edb12147c384171d",
|
"rev": "6fd52ad8bd88f39efb2c999cc971921c2fb9f3a2",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
|
@ -978,11 +978,11 @@
|
||||||
"nixpkgs": "nixpkgs_6"
|
"nixpkgs": "nixpkgs_6"
|
||||||
},
|
},
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1726229792,
|
"lastModified": 1726743157,
|
||||||
"narHash": "sha256-9xsLmjc9nr7a4PTddKv2DOi82ompTtJNyjO6R67y5tE=",
|
"narHash": "sha256-7OczwJsA47o+aUftMwkoh8R31DlNSl2FgRjqE8zAggk=",
|
||||||
"owner": "danieldk",
|
"owner": "danieldk",
|
||||||
"repo": "tgi-nix",
|
"repo": "tgi-nix",
|
||||||
"rev": "1a902f4818e94c3f8d95f6000db17bc3fadd0ce7",
|
"rev": "bcc9fd01cf81bc42cebb999a736a377adfa8942f",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
|
|
23
flake.nix
23
flake.nix
|
@ -67,17 +67,26 @@
|
||||||
'';
|
'';
|
||||||
};
|
};
|
||||||
server = pkgs.python3.pkgs.callPackage ./nix/server.nix { inherit nix-filter; };
|
server = pkgs.python3.pkgs.callPackage ./nix/server.nix { inherit nix-filter; };
|
||||||
|
client = pkgs.python3.pkgs.callPackage ./nix/client.nix { };
|
||||||
in
|
in
|
||||||
{
|
{
|
||||||
checks = {
|
checks = {
|
||||||
rust = with pkgs; rustPlatform.buildRustPackage {
|
rust =
|
||||||
|
with pkgs;
|
||||||
|
rustPlatform.buildRustPackage {
|
||||||
name = "rust-checks";
|
name = "rust-checks";
|
||||||
src = ./.;
|
src = ./.;
|
||||||
cargoLock = {
|
cargoLock = {
|
||||||
lockFile = ./Cargo.lock;
|
lockFile = ./Cargo.lock;
|
||||||
};
|
};
|
||||||
buildInputs = [ openssl.dev ];
|
buildInputs = [ openssl.dev ];
|
||||||
nativeBuildInputs = [ clippy pkg-config protobuf python3 rustfmt ];
|
nativeBuildInputs = [
|
||||||
|
clippy
|
||||||
|
pkg-config
|
||||||
|
protobuf
|
||||||
|
python3
|
||||||
|
rustfmt
|
||||||
|
];
|
||||||
buildPhase = ''
|
buildPhase = ''
|
||||||
cargo check
|
cargo check
|
||||||
'';
|
'';
|
||||||
|
@ -89,9 +98,7 @@
|
||||||
installPhase = "touch $out";
|
installPhase = "touch $out";
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
formatter = pkgs.nixfmt-rfc-style;
|
formatter = pkgs.nixfmt-rfc-style;
|
||||||
|
|
||||||
devShells = with pkgs; rec {
|
devShells = with pkgs; rec {
|
||||||
default = pure;
|
default = pure;
|
||||||
|
|
||||||
|
@ -106,10 +113,11 @@
|
||||||
test = mkShell {
|
test = mkShell {
|
||||||
buildInputs =
|
buildInputs =
|
||||||
[
|
[
|
||||||
# benchmark
|
benchmark
|
||||||
# launcher
|
launcher
|
||||||
# router
|
router
|
||||||
server
|
server
|
||||||
|
client
|
||||||
openssl.dev
|
openssl.dev
|
||||||
pkg-config
|
pkg-config
|
||||||
cargo
|
cargo
|
||||||
|
@ -149,6 +157,7 @@
|
||||||
pyright
|
pyright
|
||||||
pytest
|
pytest
|
||||||
pytest-asyncio
|
pytest-asyncio
|
||||||
|
redocly
|
||||||
ruff
|
ruff
|
||||||
syrupy
|
syrupy
|
||||||
]);
|
]);
|
||||||
|
|
|
@ -0,0 +1,206 @@
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"delta": {
|
||||||
|
"content": "**",
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1726656043,
|
||||||
|
"id": "",
|
||||||
|
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"system_fingerprint": "2.2.1-dev0-native",
|
||||||
|
"usage": null
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"delta": {
|
||||||
|
"content": "Deep",
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1726656043,
|
||||||
|
"id": "",
|
||||||
|
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"system_fingerprint": "2.2.1-dev0-native",
|
||||||
|
"usage": null
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"delta": {
|
||||||
|
"content": " Learning",
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1726656043,
|
||||||
|
"id": "",
|
||||||
|
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"system_fingerprint": "2.2.1-dev0-native",
|
||||||
|
"usage": null
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"delta": {
|
||||||
|
"content": ":",
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1726656043,
|
||||||
|
"id": "",
|
||||||
|
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"system_fingerprint": "2.2.1-dev0-native",
|
||||||
|
"usage": null
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"delta": {
|
||||||
|
"content": " An",
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1726656043,
|
||||||
|
"id": "",
|
||||||
|
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"system_fingerprint": "2.2.1-dev0-native",
|
||||||
|
"usage": null
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"delta": {
|
||||||
|
"content": " Overview",
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1726656043,
|
||||||
|
"id": "",
|
||||||
|
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"system_fingerprint": "2.2.1-dev0-native",
|
||||||
|
"usage": null
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"delta": {
|
||||||
|
"content": "**\n",
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1726656044,
|
||||||
|
"id": "",
|
||||||
|
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"system_fingerprint": "2.2.1-dev0-native",
|
||||||
|
"usage": null
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"delta": {
|
||||||
|
"content": "================================",
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1726656044,
|
||||||
|
"id": "",
|
||||||
|
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"system_fingerprint": "2.2.1-dev0-native",
|
||||||
|
"usage": null
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"delta": {
|
||||||
|
"content": "=====",
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"finish_reason": null,
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1726656044,
|
||||||
|
"id": "",
|
||||||
|
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"system_fingerprint": "2.2.1-dev0-native",
|
||||||
|
"usage": null
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"delta": {
|
||||||
|
"content": "\n\n",
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"finish_reason": "length",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1726656044,
|
||||||
|
"id": "",
|
||||||
|
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"system_fingerprint": "2.2.1-dev0-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 10,
|
||||||
|
"prompt_tokens": 40,
|
||||||
|
"total_tokens": 50
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
|
@ -3,9 +3,7 @@ import requests
|
||||||
import json
|
import json
|
||||||
from aiohttp import ClientSession
|
from aiohttp import ClientSession
|
||||||
|
|
||||||
from text_generation.types import (
|
from text_generation.types import Completion, ChatCompletionChunk
|
||||||
Completion,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
|
@ -50,6 +48,114 @@ def test_flash_llama_completion_single_prompt(
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
|
async def test_flash_llama_completion_stream_usage(
|
||||||
|
flash_llama_completion, response_snapshot
|
||||||
|
):
|
||||||
|
url = f"{flash_llama_completion.base_url}/v1/chat/completions"
|
||||||
|
request = {
|
||||||
|
"model": "tgi",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What is Deep Learning?",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"max_tokens": 10,
|
||||||
|
"temperature": 0.0,
|
||||||
|
"stream_options": {"include_usage": True},
|
||||||
|
"stream": True,
|
||||||
|
}
|
||||||
|
string = ""
|
||||||
|
chunks = []
|
||||||
|
had_usage = False
|
||||||
|
async with ClientSession(headers=flash_llama_completion.headers) as session:
|
||||||
|
async with session.post(url, json=request) as response:
|
||||||
|
# iterate over the stream
|
||||||
|
async for chunk in response.content.iter_any():
|
||||||
|
# remove "data:"
|
||||||
|
chunk = chunk.decode().split("\n\n")
|
||||||
|
# remove "data:" if present
|
||||||
|
chunk = [c.replace("data:", "") for c in chunk]
|
||||||
|
# remove empty strings
|
||||||
|
chunk = [c for c in chunk if c]
|
||||||
|
# remove completion marking chunk
|
||||||
|
chunk = [c for c in chunk if c != " [DONE]"]
|
||||||
|
# parse json
|
||||||
|
chunk = [json.loads(c) for c in chunk]
|
||||||
|
|
||||||
|
for c in chunk:
|
||||||
|
chunks.append(ChatCompletionChunk(**c))
|
||||||
|
assert "choices" in c
|
||||||
|
if len(c["choices"]) == 1:
|
||||||
|
index = c["choices"][0]["index"]
|
||||||
|
assert index == 0
|
||||||
|
string += c["choices"][0]["delta"]["content"]
|
||||||
|
|
||||||
|
has_usage = c["usage"] is not None
|
||||||
|
assert not had_usage
|
||||||
|
if has_usage:
|
||||||
|
had_usage = True
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Expected different payload")
|
||||||
|
assert had_usage
|
||||||
|
assert (
|
||||||
|
string
|
||||||
|
== "**Deep Learning: An Overview**\n=====================================\n\n"
|
||||||
|
)
|
||||||
|
assert chunks == response_snapshot
|
||||||
|
|
||||||
|
request = {
|
||||||
|
"model": "tgi",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What is Deep Learning?",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"max_tokens": 10,
|
||||||
|
"temperature": 0.0,
|
||||||
|
"stream": True,
|
||||||
|
}
|
||||||
|
string = ""
|
||||||
|
chunks = []
|
||||||
|
had_usage = False
|
||||||
|
async with ClientSession(headers=flash_llama_completion.headers) as session:
|
||||||
|
async with session.post(url, json=request) as response:
|
||||||
|
# iterate over the stream
|
||||||
|
async for chunk in response.content.iter_any():
|
||||||
|
# remove "data:"
|
||||||
|
chunk = chunk.decode().split("\n\n")
|
||||||
|
# remove "data:" if present
|
||||||
|
chunk = [c.replace("data:", "") for c in chunk]
|
||||||
|
# remove empty strings
|
||||||
|
chunk = [c for c in chunk if c]
|
||||||
|
# remove completion marking chunk
|
||||||
|
chunk = [c for c in chunk if c != " [DONE]"]
|
||||||
|
# parse json
|
||||||
|
chunk = [json.loads(c) for c in chunk]
|
||||||
|
|
||||||
|
for c in chunk:
|
||||||
|
chunks.append(ChatCompletionChunk(**c))
|
||||||
|
assert "choices" in c
|
||||||
|
if len(c["choices"]) == 1:
|
||||||
|
index = c["choices"][0]["index"]
|
||||||
|
assert index == 0
|
||||||
|
string += c["choices"][0]["delta"]["content"]
|
||||||
|
|
||||||
|
has_usage = c["usage"] is not None
|
||||||
|
assert not had_usage
|
||||||
|
if has_usage:
|
||||||
|
had_usage = True
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Expected different payload")
|
||||||
|
assert not had_usage
|
||||||
|
assert (
|
||||||
|
string
|
||||||
|
== "**Deep Learning: An Overview**\n=====================================\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.release
|
@pytest.mark.release
|
||||||
def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot):
|
def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot):
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
|
|
|
@ -367,7 +367,11 @@ struct Args {
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
num_shard: Option<usize>,
|
num_shard: Option<usize>,
|
||||||
|
|
||||||
/// Whether you want the model to be quantized.
|
/// Quantization method to use for the model. It is not necessary to specify this option
|
||||||
|
/// for pre-quantized models, since the quantization method is read from the model
|
||||||
|
/// configuration.
|
||||||
|
///
|
||||||
|
/// Marlin kernels will be used automatically for GPTQ/AWQ models.
|
||||||
#[clap(long, env, value_enum)]
|
#[clap(long, env, value_enum)]
|
||||||
quantize: Option<Quantization>,
|
quantize: Option<Quantization>,
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,21 @@
|
||||||
|
{
|
||||||
|
buildPythonPackage,
|
||||||
|
poetry-core,
|
||||||
|
huggingface-hub,
|
||||||
|
pydantic,
|
||||||
|
}:
|
||||||
|
|
||||||
|
buildPythonPackage {
|
||||||
|
name = "text-generation";
|
||||||
|
|
||||||
|
src = ../clients/python;
|
||||||
|
|
||||||
|
pyproject = true;
|
||||||
|
|
||||||
|
build-system = [ poetry-core ];
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
huggingface-hub
|
||||||
|
pydantic
|
||||||
|
];
|
||||||
|
}
|
|
@ -684,6 +684,7 @@ pub(crate) struct ChatCompletionChunk {
|
||||||
pub model: String,
|
pub model: String,
|
||||||
pub system_fingerprint: String,
|
pub system_fingerprint: String,
|
||||||
pub choices: Vec<ChatCompletionChoice>,
|
pub choices: Vec<ChatCompletionChoice>,
|
||||||
|
pub usage: Option<Usage>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Serialize, ToSchema)]
|
#[derive(Clone, Serialize, ToSchema)]
|
||||||
|
@ -732,6 +733,7 @@ impl ChatCompletionChunk {
|
||||||
created: u64,
|
created: u64,
|
||||||
logprobs: Option<ChatCompletionLogprobs>,
|
logprobs: Option<ChatCompletionLogprobs>,
|
||||||
finish_reason: Option<String>,
|
finish_reason: Option<String>,
|
||||||
|
usage: Option<Usage>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let delta = match (delta, tool_calls) {
|
let delta = match (delta, tool_calls) {
|
||||||
(Some(delta), _) => ChatCompletionDelta::Chat(TextMessage {
|
(Some(delta), _) => ChatCompletionDelta::Chat(TextMessage {
|
||||||
|
@ -766,6 +768,7 @@ impl ChatCompletionChunk {
|
||||||
logprobs,
|
logprobs,
|
||||||
finish_reason,
|
finish_reason,
|
||||||
}],
|
}],
|
||||||
|
usage,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -880,6 +883,18 @@ pub(crate) struct ChatRequest {
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(nullable = true, default = "null", example = "null")]
|
#[schema(nullable = true, default = "null", example = "null")]
|
||||||
pub guideline: Option<String>,
|
pub guideline: Option<String>,
|
||||||
|
|
||||||
|
/// Options for streaming response. Only set this when you set stream: true.
|
||||||
|
#[serde(default)]
|
||||||
|
#[schema(nullable = true, example = "null")]
|
||||||
|
pub stream_options: Option<StreamOptions>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Deserialize, ToSchema, Serialize)]
|
||||||
|
struct StreamOptions {
|
||||||
|
/// If set, an additional chunk will be streamed before the data: [DONE] message. The usage field on this chunk shows the token usage statistics for the entire request, and the choices field will always be an empty array. All other chunks will also include a usage field, but with a null value.
|
||||||
|
#[schema(example = "true")]
|
||||||
|
include_usage: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn default_tool_prompt() -> String {
|
pub fn default_tool_prompt() -> String {
|
||||||
|
@ -1472,6 +1487,27 @@ mod tests {
|
||||||
let textmsg: TextMessage = message.into();
|
let textmsg: TextMessage = message.into();
|
||||||
assert_eq!(textmsg.content, "Whats in this image?![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png)");
|
assert_eq!(textmsg.content, "Whats in this image?![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png)");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_chat_stream_options() {
|
||||||
|
let json = json!({
|
||||||
|
"model": "",
|
||||||
|
"stream_options": {"include_usage": true},
|
||||||
|
"messages": [{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hello"
|
||||||
|
}]
|
||||||
|
});
|
||||||
|
let request: ChatRequest = serde_json::from_str(json.to_string().as_str()).unwrap();
|
||||||
|
|
||||||
|
assert!(matches!(
|
||||||
|
request.stream_options,
|
||||||
|
Some(StreamOptions {
|
||||||
|
include_usage: true
|
||||||
|
})
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn openai_output() {
|
fn openai_output() {
|
||||||
let message = OutputMessage::ChatMessage(TextMessage {
|
let message = OutputMessage::ChatMessage(TextMessage {
|
||||||
|
|
|
@ -13,8 +13,8 @@ use crate::{
|
||||||
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
|
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
|
||||||
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
|
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
|
||||||
HubProcessorConfig, HubTokenizerConfig, Info, Message, MessageChunk, MessageContent,
|
HubProcessorConfig, HubTokenizerConfig, Info, Message, MessageChunk, MessageContent,
|
||||||
OutputMessage, PrefillToken, SimpleToken, StreamDetails, StreamResponse, TextMessage, Token,
|
OutputMessage, PrefillToken, SimpleToken, StreamDetails, StreamOptions, StreamResponse,
|
||||||
TokenizeResponse, ToolCallDelta, ToolCallMessage, Url, Usage, Validation,
|
TextMessage, Token, TokenizeResponse, ToolCallDelta, ToolCallMessage, Url, Usage, Validation,
|
||||||
};
|
};
|
||||||
use crate::{
|
use crate::{
|
||||||
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
|
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
|
||||||
|
@ -1175,6 +1175,7 @@ async fn chat_completions(
|
||||||
seed,
|
seed,
|
||||||
stop,
|
stop,
|
||||||
stream,
|
stream,
|
||||||
|
stream_options,
|
||||||
tools,
|
tools,
|
||||||
tool_choice,
|
tool_choice,
|
||||||
tool_prompt,
|
tool_prompt,
|
||||||
|
@ -1265,6 +1266,28 @@ async fn chat_completions(
|
||||||
(content, None)
|
(content, None)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let (usage, finish_reason) = match stream_token.details {
|
||||||
|
Some(details) => {
|
||||||
|
let usage = if stream_options
|
||||||
|
.as_ref()
|
||||||
|
.map(|s| s.include_usage)
|
||||||
|
.unwrap_or(false)
|
||||||
|
{
|
||||||
|
let completion_tokens = details.generated_tokens;
|
||||||
|
let prompt_tokens = details.input_length;
|
||||||
|
let total_tokens = prompt_tokens + completion_tokens;
|
||||||
|
Some(Usage {
|
||||||
|
completion_tokens,
|
||||||
|
prompt_tokens,
|
||||||
|
total_tokens,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
(usage, Some(details.finish_reason.format(true)))
|
||||||
|
}
|
||||||
|
None => (None, None),
|
||||||
|
};
|
||||||
event
|
event
|
||||||
.json_data(CompletionType::ChatCompletionChunk(
|
.json_data(CompletionType::ChatCompletionChunk(
|
||||||
ChatCompletionChunk::new(
|
ChatCompletionChunk::new(
|
||||||
|
@ -1274,7 +1297,8 @@ async fn chat_completions(
|
||||||
tool_calls,
|
tool_calls,
|
||||||
current_time,
|
current_time,
|
||||||
logprobs,
|
logprobs,
|
||||||
stream_token.details.map(|d| d.finish_reason.format(true)),
|
finish_reason,
|
||||||
|
usage,
|
||||||
),
|
),
|
||||||
))
|
))
|
||||||
.unwrap_or_else(|e| {
|
.unwrap_or_else(|e| {
|
||||||
|
@ -1664,6 +1688,7 @@ StreamDetails,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
GrammarType,
|
GrammarType,
|
||||||
Usage,
|
Usage,
|
||||||
|
StreamOptions,
|
||||||
DeltaToolCall,
|
DeltaToolCall,
|
||||||
ToolType,
|
ToolType,
|
||||||
Tool,
|
Tool,
|
||||||
|
|
|
@ -1244,12 +1244,12 @@ files = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "moe-kernels"
|
name = "moe-kernels"
|
||||||
version = "0.2.2"
|
version = "0.3.1"
|
||||||
description = "MoE kernels"
|
description = "MoE kernels"
|
||||||
optional = true
|
optional = true
|
||||||
python-versions = ">=3.7"
|
python-versions = ">=3.7"
|
||||||
files = [
|
files = [
|
||||||
{file = "moe_kernels-0.2.2+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:d268d818932ddcbca9bc71021dc63b008aae832827a7c0484cf206bd59cfc9ab"},
|
{file = "moe_kernels-0.3.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:b679984a53807127f25af053ec0a2c07dec97ec196f76363a8bfdc3fbb3d1a9a"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
|
@ -1259,16 +1259,16 @@ triton = "*"
|
||||||
|
|
||||||
[package.source]
|
[package.source]
|
||||||
type = "url"
|
type = "url"
|
||||||
url = "https://github.com/danieldk/moe-kernels/releases/download/v0.2.2/moe_kernels-0.2.2+cu123torch2.4-cp310-cp310-linux_x86_64.whl"
|
url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "moe-kernels"
|
name = "moe-kernels"
|
||||||
version = "0.2.2"
|
version = "0.3.1"
|
||||||
description = "MoE kernels"
|
description = "MoE kernels"
|
||||||
optional = true
|
optional = true
|
||||||
python-versions = ">=3.7"
|
python-versions = ">=3.7"
|
||||||
files = [
|
files = [
|
||||||
{file = "moe_kernels-0.2.2+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:614bbc3f41b707b0c40372f0bb00e218ad0842d306f90bef28ce8e98e7fcb7cb"},
|
{file = "moe_kernels-0.3.1+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:29684f81495f6e032085295c86d160022f03d5d9a9981446f311ca94fbbbc2cd"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
|
@ -1278,16 +1278,16 @@ triton = "*"
|
||||||
|
|
||||||
[package.source]
|
[package.source]
|
||||||
type = "url"
|
type = "url"
|
||||||
url = "https://github.com/danieldk/moe-kernels/releases/download/v0.2.2/moe_kernels-0.2.2+cu123torch2.4-cp311-cp311-linux_x86_64.whl"
|
url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp311-cp311-linux_x86_64.whl"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "moe-kernels"
|
name = "moe-kernels"
|
||||||
version = "0.2.2"
|
version = "0.3.1"
|
||||||
description = "MoE kernels"
|
description = "MoE kernels"
|
||||||
optional = true
|
optional = true
|
||||||
python-versions = ">=3.7"
|
python-versions = ">=3.7"
|
||||||
files = [
|
files = [
|
||||||
{file = "moe_kernels-0.2.2+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:c2f48ed541353be03157d4015270dff797f7b7b8a664babdcbdf7414867d5abd"},
|
{file = "moe_kernels-0.3.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:9dfdbef48b5b7e97912aaa7420b1b694876a3281f5edfe7d4ca9a69e1f48bff2"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
|
@ -1297,16 +1297,16 @@ triton = "*"
|
||||||
|
|
||||||
[package.source]
|
[package.source]
|
||||||
type = "url"
|
type = "url"
|
||||||
url = "https://github.com/danieldk/moe-kernels/releases/download/v0.2.2/moe_kernels-0.2.2+cu123torch2.4-cp312-cp312-linux_x86_64.whl"
|
url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "moe-kernels"
|
name = "moe-kernels"
|
||||||
version = "0.2.2"
|
version = "0.3.1"
|
||||||
description = "MoE kernels"
|
description = "MoE kernels"
|
||||||
optional = true
|
optional = true
|
||||||
python-versions = ">=3.7"
|
python-versions = ">=3.7"
|
||||||
files = [
|
files = [
|
||||||
{file = "moe_kernels-0.2.2+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:d5f0339b73426c422872f7ff060433df6cd8e881451baf85ee7454e0e905f9d8"},
|
{file = "moe_kernels-0.3.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:f7d0fc8f191c905a668f3d2eb889999ee988048d08bfd7062d64bca3876588ae"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
|
@ -1316,7 +1316,7 @@ triton = "*"
|
||||||
|
|
||||||
[package.source]
|
[package.source]
|
||||||
type = "url"
|
type = "url"
|
||||||
url = "https://github.com/danieldk/moe-kernels/releases/download/v0.2.2/moe_kernels-0.2.2+cu123torch2.4-cp39-cp39-linux_x86_64.whl"
|
url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "mpmath"
|
name = "mpmath"
|
||||||
|
|
|
@ -47,10 +47,10 @@ marlin-kernels = [
|
||||||
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
|
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
|
||||||
]
|
]
|
||||||
moe-kernels = [
|
moe-kernels = [
|
||||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.2.2/moe_kernels-0.2.2+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
|
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
|
||||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.2.2/moe_kernels-0.2.2+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
|
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
|
||||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.2.2/moe_kernels-0.2.2+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
|
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
|
||||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.2.2/moe_kernels-0.2.2+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
|
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
|
||||||
]
|
]
|
||||||
rich = "^13.7.1"
|
rich = "^13.7.1"
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from text_generation_server.utils.import_utils import SYSTEM
|
||||||
from text_generation_server.models.globals import ATTENTION
|
from text_generation_server.models.globals import ATTENTION
|
||||||
import torch
|
import torch
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
@ -65,5 +66,7 @@ else:
|
||||||
max_k: int
|
max_k: int
|
||||||
|
|
||||||
def clamp(self, max):
|
def clamp(self, max):
|
||||||
|
if SYSTEM == "rocm":
|
||||||
|
return self
|
||||||
raise NotImplementedError("Not implemented seqlen for paged")
|
raise NotImplementedError("Not implemented seqlen for paged")
|
||||||
return Seqlen(torch.clamp(self.input_lengths, max=max))
|
return Seqlen(torch.clamp(self.input_lengths, max=max))
|
||||||
|
|
|
@ -20,8 +20,8 @@ from text_generation_server.utils.import_utils import SYSTEM
|
||||||
|
|
||||||
if SYSTEM == "rocm":
|
if SYSTEM == "rocm":
|
||||||
from text_generation_server.layers import grouped_topk
|
from text_generation_server.layers import grouped_topk
|
||||||
else:
|
elif SYSTEM != "ipex":
|
||||||
from vllm.model_executor.layers.fused_moe import grouped_topk
|
from moe_kernels.fused_moe import grouped_topk
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
Loading…
Reference in New Issue