Merge branch 'fix_rocm_fa' into rocm_6.2_fixes

This commit is contained in:
Mohit Sharma 2024-09-24 10:53:50 +00:00
commit afe3fed1a4
19 changed files with 508 additions and 61 deletions

View File

@ -38,4 +38,4 @@ jobs:
env: env:
HF_TOKEN: ${{ secrets.HF_TOKEN }} HF_TOKEN: ${{ secrets.HF_TOKEN }}
- name: Rust tests. - name: Rust tests.
run: nix build .#checks.$(nix eval --impure --raw --expr 'builtins.currentSystem').rust -L run: nix develop .#test --command cargo test

View File

@ -42,6 +42,7 @@ jobs:
sudo rm -rf /usr/share/dotnet # will release about 20GB if you don't need .NET sudo rm -rf /usr/share/dotnet # will release about 20GB if you don't need .NET
- name: Install - name: Install
run: | run: |
sudo apt update
sudo apt install python3.11-dev -y sudo apt install python3.11-dev -y
make install-cpu make install-cpu
- name: Run server tests - name: Run server tests

View File

@ -23,9 +23,11 @@ docs/openapi.json:
- '#/components/schemas/GenerateResponse/properties/details/nullable' - '#/components/schemas/GenerateResponse/properties/details/nullable'
- '#/components/schemas/StreamResponse/properties/details/nullable' - '#/components/schemas/StreamResponse/properties/details/nullable'
- '#/components/schemas/ChatRequest/properties/response_format/nullable' - '#/components/schemas/ChatRequest/properties/response_format/nullable'
- '#/components/schemas/ChatRequest/properties/stream_options/nullable'
- '#/components/schemas/ChatRequest/properties/tool_choice/nullable' - '#/components/schemas/ChatRequest/properties/tool_choice/nullable'
- '#/components/schemas/ToolChoice/nullable' - '#/components/schemas/ToolChoice/nullable'
- '#/components/schemas/ChatCompletionComplete/properties/logprobs/nullable' - '#/components/schemas/ChatCompletionComplete/properties/logprobs/nullable'
- '#/components/schemas/ChatCompletionChunk/properties/usage/nullable'
- '#/components/schemas/ChatCompletionChoice/properties/logprobs/nullable' - '#/components/schemas/ChatCompletionChoice/properties/logprobs/nullable'
no-invalid-media-type-examples: no-invalid-media-type-examples:
- '#/paths/~1/post/responses/422/content/application~1json/example' - '#/paths/~1/post/responses/422/content/application~1json/example'

View File

@ -5,7 +5,8 @@ members = [
"backends/grpc-metadata", "backends/grpc-metadata",
"backends/trtllm", "backends/trtllm",
"backends/client", "backends/client",
"launcher" "launcher",
"router"
] ]
default-members = [ default-members = [
"benchmark", "benchmark",
@ -13,7 +14,8 @@ default-members = [
"backends/grpc-metadata", "backends/grpc-metadata",
# "backends/trtllm", # "backends/trtllm",
"backends/client", "backends/client",
"launcher" "launcher",
"router"
] ]
resolver = "2" resolver = "2"

View File

@ -168,7 +168,7 @@ class ChatCompletionComplete(BaseModel):
# Log probabilities for the chat completion # Log probabilities for the chat completion
logprobs: Optional[Any] logprobs: Optional[Any]
# Reason for completion # Reason for completion
finish_reason: str finish_reason: Optional[str]
# Usage details of the chat completion # Usage details of the chat completion
usage: Optional[Any] = None usage: Optional[Any] = None
@ -191,6 +191,7 @@ class ChatCompletionChunk(BaseModel):
model: str model: str
system_fingerprint: str system_fingerprint: str
choices: List[Choice] choices: List[Choice]
usage: Optional[Any] = None
class Parameters(BaseModel): class Parameters(BaseModel):

View File

@ -742,6 +742,14 @@
}, },
"system_fingerprint": { "system_fingerprint": {
"type": "string" "type": "string"
},
"usage": {
"allOf": [
{
"$ref": "#/components/schemas/Usage"
}
],
"nullable": true
} }
} }
}, },
@ -937,6 +945,14 @@
"stream": { "stream": {
"type": "boolean" "type": "boolean"
}, },
"stream_options": {
"allOf": [
{
"$ref": "#/components/schemas/StreamOptions"
}
],
"nullable": true
},
"temperature": { "temperature": {
"type": "number", "type": "number",
"format": "float", "format": "float",
@ -1912,6 +1928,19 @@
} }
} }
}, },
"StreamOptions": {
"type": "object",
"required": [
"include_usage"
],
"properties": {
"include_usage": {
"type": "boolean",
"description": "If set, an additional chunk will be streamed before the data: [DONE] message. The usage field on this chunk shows the token usage statistics for the entire request, and the choices field will always be an empty array. All other chunks will also include a usage field, but with a null value.",
"example": "true"
}
}
},
"StreamResponse": { "StreamResponse": {
"type": "object", "type": "object",
"required": [ "required": [

View File

@ -55,7 +55,9 @@ Options:
## QUANTIZE ## QUANTIZE
```shell ```shell
--quantize <QUANTIZE> --quantize <QUANTIZE>
Whether you want the model to be quantized Quantization method to use for the model. It is not necessary to specify this option for pre-quantized models, since the quantization method is read from the model configuration.
Marlin kernels will be used automatically for GPTQ/AWQ models.
[env: QUANTIZE=] [env: QUANTIZE=]

View File

@ -479,11 +479,11 @@
"systems": "systems_6" "systems": "systems_6"
}, },
"locked": { "locked": {
"lastModified": 1710146030, "lastModified": 1726560853,
"narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=", "narHash": "sha256-X6rJYSESBVr3hBoH0WbKE5KvhPU5bloyZ2L4K60/fPQ=",
"owner": "numtide", "owner": "numtide",
"repo": "flake-utils", "repo": "flake-utils",
"rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a", "rev": "c1dfcf08411b08f6b8615f7d8971a2bfa81d5e8a",
"type": "github" "type": "github"
}, },
"original": { "original": {
@ -853,11 +853,11 @@
] ]
}, },
"locked": { "locked": {
"lastModified": 1726280639, "lastModified": 1726626348,
"narHash": "sha256-YfLRPlFZWrT2oRLNAoqf7G3+NnUTDdlIJk6tmBU7kXM=", "narHash": "sha256-sYV7e1B1yLcxo8/h+/hTwzZYmaju2oObNiy5iRI0C30=",
"owner": "oxalica", "owner": "oxalica",
"repo": "rust-overlay", "repo": "rust-overlay",
"rev": "e9f8641c92f26fd1e076e705edb12147c384171d", "rev": "6fd52ad8bd88f39efb2c999cc971921c2fb9f3a2",
"type": "github" "type": "github"
}, },
"original": { "original": {
@ -978,11 +978,11 @@
"nixpkgs": "nixpkgs_6" "nixpkgs": "nixpkgs_6"
}, },
"locked": { "locked": {
"lastModified": 1726229792, "lastModified": 1726743157,
"narHash": "sha256-9xsLmjc9nr7a4PTddKv2DOi82ompTtJNyjO6R67y5tE=", "narHash": "sha256-7OczwJsA47o+aUftMwkoh8R31DlNSl2FgRjqE8zAggk=",
"owner": "danieldk", "owner": "danieldk",
"repo": "tgi-nix", "repo": "tgi-nix",
"rev": "1a902f4818e94c3f8d95f6000db17bc3fadd0ce7", "rev": "bcc9fd01cf81bc42cebb999a736a377adfa8942f",
"type": "github" "type": "github"
}, },
"original": { "original": {

View File

@ -67,17 +67,26 @@
''; '';
}; };
server = pkgs.python3.pkgs.callPackage ./nix/server.nix { inherit nix-filter; }; server = pkgs.python3.pkgs.callPackage ./nix/server.nix { inherit nix-filter; };
client = pkgs.python3.pkgs.callPackage ./nix/client.nix { };
in in
{ {
checks = { checks = {
rust = with pkgs; rustPlatform.buildRustPackage { rust =
with pkgs;
rustPlatform.buildRustPackage {
name = "rust-checks"; name = "rust-checks";
src = ./.; src = ./.;
cargoLock = { cargoLock = {
lockFile = ./Cargo.lock; lockFile = ./Cargo.lock;
}; };
buildInputs = [ openssl.dev ]; buildInputs = [ openssl.dev ];
nativeBuildInputs = [ clippy pkg-config protobuf python3 rustfmt ]; nativeBuildInputs = [
clippy
pkg-config
protobuf
python3
rustfmt
];
buildPhase = '' buildPhase = ''
cargo check cargo check
''; '';
@ -89,9 +98,7 @@
installPhase = "touch $out"; installPhase = "touch $out";
}; };
}; };
formatter = pkgs.nixfmt-rfc-style; formatter = pkgs.nixfmt-rfc-style;
devShells = with pkgs; rec { devShells = with pkgs; rec {
default = pure; default = pure;
@ -106,10 +113,11 @@
test = mkShell { test = mkShell {
buildInputs = buildInputs =
[ [
# benchmark benchmark
# launcher launcher
# router router
server server
client
openssl.dev openssl.dev
pkg-config pkg-config
cargo cargo
@ -149,6 +157,7 @@
pyright pyright
pytest pytest
pytest-asyncio pytest-asyncio
redocly
ruff ruff
syrupy syrupy
]); ]);

View File

@ -0,0 +1,206 @@
[
{
"choices": [
{
"delta": {
"content": "**",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1726656043,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": "Deep",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1726656043,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": " Learning",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1726656043,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": ":",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1726656043,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": " An",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1726656043,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": " Overview",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1726656043,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": "**\n",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1726656044,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": "================================",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1726656044,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": "=====",
"role": "assistant",
"tool_calls": null
},
"finish_reason": null,
"index": 0,
"logprobs": null
}
],
"created": 1726656044,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native",
"usage": null
},
{
"choices": [
{
"delta": {
"content": "\n\n",
"role": "assistant",
"tool_calls": null
},
"finish_reason": "length",
"index": 0,
"logprobs": null
}
],
"created": 1726656044,
"id": "",
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
"system_fingerprint": "2.2.1-dev0-native",
"usage": {
"completion_tokens": 10,
"prompt_tokens": 40,
"total_tokens": 50
}
}
]

View File

@ -3,9 +3,7 @@ import requests
import json import json
from aiohttp import ClientSession from aiohttp import ClientSession
from text_generation.types import ( from text_generation.types import Completion, ChatCompletionChunk
Completion,
)
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
@ -50,6 +48,114 @@ def test_flash_llama_completion_single_prompt(
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
async def test_flash_llama_completion_stream_usage(
flash_llama_completion, response_snapshot
):
url = f"{flash_llama_completion.base_url}/v1/chat/completions"
request = {
"model": "tgi",
"messages": [
{
"role": "user",
"content": "What is Deep Learning?",
}
],
"max_tokens": 10,
"temperature": 0.0,
"stream_options": {"include_usage": True},
"stream": True,
}
string = ""
chunks = []
had_usage = False
async with ClientSession(headers=flash_llama_completion.headers) as session:
async with session.post(url, json=request) as response:
# iterate over the stream
async for chunk in response.content.iter_any():
# remove "data:"
chunk = chunk.decode().split("\n\n")
# remove "data:" if present
chunk = [c.replace("data:", "") for c in chunk]
# remove empty strings
chunk = [c for c in chunk if c]
# remove completion marking chunk
chunk = [c for c in chunk if c != " [DONE]"]
# parse json
chunk = [json.loads(c) for c in chunk]
for c in chunk:
chunks.append(ChatCompletionChunk(**c))
assert "choices" in c
if len(c["choices"]) == 1:
index = c["choices"][0]["index"]
assert index == 0
string += c["choices"][0]["delta"]["content"]
has_usage = c["usage"] is not None
assert not had_usage
if has_usage:
had_usage = True
else:
raise RuntimeError("Expected different payload")
assert had_usage
assert (
string
== "**Deep Learning: An Overview**\n=====================================\n\n"
)
assert chunks == response_snapshot
request = {
"model": "tgi",
"messages": [
{
"role": "user",
"content": "What is Deep Learning?",
}
],
"max_tokens": 10,
"temperature": 0.0,
"stream": True,
}
string = ""
chunks = []
had_usage = False
async with ClientSession(headers=flash_llama_completion.headers) as session:
async with session.post(url, json=request) as response:
# iterate over the stream
async for chunk in response.content.iter_any():
# remove "data:"
chunk = chunk.decode().split("\n\n")
# remove "data:" if present
chunk = [c.replace("data:", "") for c in chunk]
# remove empty strings
chunk = [c for c in chunk if c]
# remove completion marking chunk
chunk = [c for c in chunk if c != " [DONE]"]
# parse json
chunk = [json.loads(c) for c in chunk]
for c in chunk:
chunks.append(ChatCompletionChunk(**c))
assert "choices" in c
if len(c["choices"]) == 1:
index = c["choices"][0]["index"]
assert index == 0
string += c["choices"][0]["delta"]["content"]
has_usage = c["usage"] is not None
assert not had_usage
if has_usage:
had_usage = True
else:
raise RuntimeError("Expected different payload")
assert not had_usage
assert (
string
== "**Deep Learning: An Overview**\n=====================================\n\n"
)
@pytest.mark.release @pytest.mark.release
def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot): def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot):
response = requests.post( response = requests.post(

View File

@ -367,7 +367,11 @@ struct Args {
#[clap(long, env)] #[clap(long, env)]
num_shard: Option<usize>, num_shard: Option<usize>,
/// Whether you want the model to be quantized. /// Quantization method to use for the model. It is not necessary to specify this option
/// for pre-quantized models, since the quantization method is read from the model
/// configuration.
///
/// Marlin kernels will be used automatically for GPTQ/AWQ models.
#[clap(long, env, value_enum)] #[clap(long, env, value_enum)]
quantize: Option<Quantization>, quantize: Option<Quantization>,

21
nix/client.nix Normal file
View File

@ -0,0 +1,21 @@
{
buildPythonPackage,
poetry-core,
huggingface-hub,
pydantic,
}:
buildPythonPackage {
name = "text-generation";
src = ../clients/python;
pyproject = true;
build-system = [ poetry-core ];
dependencies = [
huggingface-hub
pydantic
];
}

View File

@ -684,6 +684,7 @@ pub(crate) struct ChatCompletionChunk {
pub model: String, pub model: String,
pub system_fingerprint: String, pub system_fingerprint: String,
pub choices: Vec<ChatCompletionChoice>, pub choices: Vec<ChatCompletionChoice>,
pub usage: Option<Usage>,
} }
#[derive(Clone, Serialize, ToSchema)] #[derive(Clone, Serialize, ToSchema)]
@ -732,6 +733,7 @@ impl ChatCompletionChunk {
created: u64, created: u64,
logprobs: Option<ChatCompletionLogprobs>, logprobs: Option<ChatCompletionLogprobs>,
finish_reason: Option<String>, finish_reason: Option<String>,
usage: Option<Usage>,
) -> Self { ) -> Self {
let delta = match (delta, tool_calls) { let delta = match (delta, tool_calls) {
(Some(delta), _) => ChatCompletionDelta::Chat(TextMessage { (Some(delta), _) => ChatCompletionDelta::Chat(TextMessage {
@ -766,6 +768,7 @@ impl ChatCompletionChunk {
logprobs, logprobs,
finish_reason, finish_reason,
}], }],
usage,
} }
} }
} }
@ -880,6 +883,18 @@ pub(crate) struct ChatRequest {
#[serde(default)] #[serde(default)]
#[schema(nullable = true, default = "null", example = "null")] #[schema(nullable = true, default = "null", example = "null")]
pub guideline: Option<String>, pub guideline: Option<String>,
/// Options for streaming response. Only set this when you set stream: true.
#[serde(default)]
#[schema(nullable = true, example = "null")]
pub stream_options: Option<StreamOptions>,
}
#[derive(Clone, Deserialize, ToSchema, Serialize)]
struct StreamOptions {
/// If set, an additional chunk will be streamed before the data: [DONE] message. The usage field on this chunk shows the token usage statistics for the entire request, and the choices field will always be an empty array. All other chunks will also include a usage field, but with a null value.
#[schema(example = "true")]
include_usage: bool,
} }
pub fn default_tool_prompt() -> String { pub fn default_tool_prompt() -> String {
@ -1472,6 +1487,27 @@ mod tests {
let textmsg: TextMessage = message.into(); let textmsg: TextMessage = message.into();
assert_eq!(textmsg.content, "Whats in this image?![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png)"); assert_eq!(textmsg.content, "Whats in this image?![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png)");
} }
#[test]
fn test_chat_stream_options() {
let json = json!({
"model": "",
"stream_options": {"include_usage": true},
"messages": [{
"role": "user",
"content": "Hello"
}]
});
let request: ChatRequest = serde_json::from_str(json.to_string().as_str()).unwrap();
assert!(matches!(
request.stream_options,
Some(StreamOptions {
include_usage: true
})
));
}
#[test] #[test]
fn openai_output() { fn openai_output() {
let message = OutputMessage::ChatMessage(TextMessage { let message = OutputMessage::ChatMessage(TextMessage {

View File

@ -13,8 +13,8 @@ use crate::{
usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName, usage_stats, BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName,
GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, GenerateParameters, GenerateRequest, GenerateResponse, GrammarType, HubModelInfo,
HubProcessorConfig, HubTokenizerConfig, Info, Message, MessageChunk, MessageContent, HubProcessorConfig, HubTokenizerConfig, Info, Message, MessageChunk, MessageContent,
OutputMessage, PrefillToken, SimpleToken, StreamDetails, StreamResponse, TextMessage, Token, OutputMessage, PrefillToken, SimpleToken, StreamDetails, StreamOptions, StreamResponse,
TokenizeResponse, ToolCallDelta, ToolCallMessage, Url, Usage, Validation, TextMessage, Token, TokenizeResponse, ToolCallDelta, ToolCallMessage, Url, Usage, Validation,
}; };
use crate::{ use crate::{
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete, ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
@ -1175,6 +1175,7 @@ async fn chat_completions(
seed, seed,
stop, stop,
stream, stream,
stream_options,
tools, tools,
tool_choice, tool_choice,
tool_prompt, tool_prompt,
@ -1265,6 +1266,28 @@ async fn chat_completions(
(content, None) (content, None)
}; };
let (usage, finish_reason) = match stream_token.details {
Some(details) => {
let usage = if stream_options
.as_ref()
.map(|s| s.include_usage)
.unwrap_or(false)
{
let completion_tokens = details.generated_tokens;
let prompt_tokens = details.input_length;
let total_tokens = prompt_tokens + completion_tokens;
Some(Usage {
completion_tokens,
prompt_tokens,
total_tokens,
})
} else {
None
};
(usage, Some(details.finish_reason.format(true)))
}
None => (None, None),
};
event event
.json_data(CompletionType::ChatCompletionChunk( .json_data(CompletionType::ChatCompletionChunk(
ChatCompletionChunk::new( ChatCompletionChunk::new(
@ -1274,7 +1297,8 @@ async fn chat_completions(
tool_calls, tool_calls,
current_time, current_time,
logprobs, logprobs,
stream_token.details.map(|d| d.finish_reason.format(true)), finish_reason,
usage,
), ),
)) ))
.unwrap_or_else(|e| { .unwrap_or_else(|e| {
@ -1664,6 +1688,7 @@ StreamDetails,
ErrorResponse, ErrorResponse,
GrammarType, GrammarType,
Usage, Usage,
StreamOptions,
DeltaToolCall, DeltaToolCall,
ToolType, ToolType,
Tool, Tool,

24
server/poetry.lock generated
View File

@ -1244,12 +1244,12 @@ files = [
[[package]] [[package]]
name = "moe-kernels" name = "moe-kernels"
version = "0.2.2" version = "0.3.1"
description = "MoE kernels" description = "MoE kernels"
optional = true optional = true
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
{file = "moe_kernels-0.2.2+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:d268d818932ddcbca9bc71021dc63b008aae832827a7c0484cf206bd59cfc9ab"}, {file = "moe_kernels-0.3.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:b679984a53807127f25af053ec0a2c07dec97ec196f76363a8bfdc3fbb3d1a9a"},
] ]
[package.dependencies] [package.dependencies]
@ -1259,16 +1259,16 @@ triton = "*"
[package.source] [package.source]
type = "url" type = "url"
url = "https://github.com/danieldk/moe-kernels/releases/download/v0.2.2/moe_kernels-0.2.2+cu123torch2.4-cp310-cp310-linux_x86_64.whl" url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl"
[[package]] [[package]]
name = "moe-kernels" name = "moe-kernels"
version = "0.2.2" version = "0.3.1"
description = "MoE kernels" description = "MoE kernels"
optional = true optional = true
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
{file = "moe_kernels-0.2.2+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:614bbc3f41b707b0c40372f0bb00e218ad0842d306f90bef28ce8e98e7fcb7cb"}, {file = "moe_kernels-0.3.1+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:29684f81495f6e032085295c86d160022f03d5d9a9981446f311ca94fbbbc2cd"},
] ]
[package.dependencies] [package.dependencies]
@ -1278,16 +1278,16 @@ triton = "*"
[package.source] [package.source]
type = "url" type = "url"
url = "https://github.com/danieldk/moe-kernels/releases/download/v0.2.2/moe_kernels-0.2.2+cu123torch2.4-cp311-cp311-linux_x86_64.whl" url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp311-cp311-linux_x86_64.whl"
[[package]] [[package]]
name = "moe-kernels" name = "moe-kernels"
version = "0.2.2" version = "0.3.1"
description = "MoE kernels" description = "MoE kernels"
optional = true optional = true
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
{file = "moe_kernels-0.2.2+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:c2f48ed541353be03157d4015270dff797f7b7b8a664babdcbdf7414867d5abd"}, {file = "moe_kernels-0.3.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:9dfdbef48b5b7e97912aaa7420b1b694876a3281f5edfe7d4ca9a69e1f48bff2"},
] ]
[package.dependencies] [package.dependencies]
@ -1297,16 +1297,16 @@ triton = "*"
[package.source] [package.source]
type = "url" type = "url"
url = "https://github.com/danieldk/moe-kernels/releases/download/v0.2.2/moe_kernels-0.2.2+cu123torch2.4-cp312-cp312-linux_x86_64.whl" url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl"
[[package]] [[package]]
name = "moe-kernels" name = "moe-kernels"
version = "0.2.2" version = "0.3.1"
description = "MoE kernels" description = "MoE kernels"
optional = true optional = true
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
{file = "moe_kernels-0.2.2+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:d5f0339b73426c422872f7ff060433df6cd8e881451baf85ee7454e0e905f9d8"}, {file = "moe_kernels-0.3.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:f7d0fc8f191c905a668f3d2eb889999ee988048d08bfd7062d64bca3876588ae"},
] ]
[package.dependencies] [package.dependencies]
@ -1316,7 +1316,7 @@ triton = "*"
[package.source] [package.source]
type = "url" type = "url"
url = "https://github.com/danieldk/moe-kernels/releases/download/v0.2.2/moe_kernels-0.2.2+cu123torch2.4-cp39-cp39-linux_x86_64.whl" url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl"
[[package]] [[package]]
name = "mpmath" name = "mpmath"

View File

@ -47,10 +47,10 @@ marlin-kernels = [
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true }, { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
] ]
moe-kernels = [ moe-kernels = [
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.2.2/moe_kernels-0.2.2+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true }, { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.2.2/moe_kernels-0.2.2+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true }, { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.2.2/moe_kernels-0.2.2+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true }, { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.2.2/moe_kernels-0.2.2+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true }, { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
] ]
rich = "^13.7.1" rich = "^13.7.1"

View File

@ -1,4 +1,5 @@
from dataclasses import dataclass from dataclasses import dataclass
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import ATTENTION from text_generation_server.models.globals import ATTENTION
import torch import torch
from typing import Optional from typing import Optional
@ -65,5 +66,7 @@ else:
max_k: int max_k: int
def clamp(self, max): def clamp(self, max):
if SYSTEM == "rocm":
return self
raise NotImplementedError("Not implemented seqlen for paged") raise NotImplementedError("Not implemented seqlen for paged")
return Seqlen(torch.clamp(self.input_lengths, max=max)) return Seqlen(torch.clamp(self.input_lengths, max=max))

View File

@ -20,8 +20,8 @@ from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM == "rocm": if SYSTEM == "rocm":
from text_generation_server.layers import grouped_topk from text_generation_server.layers import grouped_topk
else: elif SYSTEM != "ipex":
from vllm.model_executor.layers.fused_moe import grouped_topk from moe_kernels.fused_moe import grouped_topk
import torch import torch
import torch.distributed import torch.distributed