Merge branch 'main' into ci_amd3

This commit is contained in:
fxmarty 2024-07-16 15:15:17 +02:00
commit 291453fe88
65 changed files with 3782 additions and 932 deletions

View File

@ -30,6 +30,10 @@ jobs:
id: install-router
run: cargo install --path router/
- uses: actions/setup-node@v4
with:
node-version: 22
- name: Set up Python
uses: actions/setup-python@v2
with:
@ -37,4 +41,5 @@ jobs:
- name: Check that documentation is up-to-date
run: |
npm install -g swagger-cli
python update_doc.py --check

View File

@ -30,7 +30,7 @@ jobs:
group: ${{ github.workflow }}-build-and-push-image-${{ inputs.hardware }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
# TODO see with @Glegendre to get CPU runner here instead
runs-on: [self-hosted, nvidia-gpu , multi-gpu, 4-a10, ci]
runs-on: [self-hosted, intel-cpu, 32-cpu, 256-ram, ci]
permissions:
contents: write
packages: write
@ -142,8 +142,8 @@ jobs:
DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }}
tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }}
labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }}
cache-from: type=registry,ref=registry-push.github-runners.huggingface.tech/api-inference/community/text-generation-inference:cache${{ env.LABEL }},mode=min
cache-to: type=registry,ref=registry-push.github-runners.huggingface.tech/api-inference/community/text-generation-inference:cache${{ env.LABEL }},mode=min
cache-from: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL }},mode=min,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min
cache-to: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL }},mode=min,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min
- name: Final
id: final
run: |

36
Cargo.lock generated
View File

@ -1935,17 +1935,6 @@ version = "2.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3"
[[package]]
name = "metrics"
version = "0.21.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fde3af1a009ed76a778cb84fdef9e7dbbdf5775ae3e4cc1f434a6a307f6f76c5"
dependencies = [
"ahash",
"metrics-macros",
"portable-atomic",
]
[[package]]
name = "metrics"
version = "0.23.0"
@ -1969,7 +1958,7 @@ dependencies = [
"hyper-util",
"indexmap 2.2.6",
"ipnet",
"metrics 0.23.0",
"metrics",
"metrics-util",
"quanta",
"thiserror",
@ -1977,17 +1966,6 @@ dependencies = [
"tracing",
]
[[package]]
name = "metrics-macros"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38b4faf00617defe497754acde3024865bc143d44a86799b24e191ecff91354f"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.68",
]
[[package]]
name = "metrics-util"
version = "0.17.0"
@ -1997,7 +1975,7 @@ dependencies = [
"crossbeam-epoch",
"crossbeam-utils",
"hashbrown 0.14.5",
"metrics 0.23.0",
"metrics",
"num_cpus",
"quanta",
"sketches-ddsketch",
@ -3762,7 +3740,7 @@ dependencies = [
[[package]]
name = "text-generation-benchmark"
version = "2.1.1-dev0"
version = "2.1.2-dev0"
dependencies = [
"average",
"clap",
@ -3783,7 +3761,7 @@ dependencies = [
[[package]]
name = "text-generation-client"
version = "2.1.1-dev0"
version = "2.1.2-dev0"
dependencies = [
"async-trait",
"base64 0.22.1",
@ -3801,7 +3779,7 @@ dependencies = [
[[package]]
name = "text-generation-launcher"
version = "2.1.1-dev0"
version = "2.1.2-dev0"
dependencies = [
"clap",
"ctrlc",
@ -3820,7 +3798,7 @@ dependencies = [
[[package]]
name = "text-generation-router"
version = "2.1.1-dev0"
version = "2.1.2-dev0"
dependencies = [
"async-stream",
"axum 0.7.5",
@ -3834,7 +3812,7 @@ dependencies = [
"init-tracing-opentelemetry",
"itertools 0.10.5",
"jsonschema",
"metrics 0.21.1",
"metrics",
"metrics-exporter-prometheus",
"minijinja",
"minijinja-contrib",

View File

@ -40,7 +40,9 @@ RUN cargo build --profile release-opt
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS pytorch-install
# NOTE: When updating PyTorch version, beware to remove `pip install nvidia-nccl-cu12==2.22.3` below in the Dockerfile. Context: https://github.com/huggingface/text-generation-inference/pull/2099
ARG PYTORCH_VERSION=2.3.0
ARG PYTHON_VERSION=3.10
# Keep in sync with `server/pyproject.toml
ARG CUDA_VERSION=12.1
@ -241,7 +243,10 @@ COPY server/Makefile server/Makefile
RUN cd server && \
make gen-server && \
pip install -r requirements_cuda.txt && \
pip install ".[bnb, accelerate, quantize, peft, outlines]" --no-cache-dir
pip install ".[bnb, accelerate, quantize, peft, outlines]" --no-cache-dir && \
pip install nvidia-nccl-cu12==2.22.3
ENV LD_PRELOAD=/opt/conda/lib/python3.10/site-packages/nvidia/nccl/lib/libnccl.so.2
# Deps before the binaries
# The binaries change on every build given we burn the SHA into them

View File

@ -20,19 +20,20 @@ to power Hugging Chat, the Inference API and Inference Endpoint.
## Table of contents
- [Get Started](#get-started)
- [API Documentation](#api-documentation)
- [Using a private or gated model](#using-a-private-or-gated-model)
- [A note on Shared Memory](#a-note-on-shared-memory-shm)
- [Distributed Tracing](#distributed-tracing)
- [Local Install](#local-install)
- [CUDA Kernels](#cuda-kernels)
- [Optimized architectures](#optimized-architectures)
- [Run Mistral](#run-a-model)
- [Run](#run)
- [Quantization](#quantization)
- [Develop](#develop)
- [Testing](#testing)
- [Get Started](#get-started)
- [Docker](#docker)
- [API documentation](#api-documentation)
- [Using a private or gated model](#using-a-private-or-gated-model)
- [A note on Shared Memory (shm)](#a-note-on-shared-memory-shm)
- [Distributed Tracing](#distributed-tracing)
- [Architecture](#architecture)
- [Local install](#local-install)
- [Optimized architectures](#optimized-architectures)
- [Run locally](#run-locally)
- [Run](#run)
- [Quantization](#quantization)
- [Develop](#develop)
- [Testing](#testing)
Text Generation Inference (TGI) is a toolkit for deploying and serving Large Language Models (LLMs). TGI enables high-performance text generation for the most popular open-source LLMs, including Llama, Falcon, StarCoder, BLOOM, GPT-NeoX, and [more](https://huggingface.co/docs/text-generation-inference/supported_models). TGI implements many features, such as:

View File

@ -61,7 +61,7 @@ class ChoiceDeltaToolCall(BaseModel):
class ChoiceDelta(BaseModel):
role: str
content: Optional[str] = None
tool_calls: Optional[ChoiceDeltaToolCall]
tool_calls: Optional[ChoiceDeltaToolCall] = None
class Choice(BaseModel):

View File

@ -492,12 +492,12 @@
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/Completion"
"$ref": "#/components/schemas/CompletionFinal"
}
},
"text/event-stream": {
"schema": {
"$ref": "#/components/schemas/CompletionCompleteChunk"
"$ref": "#/components/schemas/Chunk"
}
}
}
@ -809,7 +809,6 @@
"ChatRequest": {
"type": "object",
"required": [
"model",
"messages"
],
"properties": {
@ -854,7 +853,8 @@
"model": {
"type": "string",
"description": "[UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.",
"example": "mistralai/Mistral-7B-Instruct-v0.2"
"example": "mistralai/Mistral-7B-Instruct-v0.2",
"nullable": true
},
"n": {
"type": "integer",
@ -1116,7 +1116,6 @@
"CompletionRequest": {
"type": "object",
"required": [
"model",
"prompt"
],
"properties": {
@ -1138,7 +1137,8 @@
"model": {
"type": "string",
"description": "UNUSED\nID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.",
"example": "mistralai/Mistral-7B-Instruct-v0.2"
"example": "mistralai/Mistral-7B-Instruct-v0.2",
"nullable": true
},
"prompt": {
"$ref": "#/components/schemas/Prompt"
@ -1324,6 +1324,17 @@
}
}
},
"FunctionName": {
"type": "object",
"required": [
"name"
],
"properties": {
"name": {
"type": "string"
}
}
},
"GenerateParameters": {
"type": "object",
"properties": {
@ -1708,6 +1719,72 @@
}
}
},
"MessageChunk": {
"oneOf": [
{
"type": "object",
"required": [
"text",
"type"
],
"properties": {
"text": {
"type": "string"
},
"type": {
"type": "string",
"enum": [
"text"
]
}
}
},
{
"type": "object",
"required": [
"image_url",
"type"
],
"properties": {
"image_url": {
"$ref": "#/components/schemas/Url"
},
"type": {
"type": "string",
"enum": [
"image_url"
]
}
}
}
],
"discriminator": {
"propertyName": "type"
}
},
"MessageContent": {
"oneOf": [
{
"type": "string"
},
{
"type": "array",
"items": {
"$ref": "#/components/schemas/MessageChunk"
}
}
]
},
"OutputMessage": {
"oneOf": [
{
"$ref": "#/components/schemas/TextMessage"
},
{
"$ref": "#/components/schemas/ToolCallMessage"
}
]
},
"PrefillToken": {
"type": "object",
"required": [
@ -1834,6 +1911,23 @@
}
}
},
"TextMessage": {
"type": "object",
"required": [
"role",
"content"
],
"properties": {
"content": {
"type": "string",
"example": "My name is David and I"
},
"role": {
"type": "string",
"example": "user"
}
}
},
"Token": {
"type": "object",
"required": [
@ -1906,6 +2000,41 @@
}
}
},
"ToolCallDelta": {
"type": "object",
"required": [
"role",
"tool_calls"
],
"properties": {
"role": {
"type": "string",
"example": "assistant"
},
"tool_calls": {
"$ref": "#/components/schemas/DeltaToolCall"
}
}
},
"ToolCallMessage": {
"type": "object",
"required": [
"role",
"tool_calls"
],
"properties": {
"role": {
"type": "string",
"example": "assistant"
},
"tool_calls": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ToolCall"
}
}
}
},
"ToolType": {
"oneOf": [
{
@ -1929,6 +2058,17 @@
}
]
},
"Url": {
"type": "object",
"required": [
"url"
],
"properties": {
"url": {
"type": "string"
}
}
},
"Usage": {
"type": "object",
"required": [

View File

@ -11,6 +11,8 @@
title: Using TGI with Intel Gaudi
- local: installation_inferentia
title: Using TGI with AWS Inferentia
- local: installation_intel
title: Using TGI with Intel GPUs
- local: installation
title: Installation from source
- local: supported_models

View File

@ -103,6 +103,7 @@ Several variants of the model server exist that are actively supported by Huggin
- By default, the model server will attempt building [a server optimized for Nvidia GPUs with CUDA](https://huggingface.co/docs/text-generation-inference/installation_nvidia). The code for this version is hosted in the [main TGI repository](https://github.com/huggingface/text-generation-inference).
- A [version optimized for AMD with ROCm](https://huggingface.co/docs/text-generation-inference/installation_amd) is hosted in the main TGI repository. Some model features differ.
- A [version optimized for Intel GPUs](https://huggingface.co/docs/text-generation-inference/installation_intel) is hosted in the main TGI repository. Some model features differ.
- The [version for Intel Gaudi](https://huggingface.co/docs/text-generation-inference/installation_gaudi) is maintained on a forked repository, often resynchronized with the main [TGI repository](https://github.com/huggingface/tgi-gaudi).
- A [version for Neuron (AWS Inferentia2)](https://huggingface.co/docs/text-generation-inference/installation_inferentia) is maintained as part of [Optimum Neuron](https://github.com/huggingface/optimum-neuron/tree/main/text-generation-inference).
- A version for Google TPUs is maintained as part of [Optimum TPU](https://github.com/huggingface/optimum-tpu/tree/main/text-generation-inference).

View File

@ -0,0 +1,19 @@
# Using TGI with Intel GPUs
TGI optimized models are supported on Intel Data Center GPU [Max1100](https://www.intel.com/content/www/us/en/products/sku/232876/intel-data-center-gpu-max-1100/specifications.html), [Max1550](https://www.intel.com/content/www/us/en/products/sku/232873/intel-data-center-gpu-max-1550/specifications.html), the recommended usage is through Docker.
On a server powered by Intel GPUs, TGI can be launched with the following command:
```bash
model=teknium/OpenHermes-2.5-Mistral-7B
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run --rm --privileged --cap-add=sys_nice \
--device=/dev/dri \
--ipc=host --shm-size 1g --net host -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:latest-intel \
--model-id $model --cuda-graphs 0
```
The launched TGI server can then be queried from clients, make sure to check out the [Consuming TGI](./basic_tutorials/consuming_tgi) guide.

View File

@ -17,7 +17,7 @@ docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
### Supported hardware
TGI supports various hardware. Make sure to check the [Using TGI with Nvidia GPUs](./installation_nvidia), [Using TGI with AMD GPUs](./installation_amd), [Using TGI with Gaudi](./installation_gaudi), [Using TGI with Inferentia](./installation_inferentia) guides depending on which hardware you would like to deploy TGI on.
TGI supports various hardware. Make sure to check the [Using TGI with Nvidia GPUs](./installation_nvidia), [Using TGI with AMD GPUs](./installation_amd), [Using TGI with Intel GPUs](./installation_intel), [Using TGI with Gaudi](./installation_gaudi), [Using TGI with Inferentia](./installation_inferentia) guides depending on which hardware you would like to deploy TGI on.
## Consuming TGI

View File

@ -347,6 +347,8 @@ def launcher(event_loop):
max_input_length: Optional[int] = None,
max_batch_prefill_tokens: Optional[int] = None,
max_total_tokens: Optional[int] = None,
lora_adapters: Optional[List[str]] = None,
cuda_graphs: Optional[List[int]] = None,
):
port = random.randint(8000, 10_000)
master_port = random.randint(10_000, 20_000)
@ -393,6 +395,14 @@ def launcher(event_loop):
if max_total_tokens:
args.append("--max-total-tokens")
args.append(str(max_total_tokens))
if lora_adapters:
args.append("--lora-adapters")
args.append(",".join(lora_adapters))
if cuda_graphs:
args.append("--cuda-graphs")
args.append(",".join(map(str, cuda_graphs)))
print(" ".join(args), file=sys.stderr)
env["LOG_LEVEL"] = "info,text_generation_router=debug"
@ -433,6 +443,8 @@ def launcher(event_loop):
max_input_length: Optional[int] = None,
max_batch_prefill_tokens: Optional[int] = None,
max_total_tokens: Optional[int] = None,
lora_adapters: Optional[List[str]] = None,
cuda_graphs: Optional[List[int]] = None,
):
port = random.randint(8000, 10_000)
@ -462,6 +474,12 @@ def launcher(event_loop):
if max_total_tokens:
args.append("--max-total-tokens")
args.append(str(max_total_tokens))
if lora_adapters:
args.append("--lora-adapters")
args.append(",".join(lora_adapters))
if cuda_graphs:
args.append("--cuda-graphs")
args.append(",".join(map(str, cuda_graphs)))
client = docker.from_env()

View File

@ -0,0 +1,251 @@
{
"details": {
"finish_reason": "length",
"generated_tokens": 40,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -0.27416992,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.17016602,
"special": false,
"text": "\n"
},
{
"id": 28737,
"logprob": -2.7109375,
"special": false,
"text": "I"
},
{
"id": 28809,
"logprob": -1.5,
"special": false,
"text": ""
},
{
"id": 28719,
"logprob": -0.34204102,
"special": false,
"text": "m"
},
{
"id": 459,
"logprob": -1.6914062,
"special": false,
"text": " not"
},
{
"id": 1864,
"logprob": -0.69140625,
"special": false,
"text": " sure"
},
{
"id": 513,
"logprob": -1.6171875,
"special": false,
"text": " if"
},
{
"id": 315,
"logprob": -1.3837891,
"special": false,
"text": " I"
},
{
"id": 541,
"logprob": -1.2226562,
"special": false,
"text": " can"
},
{
"id": 1567,
"logprob": -1.8652344,
"special": false,
"text": " come"
},
{
"id": 582,
"logprob": -0.0070228577,
"special": false,
"text": " up"
},
{
"id": 395,
"logprob": -0.0054092407,
"special": false,
"text": " with"
},
{
"id": 28705,
"logprob": -0.62597656,
"special": false,
"text": " "
},
{
"id": 28770,
"logprob": -0.0035572052,
"special": false,
"text": "3"
},
{
"id": 4842,
"logprob": -0.93603516,
"special": false,
"text": " unique"
},
{
"id": 3085,
"logprob": -0.028411865,
"special": false,
"text": " words"
},
{
"id": 369,
"logprob": -1.0400391,
"special": false,
"text": " that"
},
{
"id": 6685,
"logprob": -0.09710693,
"special": false,
"text": " describe"
},
{
"id": 528,
"logprob": -0.066467285,
"special": false,
"text": " me"
},
{
"id": 28725,
"logprob": -1.0722656,
"special": false,
"text": ","
},
{
"id": 562,
"logprob": -0.33422852,
"special": false,
"text": " but"
},
{
"id": 315,
"logprob": -0.5136719,
"special": false,
"text": " I"
},
{
"id": 28809,
"logprob": -0.8989258,
"special": false,
"text": ""
},
{
"id": 584,
"logprob": -0.2076416,
"special": false,
"text": "ll"
},
{
"id": 1464,
"logprob": -0.8808594,
"special": false,
"text": " try"
},
{
"id": 28723,
"logprob": -0.88427734,
"special": false,
"text": "."
},
{
"id": 13,
"logprob": -0.91064453,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.08105469,
"special": false,
"text": "\n"
},
{
"id": 28740,
"logprob": -1.8486328,
"special": false,
"text": "1"
},
{
"id": 28723,
"logprob": -0.111572266,
"special": false,
"text": "."
},
{
"id": 23626,
"logprob": -3.15625,
"special": false,
"text": " Creative"
},
{
"id": 13,
"logprob": -0.9194336,
"special": false,
"text": "\n"
},
{
"id": 28750,
"logprob": -0.24841309,
"special": false,
"text": "2"
},
{
"id": 28723,
"logprob": -9.393692e-05,
"special": false,
"text": "."
},
{
"id": 6785,
"logprob": -3.1386719,
"special": false,
"text": " Fun"
},
{
"id": 1780,
"logprob": -0.53564453,
"special": false,
"text": "ny"
},
{
"id": 13,
"logprob": -0.09033203,
"special": false,
"text": "\n"
},
{
"id": 28770,
"logprob": -0.00466156,
"special": false,
"text": "3"
},
{
"id": 28723,
"logprob": -0.00016450882,
"special": false,
"text": "."
}
]
},
"generated_text": "\n\nIm not sure if I can come up with 3 unique words that describe me, but Ill try.\n\n1. Creative\n2. Funny\n3."
}

View File

@ -0,0 +1,53 @@
{
"details": {
"finish_reason": "eos_token",
"generated_tokens": 7,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 1,
"logprob": -0.49658203,
"special": true,
"text": "<s>"
},
{
"id": 28705,
"logprob": -0.0016384125,
"special": false,
"text": " "
},
{
"id": 1,
"logprob": -1.4931641,
"special": true,
"text": "<s>"
},
{
"id": 28705,
"logprob": -0.00075769424,
"special": false,
"text": " "
},
{
"id": 28740,
"logprob": -0.25024414,
"special": false,
"text": "1"
},
{
"id": 28740,
"logprob": -0.2631836,
"special": false,
"text": "1"
},
{
"id": 2,
"logprob": -0.0003285408,
"special": true,
"text": "</s>"
}
]
},
"generated_text": " 11"
}

View File

@ -0,0 +1,251 @@
{
"details": {
"finish_reason": "length",
"generated_tokens": 40,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -1.0488281,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -1.0800781,
"special": false,
"text": "\n"
},
{
"id": 27332,
"logprob": -2.1152344,
"special": false,
"text": "###"
},
{
"id": 28705,
"logprob": -1.6748047,
"special": false,
"text": " "
},
{
"id": 28740,
"logprob": -0.097229004,
"special": false,
"text": "1"
},
{
"id": 28723,
"logprob": -0.16467285,
"special": false,
"text": "."
},
{
"id": 7615,
"logprob": -2.2246094,
"special": false,
"text": " News"
},
{
"id": 13,
"logprob": -1.0488281,
"special": false,
"text": "\n"
},
{
"id": 27332,
"logprob": -0.69189453,
"special": false,
"text": "###"
},
{
"id": 28705,
"logprob": -0.013343811,
"special": false,
"text": " "
},
{
"id": 28750,
"logprob": -0.011230469,
"special": false,
"text": "2"
},
{
"id": 28723,
"logprob": -0.00096845627,
"special": false,
"text": "."
},
{
"id": 21095,
"logprob": -2.5605469,
"special": false,
"text": " Blog"
},
{
"id": 13,
"logprob": -0.19458008,
"special": false,
"text": "\n"
},
{
"id": 27332,
"logprob": -0.031280518,
"special": false,
"text": "###"
},
{
"id": 28705,
"logprob": -0.0030708313,
"special": false,
"text": " "
},
{
"id": 28770,
"logprob": -0.0029277802,
"special": false,
"text": "3"
},
{
"id": 28723,
"logprob": -0.0012350082,
"special": false,
"text": "."
},
{
"id": 20108,
"logprob": -2.1582031,
"special": false,
"text": " Article"
},
{
"id": 13,
"logprob": -0.05810547,
"special": false,
"text": "\n"
},
{
"id": 27332,
"logprob": -0.35083008,
"special": false,
"text": "###"
},
{
"id": 28705,
"logprob": -0.034332275,
"special": false,
"text": " "
},
{
"id": 28781,
"logprob": -0.009666443,
"special": false,
"text": "4"
},
{
"id": 28723,
"logprob": -0.0013113022,
"special": false,
"text": "."
},
{
"id": 8349,
"logprob": -2.6191406,
"special": false,
"text": " Review"
},
{
"id": 13,
"logprob": -0.04031372,
"special": false,
"text": "\n"
},
{
"id": 27332,
"logprob": -0.45239258,
"special": false,
"text": "###"
},
{
"id": 28705,
"logprob": -0.045410156,
"special": false,
"text": " "
},
{
"id": 28782,
"logprob": -0.0041236877,
"special": false,
"text": "5"
},
{
"id": 28723,
"logprob": -0.0010223389,
"special": false,
"text": "."
},
{
"id": 5299,
"logprob": -2.8066406,
"special": false,
"text": " Other"
},
{
"id": 13,
"logprob": -0.12054443,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.44580078,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -1.4921875,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -1.3574219,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -1.0039062,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.5859375,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.43481445,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.2783203,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.20410156,
"special": false,
"text": "\n"
}
]
},
"generated_text": "\n\n### 1. News\n### 2. Blog\n### 3. Article\n### 4. Review\n### 5. Other\n\n\n\n\n\n\n\n\n"
}

View File

@ -0,0 +1,251 @@
{
"details": {
"finish_reason": "length",
"generated_tokens": 40,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -0.31347656,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.27441406,
"special": false,
"text": "\n"
},
{
"id": 28737,
"logprob": -2.2285156,
"special": false,
"text": "I"
},
{
"id": 28809,
"logprob": -1.4677734,
"special": false,
"text": ""
},
{
"id": 28719,
"logprob": -0.31762695,
"special": false,
"text": "m"
},
{
"id": 264,
"logprob": -1.6865234,
"special": false,
"text": " a"
},
{
"id": 1215,
"logprob": -3.2695312,
"special": false,
"text": " very"
},
{
"id": 20640,
"logprob": -3.1230469,
"special": false,
"text": " passionate"
},
{
"id": 1338,
"logprob": -0.48339844,
"special": false,
"text": " person"
},
{
"id": 28723,
"logprob": -0.9970703,
"special": false,
"text": "."
},
{
"id": 315,
"logprob": -0.5498047,
"special": false,
"text": " I"
},
{
"id": 28809,
"logprob": -1.1923828,
"special": false,
"text": ""
},
{
"id": 28719,
"logprob": -0.080444336,
"special": false,
"text": "m"
},
{
"id": 1215,
"logprob": -1.8271484,
"special": false,
"text": " very"
},
{
"id": 12215,
"logprob": -2.8847656,
"special": false,
"text": " driven"
},
{
"id": 28723,
"logprob": -1.0927734,
"special": false,
"text": "."
},
{
"id": 315,
"logprob": -0.4584961,
"special": false,
"text": " I"
},
{
"id": 28809,
"logprob": -0.5019531,
"special": false,
"text": ""
},
{
"id": 28719,
"logprob": -0.030715942,
"special": false,
"text": "m"
},
{
"id": 1215,
"logprob": -0.96972656,
"special": false,
"text": " very"
},
{
"id": 7798,
"logprob": -2.8847656,
"special": false,
"text": " determined"
},
{
"id": 28723,
"logprob": -0.27319336,
"special": false,
"text": "."
},
{
"id": 13,
"logprob": -0.56396484,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.011016846,
"special": false,
"text": "\n"
},
{
"id": 3195,
"logprob": -0.7163086,
"special": false,
"text": "What"
},
{
"id": 349,
"logprob": -1.1611328,
"special": false,
"text": " is"
},
{
"id": 574,
"logprob": -0.515625,
"special": false,
"text": " your"
},
{
"id": 6656,
"logprob": -1.0253906,
"special": false,
"text": " favorite"
},
{
"id": 1970,
"logprob": -2.1738281,
"special": false,
"text": " thing"
},
{
"id": 684,
"logprob": -0.48364258,
"special": false,
"text": " about"
},
{
"id": 1250,
"logprob": -1.8876953,
"special": false,
"text": " being"
},
{
"id": 264,
"logprob": -0.41967773,
"special": false,
"text": " a"
},
{
"id": 8626,
"logprob": -2.9160156,
"special": false,
"text": " teacher"
},
{
"id": 28804,
"logprob": -0.11920166,
"special": false,
"text": "?"
},
{
"id": 13,
"logprob": -0.023727417,
"special": false,
"text": "\n"
},
{
"id": 13,
"logprob": -0.010848999,
"special": false,
"text": "\n"
},
{
"id": 28737,
"logprob": -1.0566406,
"special": false,
"text": "I"
},
{
"id": 2016,
"logprob": -0.7163086,
"special": false,
"text": " love"
},
{
"id": 272,
"logprob": -1.9169922,
"special": false,
"text": " the"
},
{
"id": 1639,
"logprob": -2.03125,
"special": false,
"text": " fact"
}
]
},
"generated_text": "\n\nIm a very passionate person. Im very driven. Im very determined.\n\nWhat is your favorite thing about being a teacher?\n\nI love the fact"
}

View File

@ -0,0 +1,134 @@
import pytest
import requests
@pytest.fixture(scope="module")
def lora_mistral_handle(launcher):
with launcher(
"mistralai/Mistral-7B-v0.1",
lora_adapters=[
"predibase/dbpedia",
"predibase/customer_support",
],
cuda_graphs=[0],
) as handle:
yield handle
@pytest.fixture(scope="module")
async def lora_mistral(lora_mistral_handle):
await lora_mistral_handle.health(300)
return lora_mistral_handle.client
@pytest.mark.asyncio
@pytest.mark.private
async def test_lora_mistral(lora_mistral, response_snapshot):
response = await lora_mistral.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
classification_prompt = """You are given the title and the body of an article below. Please determine the type of the article.\n### Title: Great White Whale\n\n### Body: Great White Whale is the debut album by the Canadian rock band Secret and Whisper. The album was in the works for about a year and was released on February 12 2008. A music video was shot in Pittsburgh for the album's first single XOXOXO. The album reached number 17 on iTunes's top 100 albums in its first week on sale.\n\n### Article Type:"""
@pytest.mark.asyncio
@pytest.mark.private
async def test_lora_mistral_without_adapter(lora_mistral, response_snapshot):
response = requests.post(
f"{lora_mistral.base_url}/generate",
headers=lora_mistral.headers,
json={
"inputs": classification_prompt,
"parameters": {
"max_new_tokens": 40,
"details": True,
},
},
)
assert response.status_code == 200
data = response.json()
assert (
data["generated_text"]
== "\n\n### 1. News\n### 2. Blog\n### 3. Article\n### 4. Review\n### 5. Other\n\n\n\n\n\n\n\n\n"
)
assert data == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_lora_mistral_with_dbpedia_adapter(lora_mistral, response_snapshot):
response = requests.post(
f"{lora_mistral.base_url}/generate",
headers=lora_mistral.headers,
json={
"inputs": classification_prompt,
"parameters": {
"max_new_tokens": 40,
"adapter_id": "predibase/dbpedia",
"details": True,
},
},
)
assert response.status_code == 200
data = response.json()
assert data["generated_text"] == " 11"
assert data == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_lora_mistral_with_customer_support_adapter(
lora_mistral, response_snapshot
):
print(lora_mistral.base_url)
print(lora_mistral.headers)
response = requests.post(
f"{lora_mistral.base_url}/generate",
headers=lora_mistral.headers,
json={
"inputs": "What are 3 unique words that describe you?",
"parameters": {
"max_new_tokens": 40,
"adapter_id": "predibase/customer_support",
"details": True,
},
},
)
assert response.status_code == 200
data = response.json()
assert (
data["generated_text"]
== "\n\nIm not sure if I can come up with 3 unique words that describe me, but Ill try.\n\n1. Creative\n2. Funny\n3."
)
assert data == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_lora_mistral_without_customer_support_adapter(
lora_mistral, response_snapshot
):
response = requests.post(
f"{lora_mistral.base_url}/generate",
headers=lora_mistral.headers,
json={
"inputs": "What are 3 unique words that describe you?",
"parameters": {
"max_new_tokens": 40,
"details": True,
},
},
)
assert response.status_code == 200
data = response.json()
assert (
data["generated_text"]
== "\n\nIm a very passionate person. Im very driven. Im very determined.\n\nWhat is your favorite thing about being a teacher?\n\nI love the fact"
)
assert data == response_snapshot

View File

@ -24,7 +24,7 @@ futures = "0.3.28"
hf-hub = { workspace = true }
itertools = "0.10"
jsonschema = { version = "0.17.1", features = ["draft202012"] }
metrics = "0.21.1"
metrics = "0.23.0"
metrics-exporter-prometheus = { version = "0.15.1", features = [] }
nohash-hasher = "0.2.0"
opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }

View File

@ -91,14 +91,14 @@ impl Infer {
.limit_concurrent_requests
.try_acquire_owned()
.map_err(|err| {
metrics::increment_counter!("tgi_request_failure", "err" => "overloaded");
metrics::counter!("tgi_request_failure", "err" => "overloaded").increment(1);
tracing::error!("{err}");
err
})?;
// Validate request
let valid_request = self.validation.validate(request).await.map_err(|err| {
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
tracing::error!("{err}");
err
})?;
@ -140,7 +140,7 @@ impl Infer {
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
.apply(messages, grammar_with_prompt)
.map_err(|e| {
metrics::increment_counter!("tgi_request_failure", "err" => "template");
metrics::counter!("tgi_request_failure", "err" => "template").increment(1);
tracing::error!("{e}");
e
})
@ -214,7 +214,7 @@ impl Infer {
})
} else {
let err = InferError::IncompleteGeneration;
metrics::increment_counter!("tgi_request_failure", "err" => "incomplete");
metrics::counter!("tgi_request_failure", "err" => "incomplete").increment(1);
tracing::error!("{err}");
Err(err)
}

View File

@ -111,7 +111,7 @@ async fn queue_task(
match cmd {
QueueCommand::Append(entry, span) => {
span.in_scope(|| state.append(*entry));
metrics::increment_gauge!("tgi_queue_size", 1.0);
metrics::gauge!("tgi_queue_size").increment(1.0);
}
QueueCommand::NextBatch {
min_size,
@ -124,7 +124,7 @@ async fn queue_task(
let next_batch =
state.next_batch(min_size, max_size, prefill_token_budget, token_budget);
response_sender.send(next_batch).unwrap();
metrics::gauge!("tgi_queue_size", state.entries.len() as f64);
metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64);
}),
}
}
@ -226,7 +226,7 @@ impl State {
// Filter entries where the response receiver was dropped (== entries where the request
// was dropped by the client)
if entry.response_tx.is_closed() {
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
tracing::debug!("Dropping entry");
continue;
}
@ -336,7 +336,7 @@ impl State {
// Increment batch id
self.next_batch_id += 1;
metrics::histogram!("tgi_batch_next_size", batch.size as f64);
metrics::histogram!("tgi_batch_next_size").record(batch.size as f64);
Some((batch_entries, batch, next_batch_span))
}

View File

@ -148,8 +148,8 @@ pub(crate) async fn batching_task(
let batch_size = batch.size;
let batch_max_tokens = batch.max_tokens;
let mut batches = vec![batch];
metrics::gauge!("tgi_batch_current_size", batch_size as f64);
metrics::gauge!("tgi_batch_current_max_tokens", batch_max_tokens as f64);
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
let min_size = if waiting_tokens >= max_waiting_tokens {
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
@ -170,9 +170,11 @@ pub(crate) async fn batching_task(
{
// Tracking metrics
if min_size.is_some() {
metrics::increment_counter!("tgi_batch_concat", "reason" => "backpressure");
metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
.increment(1);
} else {
metrics::increment_counter!("tgi_batch_concat", "reason" => "wait_exceeded");
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
.increment(1);
}
entries.iter_mut().for_each(|(_, entry)| {
@ -219,8 +221,8 @@ pub(crate) async fn batching_task(
.await;
waiting_tokens += 1;
}
metrics::gauge!("tgi_batch_current_size", 0.0);
metrics::gauge!("tgi_batch_current_max_tokens", 0.0);
metrics::gauge!("tgi_batch_current_size").set(0.0);
metrics::gauge!("tgi_batch_current_max_tokens").set(0.0);
}
}
}
@ -234,7 +236,7 @@ async fn prefill(
) -> Option<CachedBatch> {
let start_time = Instant::now();
let batch_id = batch.id;
metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill");
metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);
match client.prefill(batch).await {
Ok((generations, next_batch, timings)) => {
@ -248,11 +250,15 @@ async fn prefill(
// Filter next batch and remove requests that were stopped
let next_batch = filter_batch(client, next_batch, entries).await;
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill");
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill");
metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "prefill");
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill");
metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill");
metrics::histogram!("tgi_batch_forward_duration","method" => "prefill")
.record(timings.forward.as_secs_f64());
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
.record(timings.decode.as_secs_f64());
metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill")
.record(start_filtering_time.elapsed().as_secs_f64());
metrics::histogram!("tgi_batch_inference_duration","method" => "prefill")
.record(start_time.elapsed().as_secs_f64());
metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1);
next_batch
}
// If we have an error, we discard the whole batch
@ -261,7 +267,7 @@ async fn prefill(
generation_health.store(false, Ordering::SeqCst);
let _ = client.clear_cache(Some(batch_id)).await;
send_errors(err, entries);
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill");
metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1);
None
}
}
@ -276,7 +282,7 @@ async fn decode(
) -> Option<CachedBatch> {
let start_time = Instant::now();
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode");
metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1);
match client.decode(batches).await {
Ok((generations, next_batch, timings)) => {
@ -291,13 +297,18 @@ async fn decode(
let next_batch = filter_batch(client, next_batch, entries).await;
if let Some(concat_duration) = timings.concat {
metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode");
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
.record(concat_duration.as_secs_f64());
}
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "decode");
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "decode");
metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "decode");
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode");
metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode");
metrics::histogram!("tgi_batch_forward_duration", "method" => "decode")
.record(timings.forward.as_secs_f64());
metrics::histogram!("tgi_batch_decode_duration", "method" => "decode")
.record(timings.decode.as_secs_f64());
metrics::histogram!("tgi_batch_filter_duration", "method" => "decode")
.record(start_filtering_time.elapsed().as_secs_f64());
metrics::histogram!("tgi_batch_inference_duration", "method" => "decode")
.record(start_time.elapsed().as_secs_f64());
metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1);
next_batch
}
// If we have an error, we discard the whole batch
@ -307,7 +318,7 @@ async fn decode(
let _ = client.clear_cache(Some(id)).await;
}
send_errors(err, entries);
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode");
metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1);
None
}
}
@ -365,7 +376,7 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
// request and we need to stop generating hence why we unwrap_or(true)
let stopped = send_responses(generation, entry).map_err(|err| {
tracing::error!("Entry response channel error.");
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
err
}).unwrap_or(true);
if stopped {
@ -381,7 +392,7 @@ fn send_responses(
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
// Return directly if the channel is disconnected
if entry.response_tx.is_closed() {
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
return Ok(true);
}
@ -407,7 +418,7 @@ fn send_responses(
// Create last Token
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
let n = tokens_.ids.len();
metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64);
metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64);
let mut iterator = tokens_
.ids
.into_iter()
@ -472,7 +483,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
// Create and enter a span to link this function back to the entry
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
let err = InferError::GenerationError(error.to_string());
metrics::increment_counter!("tgi_request_failure", "err" => "generation");
metrics::counter!("tgi_request_failure", "err" => "generation").increment(1);
tracing::error!("{err}");
// unwrap_or is valid here as we don't care if the receiver is gone.

View File

@ -126,7 +126,7 @@ async fn queue_task(
match cmd {
QueueCommand::Append(entry, span) => {
span.in_scope(|| state.append(*entry));
metrics::increment_gauge!("tgi_queue_size", 1.0);
metrics::gauge!("tgi_queue_size").increment(1.0);
}
QueueCommand::NextBatch {
min_size,
@ -141,7 +141,7 @@ async fn queue_task(
.instrument(span)
.await;
response_sender.send(next_batch).unwrap();
metrics::gauge!("tgi_queue_size", state.entries.len() as f64);
metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64);
}
}
}
@ -248,7 +248,7 @@ impl State {
// Filter entries where the response receiver was dropped (== entries where the request
// was dropped by the client)
if entry.response_tx.is_closed() {
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
tracing::debug!("Dropping entry");
continue;
}
@ -399,7 +399,7 @@ impl State {
// Increment batch id
self.next_batch_id += 1;
metrics::histogram!("tgi_batch_next_size", batch.size as f64);
metrics::histogram!("tgi_batch_next_size").record(batch.size as f64);
Some((batch_entries, batch, next_batch_span))
}

View File

@ -154,8 +154,8 @@ pub(crate) async fn batching_task(
let batch_size = batch.size;
let batch_max_tokens = batch.max_tokens;
let mut batches = vec![batch];
metrics::gauge!("tgi_batch_current_size", batch_size as f64);
metrics::gauge!("tgi_batch_current_max_tokens", batch_max_tokens as f64);
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
let min_size = if waiting_tokens >= max_waiting_tokens {
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
@ -176,9 +176,11 @@ pub(crate) async fn batching_task(
{
// Tracking metrics
if min_size.is_some() {
metrics::increment_counter!("tgi_batch_concat", "reason" => "backpressure");
metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
.increment(1);
} else {
metrics::increment_counter!("tgi_batch_concat", "reason" => "wait_exceeded");
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
.increment(1);
}
entries.iter_mut().for_each(|(_, entry)| {
@ -225,8 +227,8 @@ pub(crate) async fn batching_task(
.await;
waiting_tokens += 1;
}
metrics::gauge!("tgi_batch_current_size", 0.0);
metrics::gauge!("tgi_batch_current_max_tokens", 0.0);
metrics::gauge!("tgi_batch_current_size").set(0.0);
metrics::gauge!("tgi_batch_current_max_tokens").set(0.0);
}
}
}
@ -240,7 +242,7 @@ async fn prefill(
) -> Option<CachedBatch> {
let start_time = Instant::now();
let batch_id = batch.id;
metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill");
metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);
match client.prefill(batch).await {
Ok((generations, next_batch, timings)) => {
@ -254,11 +256,15 @@ async fn prefill(
// Filter next batch and remove requests that were stopped
let next_batch = filter_batch(client, next_batch, entries).await;
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill");
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill");
metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "prefill");
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill");
metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill");
metrics::histogram!("tgi_batch_forward_duration","method" => "prefill")
.record(timings.forward.as_secs_f64());
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
.record(timings.decode.as_secs_f64());
metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill")
.record(start_filtering_time.elapsed().as_secs_f64());
metrics::histogram!("tgi_batch_inference_duration", "method" => "prefill")
.record(start_time.elapsed().as_secs_f64());
metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1);
next_batch
}
// If we have an error, we discard the whole batch
@ -267,7 +273,7 @@ async fn prefill(
generation_health.store(false, Ordering::SeqCst);
let _ = client.clear_cache(Some(batch_id)).await;
send_errors(err, entries);
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill");
metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1);
None
}
}
@ -282,7 +288,7 @@ async fn decode(
) -> Option<CachedBatch> {
let start_time = Instant::now();
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode");
metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1);
match client.decode(batches).await {
Ok((generations, next_batch, timings)) => {
@ -297,13 +303,18 @@ async fn decode(
let next_batch = filter_batch(client, next_batch, entries).await;
if let Some(concat_duration) = timings.concat {
metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode");
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
.record(concat_duration.as_secs_f64());
}
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "decode");
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "decode");
metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "decode");
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode");
metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode");
metrics::histogram!("tgi_batch_forward_duration", "method" => "decode")
.record(timings.forward.as_secs_f64());
metrics::histogram!("tgi_batch_decode_duration", "method" => "decode")
.record(timings.decode.as_secs_f64());
metrics::histogram!("tgi_batch_filter_duration", "method" => "decode")
.record(start_filtering_time.elapsed().as_secs_f64());
metrics::histogram!("tgi_batch_inference_duration", "method" => "decode")
.record(start_time.elapsed().as_secs_f64());
metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1);
next_batch
}
// If we have an error, we discard the whole batch
@ -313,7 +324,7 @@ async fn decode(
let _ = client.clear_cache(Some(id)).await;
}
send_errors(err, entries);
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode");
metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1);
None
}
}
@ -371,7 +382,7 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
// request and we need to stop generating hence why we unwrap_or(true)
let stopped = send_responses(generation, entry).map_err(|err| {
tracing::error!("Entry response channel error.");
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
err
}).unwrap_or(true);
if stopped {
@ -387,7 +398,7 @@ fn send_responses(
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
// Return directly if the channel is disconnected
if entry.response_tx.is_closed() {
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
return Ok(true);
}
@ -413,7 +424,7 @@ fn send_responses(
// Create last Token
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
let n = tokens_.ids.len();
metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64);
metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64);
let mut iterator = tokens_
.ids
.into_iter()
@ -478,7 +489,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
// Create and enter a span to link this function back to the entry
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
let err = InferError::GenerationError(error.to_string());
metrics::increment_counter!("tgi_request_failure", "err" => "generation");
metrics::counter!("tgi_request_failure", "err" => "generation").increment(1);
tracing::error!("{err}");
// unwrap_or is valid here as we don't care if the receiver is gone.

View File

@ -384,7 +384,7 @@ pub struct CompletionRequest {
/// UNUSED
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
/// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
pub model: String,
pub model: Option<String>,
/// The prompt to generate completions for.
#[schema(example = "What is Deep Learning?")]
@ -731,7 +731,7 @@ impl ChatCompletionChunk {
pub(crate) struct ChatRequest {
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
/// [UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
pub model: String,
pub model: Option<String>,
/// A list of messages comprising the conversation so far.
#[schema(example = "[{\"role\": \"user\", \"content\": \"What is Deep Learning?\"}]")]
@ -848,7 +848,7 @@ pub enum ToolType {
Function { function: FunctionName },
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema)]
pub struct FunctionName {
pub name: String,
}

View File

@ -210,7 +210,11 @@ async fn main() -> Result<(), RouterError> {
}
let api = if use_api {
if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) {
let cache = Cache::default();
let cache = std::env::var("HUGGINGFACE_HUB_CACHE")
.map_err(|_| ())
.map(|cache_dir| Cache::new(cache_dir.into()))
.unwrap_or_else(|_| Cache::default());
tracing::warn!("Offline mode active using cache defaults");
Type::Cache(cache)
} else {

View File

@ -11,10 +11,11 @@ use crate::kserve::{
};
use crate::validation::ValidationError;
use crate::{
BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, HubTokenizerConfig, Info,
Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse,
Usage, Validation,
BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName, GenerateParameters,
GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig,
HubTokenizerConfig, Info, Message, MessageChunk, MessageContent, OutputMessage, PrefillToken,
SimpleToken, StreamDetails, StreamResponse, TextMessage, Token, TokenizeResponse,
ToolCallDelta, ToolCallMessage, Url, Usage, Validation,
};
use crate::{
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
@ -185,7 +186,7 @@ pub(crate) async fn generate_internal(
span: tracing::Span,
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
let start_time = Instant::now();
metrics::increment_counter!("tgi_request_count");
metrics::counter!("tgi_request_count").increment(1);
// Do not long ultra long inputs, like image payloads.
tracing::debug!("Input: {}", &req.inputs[..1000.min(req.inputs.len())]);
@ -301,25 +302,15 @@ pub(crate) async fn generate_internal(
);
// Metrics
metrics::increment_counter!("tgi_request_success");
metrics::histogram!("tgi_request_duration", total_time.as_secs_f64());
metrics::histogram!(
"tgi_request_validation_duration",
validation_time.as_secs_f64()
);
metrics::histogram!("tgi_request_queue_duration", queue_time.as_secs_f64());
metrics::histogram!(
"tgi_request_inference_duration",
inference_time.as_secs_f64()
);
metrics::histogram!(
"tgi_request_mean_time_per_token_duration",
time_per_token.as_secs_f64()
);
metrics::histogram!(
"tgi_request_generated_tokens",
response.generated_text.generated_tokens as f64
);
metrics::counter!("tgi_request_success").increment(1);
metrics::histogram!("tgi_request_duration").record(total_time.as_secs_f64());
metrics::histogram!("tgi_request_validation_duration").record(validation_time.as_secs_f64());
metrics::histogram!("tgi_request_queue_duration").record(queue_time.as_secs_f64());
metrics::histogram!("tgi_request_inference_duration").record(inference_time.as_secs_f64());
metrics::histogram!("tgi_request_mean_time_per_token_duration")
.record(time_per_token.as_secs_f64());
metrics::histogram!("tgi_request_generated_tokens")
.record(response.generated_text.generated_tokens as f64);
// Send response
let mut output_text = response.generated_text.text;
@ -399,7 +390,7 @@ async fn generate_stream_internal(
span: tracing::Span,
) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) {
let start_time = Instant::now();
metrics::increment_counter!("tgi_request_count");
metrics::counter!("tgi_request_count").increment(1);
tracing::debug!("Input: {}", req.inputs);
@ -427,12 +418,12 @@ async fn generate_stream_internal(
let best_of = req.parameters.best_of.unwrap_or(1);
if best_of != 1 {
let err = InferError::from(ValidationError::BestOfStream);
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
tracing::error!("{err}");
yield Ok(Event::from(err));
} else if req.parameters.decoder_input_details {
let err = InferError::from(ValidationError::PrefillDetailsStream);
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
tracing::error!("{err}");
yield Ok(Event::from(err));
} else {
@ -500,13 +491,13 @@ async fn generate_stream_internal(
span.record("seed", format!("{:?}", generated_text.seed));
// Metrics
metrics::increment_counter!("tgi_request_success");
metrics::histogram!("tgi_request_duration", total_time.as_secs_f64());
metrics::histogram!("tgi_request_validation_duration", validation_time.as_secs_f64());
metrics::histogram!("tgi_request_queue_duration", queue_time.as_secs_f64());
metrics::histogram!("tgi_request_inference_duration", inference_time.as_secs_f64());
metrics::histogram!("tgi_request_mean_time_per_token_duration", time_per_token.as_secs_f64());
metrics::histogram!("tgi_request_generated_tokens", generated_text.generated_tokens as f64);
metrics::counter!("tgi_request_success").increment(1);
metrics::histogram!("tgi_request_duration").record(total_time.as_secs_f64());
metrics::histogram!("tgi_request_validation_duration").record(validation_time.as_secs_f64());
metrics::histogram!("tgi_request_queue_duration").record(queue_time.as_secs_f64());
metrics::histogram!("tgi_request_inference_duration").record(inference_time.as_secs_f64());
metrics::histogram!("tgi_request_mean_time_per_token_duration").record(time_per_token.as_secs_f64());
metrics::histogram!("tgi_request_generated_tokens").record(generated_text.generated_tokens as f64);
// StreamResponse
end_reached = true;
@ -553,7 +544,7 @@ async fn generate_stream_internal(
// Skip if we already sent an error
if !end_reached && !error {
let err = InferError::IncompleteGeneration;
metrics::increment_counter!("tgi_request_failure", "err" => "incomplete");
metrics::counter!("tgi_request_failure", "err" => "incomplete").increment(1);
tracing::error!("{err}");
yield Ok(Event::from(err));
}
@ -572,8 +563,8 @@ request_body = CompletionRequest,
responses(
(status = 200, description = "Generated Chat Completion",
content(
("application/json" = Completion),
("text/event-stream" = CompletionCompleteChunk),
("application/json" = CompletionFinal),
("text/event-stream" = Chunk),
)),
(status = 424, description = "Generation Error", body = ErrorResponse,
example = json ! ({"error": "Request failed during generation"})),
@ -604,9 +595,10 @@ async fn completions(
Json(req): Json<CompletionRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current();
metrics::increment_counter!("tgi_request_count");
metrics::counter!("tgi_request_count").increment(1);
let CompletionRequest {
model,
max_tokens,
seed,
stop,
@ -625,7 +617,7 @@ async fn completions(
// if suffix is present throw an error
if req.suffix.is_some() {
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
return Err((
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
@ -637,7 +629,7 @@ async fn completions(
}
if req.prompt.0.len() > info.max_client_batch_size {
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
return Err((
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
@ -675,7 +667,7 @@ async fn completions(
seed,
top_n_tokens: None,
grammar: None,
..Default::default()
adapter_id: model.as_ref().filter(|m| *m != "tgi").map(String::from),
},
})
.collect();
@ -820,6 +812,10 @@ async fn completions(
}
};
let stream = stream.chain(futures::stream::once(async {
Ok(Event::default().data("[DONE]"))
}));
let sse = Sse::new(stream).keep_alive(KeepAlive::default());
Ok((headers, sse).into_response())
} else {
@ -1009,8 +1005,9 @@ async fn chat_completions(
Json(req): Json<ChatRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current();
metrics::increment_counter!("tgi_request_count");
metrics::counter!("tgi_request_count").increment(1);
let ChatRequest {
model,
logprobs,
max_tokens,
messages,
@ -1039,7 +1036,7 @@ async fn chat_completions(
// response_format and tools are mutually exclusive
if response_format.is_some() && tools.as_ref().is_some() {
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
return Err((
StatusCode::UNPROCESSABLE_ENTITY,
Json(ErrorResponse {
@ -1053,7 +1050,7 @@ async fn chat_completions(
let tool_grammar = match ToolGrammar::apply(tools, tool_choice) {
Ok(grammar) => grammar,
Err(err) => {
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
tracing::error!("{err}");
return Err((
StatusCode::UNPROCESSABLE_ENTITY,
@ -1082,7 +1079,7 @@ async fn chat_completions(
let inputs = match infer.apply_chat_template(messages, tools_grammar_prompt) {
Ok(inputs) => inputs,
Err(err) => {
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
tracing::error!("{err}");
return Err((
StatusCode::UNPROCESSABLE_ENTITY,
@ -1116,7 +1113,7 @@ async fn chat_completions(
seed,
top_n_tokens: req.top_logprobs,
grammar,
..Default::default()
adapter_id: model.filter(|m| *m != "tgi").map(String::from),
},
};
@ -1178,6 +1175,11 @@ async fn chat_completions(
span,
)
.await;
let response_stream = response_stream.chain(futures::stream::once(async {
Ok(Event::default().data("[DONE]"))
}));
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
Ok((headers, sse).into_response())
} else {
@ -1280,7 +1282,7 @@ async fn vertex_compatibility(
Json(req): Json<VertexRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current();
metrics::increment_counter!("tgi_request_count");
metrics::counter!("tgi_request_count").increment(1);
// check that theres at least one instance
if req.instances.is_empty() {
@ -1454,6 +1456,14 @@ pub async fn run(
GrammarType,
ChatRequest,
Message,
MessageContent,
MessageChunk,
Url,
FunctionName,
OutputMessage,
TextMessage,
ToolCallMessage,
ToolCallDelta,
ChatCompletionComplete,
ChatCompletionChoice,
ChatCompletionDelta,

View File

@ -157,7 +157,7 @@ impl Validation {
));
}
metrics::histogram!("tgi_request_input_length", input_length as f64);
metrics::histogram!("tgi_request_input_length").record(input_length as f64);
Ok((inputs, input_length, max_new_tokens))
}
// Return inputs without validation
@ -384,7 +384,7 @@ impl Validation {
ignore_eos_token: false,
};
metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64);
metrics::histogram!("tgi_request_max_new_tokens").record(max_new_tokens as f64);
Ok(ValidGenerateRequest {
inputs,

View File

@ -21,13 +21,14 @@ gen-server:
install-server: gen-server
pip install pip --upgrade
pip install -r requirements_cuda.txt
pip install -e ".[bnb, accelerate, quantize, peft, outlines]"
pip install -e ".[accelerate, quantize, peft, outlines]"
install: install-cuda
echo "Installed server"
install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention
pip install -e ".[bnb]"
install-rocm: install-server install-flash-attention-v2-rocm install-vllm-rocm
@ -35,5 +36,5 @@ run-dev:
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded
export-requirements:
poetry export -o requirements_cuda.txt --without-hashes
poetry export -o requirements_cuda.txt --without-hashes --with cuda
poetry export -o requirements_rocm.txt --without-hashes

View File

@ -59,3 +59,18 @@ def marlin_gemm(
Matrix multiplication using Marlin kernels.
"""
...
# fp8 marlin
def fp8_marlin_gemm(
a: torch.Tensor,
b_q_weight: torch.Tensor,
b_scales: torch.Tensor,
workspace: torch.Tensor,
num_bits: int,
size_m: int,
size_n: int,
size_k: int,
) -> torch.Tensor:
return torch.ops._C.fp8_marlin_gemm(
a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k
)

View File

@ -9,4 +9,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("gptq_marlin_repack", &gptq_marlin_repack,
"Repack GPTQ parameters for Marlin");
m.def("marlin_gemm", &marlin_gemm, "Marlin gemm");
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
m.def("fp8_marlin_gemm", &fp8_marlin_gemm);
}

View File

@ -27,4 +27,9 @@ torch::Tensor marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
torch::Tensor &b_scales, torch::Tensor &workspace,
int64_t size_m, int64_t size_n, int64_t size_k);
torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& workspace,
int64_t num_bits, int64_t size_m, int64_t size_n,
int64_t size_k);
#endif

File diff suppressed because it is too large Load Diff

View File

@ -9,6 +9,7 @@ setup(
CUDAExtension(
name="marlin_kernels",
sources=[
"marlin_kernels/fp8_marlin.cu",
"marlin_kernels/gptq_marlin.cu",
"marlin_kernels/gptq_marlin_repack.cu",
"marlin_kernels/marlin_cuda_kernel.cu",

View File

@ -2,6 +2,7 @@ import torch
from text_generation_server.layers import (
TensorParallelEmbedding,
)
from text_generation_server.utils.weights import DefaultWeightsLoader
class ProcessGroup:
@ -42,7 +43,12 @@ class Weights:
def test_weight_hub_files_offline_error():
vocab_size = 17
weights = Weights(rank=0, world_size=1, vocab_size=vocab_size, hidden_dim=256)
weights = Weights(
rank=0,
world_size=1,
vocab_size=vocab_size,
hidden_dim=256,
)
embeddings = TensorParallelEmbedding("", weights)
input_ids = torch.arange(vocab_size)

View File

@ -1,13 +1,47 @@
import pytest
import torch
from text_generation_server.utils.weights import Weights
from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.layers.exl2 import Exl2Weight
from text_generation_server.layers.marlin import MarlinWeight
from text_generation_server.utils.weights import (
DefaultWeightsLoader,
Weights,
WeightsLoader,
)
from text_generation_server.layers.gptq import GPTQWeight, GPTQWeightsLoader
from text_generation_server.layers.exl2 import Exl2Weight, Exl2WeightsLoader
from text_generation_server.layers.marlin import MarlinWeight, MarlinWeightsLoader
from types import SimpleNamespace
from typing import List, Optional, Dict, Union
from pathlib import Path
@pytest.fixture
def gptq_weights_loader():
return GPTQWeightsLoader(
bits=4,
groupsize=-1,
desc_act=False,
quant_method="gptq",
quantize="gptq",
sym=True,
)
@pytest.fixture
def gptq_weights_loader_awq():
return GPTQWeightsLoader(
bits=4,
groupsize=-1,
desc_act=False,
quant_method="awq",
quantize="awq",
sym=True,
)
@pytest.fixture
def marlin_weights_loader():
return MarlinWeightsLoader(bits=4, is_marlin_24=False)
dummy_file_system = {
"test_weights": {
"layer.0.weight": torch.tensor(
@ -58,7 +92,7 @@ dummy_file_system = {
dtype=torch.float32,
),
},
"test_get_multi_weights_row": {
"test_get_weights_row": {
"weight.weight": torch.tensor(
[
[1, 2],
@ -101,7 +135,7 @@ dummy_file_system = {
"weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
"weight.s": torch.tensor([[0.5000], [0.2500]], dtype=torch.float16),
},
"test_get_multi_weights_row_gptq": {
"test_get_weights_row_gptq": {
"weight.qweight": torch.tensor(
[
[1, 2],
@ -200,7 +234,7 @@ dummy_file_system = {
"weight.q_scale_max": torch.tensor([100], dtype=torch.float16),
"weight.q_groups": torch.tensor([4], dtype=torch.int16),
},
"test_get_multi_weights_row_exl2": {
"test_get_weights_row_exl2": {
"weight.q_weight": torch.tensor(
[
[1, 2],
@ -245,7 +279,7 @@ dummy_file_system = {
"weight.q_scale_max": torch.tensor([100], dtype=torch.float16),
"weight.q_groups": torch.tensor([4], dtype=torch.int16),
},
"test_get_multi_weights_row_marlin": {
"test_get_weights_row_marlin": {
"weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
"weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16),
},
@ -308,6 +342,7 @@ class MockWeights(Weights):
dummy_fs,
aliases: Optional[Dict[str, List[str]]] = None,
prefix: Optional[str] = None,
weights_loader: Optional[WeightsLoader] = None,
):
routing = {}
self.dummy_fs = dummy_fs
@ -327,6 +362,9 @@ class MockWeights(Weights):
self.dtype = dtype
self.process_group = process_group
self.prefix = prefix
self.weights_loader = (
DefaultWeightsLoader() if weights_loader is None else weights_loader
)
self._handles = {}
def _get_handle(self, filename: Union[Path, str]):
@ -412,12 +450,10 @@ def test_get_weights_col_packed():
)
prefix = "weight"
quantize = None
block_sizes = 1
w = weights.get_weights_col_packed(
prefix=prefix,
quantize=quantize,
block_sizes=block_sizes,
)
@ -448,12 +484,10 @@ def test_get_weights_col_packed_block_size():
)
prefix = "weight"
quantize = None
block_sizes = 2
w = weights.get_weights_col_packed(
prefix=prefix,
quantize=quantize,
block_sizes=block_sizes,
)
@ -484,12 +518,10 @@ def test_get_weights_col_packed_block_size_arr():
)
prefix = "weight"
quantize = None
block_sizes = [1, 1]
w = weights.get_weights_col_packed(
prefix=prefix,
quantize=quantize,
block_sizes=block_sizes,
)
@ -519,11 +551,9 @@ def test_get_multi_weights_col():
)
prefixes = ["weight", "weight"]
quantize = None
w = weights.get_multi_weights_col(
prefixes=prefixes,
quantize=quantize,
dim=0,
)
@ -545,10 +575,10 @@ def test_get_multi_weights_col():
)
def test_get_multi_weights_row():
def test_get_weights_row():
weights = MockWeights(
[
"test_get_multi_weights_row",
"test_get_weights_row",
],
device="cpu",
dtype=torch.float32,
@ -557,11 +587,9 @@ def test_get_multi_weights_row():
)
prefix = "weight"
quantize = None
w = weights.get_multi_weights_row(
w = weights.get_weights_row(
prefix=prefix,
quantize=quantize,
)
assert torch.allclose(
@ -576,7 +604,7 @@ def test_get_multi_weights_row():
# test_get_weights_col
def test_get_weights_col_awq():
def test_get_weights_col_awq(gptq_weights_loader_awq):
weights = MockWeights(
[
"test_get_weights_col_gptq",
@ -585,14 +613,13 @@ def test_get_weights_col_awq():
dtype=torch.float32,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
weights_loader=gptq_weights_loader_awq,
)
prefix = "weight"
quantize = "awq"
w = weights.get_weights_col(
prefix=prefix,
quantize=quantize,
)
expected_weight = GPTQWeight(
@ -617,7 +644,7 @@ def test_get_weights_col_awq():
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
def test_get_weights_col_gtpq():
def test_get_weights_col_gtpq(gptq_weights_loader):
weights = MockWeights(
[
"test_get_weights_col_gptq",
@ -626,14 +653,13 @@ def test_get_weights_col_gtpq():
dtype=torch.float32,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
weights_loader=gptq_weights_loader,
)
prefix = "weight"
quantize = "gptq"
w = weights.get_weights_col(
prefix=prefix,
quantize=quantize,
)
expected_weight = GPTQWeight(
@ -664,14 +690,13 @@ def test_get_weights_col_exl2():
dtype=torch.float32,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
weights_loader=Exl2WeightsLoader(),
)
prefix = "weight"
quantize = "exl2"
w = weights.get_weights_col(
prefix=prefix,
quantize=quantize,
)
scaled_scale_max = 0.3906 * 256
@ -692,7 +717,7 @@ def test_get_weights_col_exl2():
assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch"
def test_get_weights_col_marlin():
def test_get_weights_col_marlin(marlin_weights_loader):
weights = MockWeights(
[
"test_get_weights_col_marlin",
@ -701,14 +726,13 @@ def test_get_weights_col_marlin():
dtype=torch.float16,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
weights_loader=marlin_weights_loader,
)
prefix = "weight"
quantize = "marlin"
w = weights.get_weights_col(
prefix=prefix,
quantize=quantize,
)
expected_weight = MarlinWeight(
@ -723,7 +747,7 @@ def test_get_weights_col_marlin():
# test_get_weights_col_packed
def test_get_weights_col_packed_awq():
def test_get_weights_col_packed_awq(gptq_weights_loader_awq):
weights = MockWeights(
[
"test_get_weights_col_packed_gptq",
@ -732,15 +756,14 @@ def test_get_weights_col_packed_awq():
dtype=torch.float32,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
weights_loader=gptq_weights_loader_awq,
)
prefix = "weight"
quantize = "awq"
block_sizes = 1
w = weights.get_weights_col_packed(
prefix=prefix,
quantize=quantize,
block_sizes=block_sizes,
)
@ -773,15 +796,14 @@ def test_get_weights_col_packed_exl2():
dtype=torch.float32,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
weights_loader=Exl2WeightsLoader(),
)
prefix = "weight"
quantize = "exl2"
block_sizes = 1
w = weights.get_weights_col_packed(
prefix=prefix,
quantize=quantize,
block_sizes=block_sizes,
)
@ -803,7 +825,7 @@ def test_get_weights_col_packed_exl2():
assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch"
def test_get_weights_col_packed_gptq():
def test_get_weights_col_packed_gptq(gptq_weights_loader):
weights = MockWeights(
[
"test_get_weights_col_packed_gptq",
@ -812,14 +834,13 @@ def test_get_weights_col_packed_gptq():
dtype=torch.float32,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
weights_loader=gptq_weights_loader,
)
prefixes = ["weight"]
quantize = "gptq"
w = weights.get_multi_weights_col(
prefixes=prefixes,
quantize=quantize,
dim=0,
)
@ -842,7 +863,7 @@ def test_get_weights_col_packed_gptq():
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
def test_get_weights_col_packed_marlin():
def test_get_weights_col_packed_marlin(marlin_weights_loader):
weights = MockWeights(
[
"test_get_weights_col_packed_marlin",
@ -851,14 +872,13 @@ def test_get_weights_col_packed_marlin():
dtype=torch.float16,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
weights_loader=marlin_weights_loader,
)
prefix = "weight"
quantize = "marlin"
w = weights.get_multi_weights_col(
prefixes=[prefix],
quantize=quantize,
dim=0,
)
@ -876,7 +896,7 @@ def test_get_weights_col_packed_marlin():
# test_get_multi_weights_col
def test_get_multi_weights_col_awq():
def test_get_multi_weights_col_awq(gptq_weights_loader_awq):
weights = MockWeights(
[
"test_get_multi_weights_col_gptq",
@ -885,14 +905,13 @@ def test_get_multi_weights_col_awq():
dtype=torch.float32,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
weights_loader=gptq_weights_loader_awq,
)
prefixes = ["weight"]
quantize = "awq"
w = weights.get_multi_weights_col(
prefixes=prefixes,
quantize=quantize,
dim=0,
)
@ -924,22 +943,21 @@ def test_get_multi_weights_col_exl2():
dtype=torch.float32,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
weights_loader=Exl2WeightsLoader(),
)
prefix = "weight"
quantize = "exl2"
try:
w = weights.get_multi_weights_col(
prefixes=[prefix],
quantize=quantize,
dim=0,
)
except ValueError as e:
assert e.args[0] == "get_multi_weights_col is not supported for exl2"
def test_get_multi_weights_col_gptq():
def test_get_multi_weights_col_gptq(gptq_weights_loader):
weights = MockWeights(
[
"test_get_multi_weights_col_gptq",
@ -948,14 +966,13 @@ def test_get_multi_weights_col_gptq():
dtype=torch.float32,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
weights_loader=gptq_weights_loader,
)
prefixes = ["weight"]
quantize = "gptq"
w = weights.get_multi_weights_col(
prefixes=prefixes,
quantize=quantize,
dim=0,
)
@ -978,7 +995,7 @@ def test_get_multi_weights_col_gptq():
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
def test_get_multi_weights_col_marlin():
def test_get_multi_weights_col_marlin(marlin_weights_loader):
weights = MockWeights(
[
"test_get_multi_weights_col_marlin",
@ -987,14 +1004,13 @@ def test_get_multi_weights_col_marlin():
dtype=torch.float16,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
weights_loader=marlin_weights_loader,
)
prefix = "weight"
quantize = "marlin"
w = weights.get_multi_weights_col(
prefixes=[prefix],
quantize=quantize,
dim=0,
)
@ -1007,26 +1023,25 @@ def test_get_multi_weights_col_marlin():
assert torch.allclose(w.s, expected_weight.s), "s mismatch"
# test_get_multi_weights_row
# test_get_weights_row
def test_get_multi_weights_row_awq():
def test_get_weights_row_awq(gptq_weights_loader_awq):
weights = MockWeights(
[
"test_get_multi_weights_row_gptq",
"test_get_weights_row_gptq",
],
device="cpu",
dtype=torch.float32,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
weights_loader=gptq_weights_loader_awq,
)
prefix = "weight"
quantize = "awq"
w = weights.get_multi_weights_row(
w = weights.get_weights_row(
prefix=prefix,
quantize=quantize,
)
expected_weight = GPTQWeight(
@ -1048,23 +1063,22 @@ def test_get_multi_weights_row_awq():
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
def test_get_multi_weights_row_exl2():
def test_get_weights_row_exl2():
weights = MockWeights(
[
"test_get_multi_weights_row_exl2",
"test_get_weights_row_exl2",
],
device="cpu",
dtype=torch.float32,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
weights_loader=Exl2WeightsLoader(),
)
prefix = "weight"
quantize = "exl2"
w = weights.get_multi_weights_row(
w = weights.get_weights_row(
prefix=prefix,
quantize=quantize,
)
print(w)
@ -1086,23 +1100,22 @@ def test_get_multi_weights_row_exl2():
assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch"
def test_get_multi_weights_row_gptq():
def test_get_weights_row_gptq(gptq_weights_loader):
weights = MockWeights(
[
"test_get_multi_weights_row_gptq",
"test_get_weights_row_gptq",
],
device="cpu",
dtype=torch.float32,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
weights_loader=gptq_weights_loader,
)
prefix = "weight"
quantize = "gptq"
w = weights.get_multi_weights_row(
w = weights.get_weights_row(
prefix=prefix,
quantize=quantize,
)
expected_weight = GPTQWeight(
@ -1124,23 +1137,22 @@ def test_get_multi_weights_row_gptq():
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
def test_get_multi_weights_row_marlin():
def test_get_weights_row_marlin(marlin_weights_loader):
weights = MockWeights(
[
"test_get_multi_weights_row_marlin",
"test_get_weights_row_marlin",
],
device="cpu",
dtype=torch.float16,
process_group=dummy_process_group,
dummy_fs=dummy_file_system,
weights_loader=marlin_weights_loader,
)
prefix = "weight"
quantize = "marlin"
w = weights.get_multi_weights_row(
w = weights.get_weights_row(
prefix=prefix,
quantize=quantize,
)
expected_weight = MarlinWeight(

View File

@ -91,6 +91,15 @@ def serve(
f"LoRA adapters are enabled. This is an experimental feature and may not work as expected."
)
# TODO: enable lora with cuda graphs. for now disable cuda graphs if lora is enabled
# and warn the user
if len(lora_adapter_ids) > 0 and os.getenv("CUDA_GRAPHS", None) is not None:
logger.warning(
f"LoRa adapter are not supported with CUDA Graphs. Disabling CUDA Graphs."
)
global CUDA_GRAPHS
CUDA_GRAPHS = None
# Downgrade enum into str for easier management later on
quantize = None if quantize is None else quantize.value
dtype = None if dtype is None else dtype.value
@ -332,6 +341,7 @@ def quantize(
upload_to_model_id: Optional[str] = None,
percdamp: float = 0.01,
act_order: bool = False,
groupsize: int = 128,
):
if revision is None:
revision = "main"
@ -346,13 +356,14 @@ def quantize(
quantize(
model_id=model_id,
bits=4,
groupsize=128,
groupsize=groupsize,
output_dir=output_dir,
revision=revision,
trust_remote_code=trust_remote_code,
upload_to_model_id=upload_to_model_id,
percdamp=percdamp,
act_order=act_order,
sym=True,
)

View File

@ -1,6 +1,9 @@
import torch
from typing import List, Union
from dataclasses import dataclass
from text_generation_server.utils.weights import WeightsLoader, Weights
@dataclass
class Exl2Weight:
@ -21,3 +24,60 @@ class Exl2Weight:
@property
def device(self) -> torch.device:
return self.q_weight.device
class Exl2WeightsLoader(WeightsLoader):
"""Loader for exl2-quantized weights."""
def get_weights_col_packed(
self,
weights: Weights,
prefix: str,
block_sizes: Union[int, List[int]],
):
raise RuntimeError("Column-packed weights are not supported for exl")
def get_weights_col(self, weights: Weights, prefix: str):
try:
q_weight = weights.get_tensor(f"{prefix}.q_weight")
except RuntimeError:
raise RuntimeError(
"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
)
q_scale = weights.get_tensor(f"{prefix}.q_scale")
q_invperm = weights.get_tensor(f"{prefix}.q_invperm")
q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max")
q_groups = weights.get_tensor(f"{prefix}.q_groups")
return Exl2Weight(
q_weight=q_weight,
q_scale=q_scale,
q_invperm=q_invperm,
q_scale_max=q_scale_max,
q_groups=q_groups,
)
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
raise ValueError("get_multi_weights_col is not supported for exl2")
def get_weights_row(self, weights: Weights, prefix: str):
try:
q_weight = weights.get_tensor(f"{prefix}.q_weight")
except RuntimeError:
raise RuntimeError(
"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
)
q_scale = weights.get_tensor(f"{prefix}.q_scale")
q_invperm = weights.get_tensor(f"{prefix}.q_invperm")
q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max")
q_groups = weights.get_tensor(f"{prefix}.q_groups")
return Exl2Weight(
q_weight=q_weight,
q_scale=q_scale,
q_invperm=q_invperm,
q_scale_max=q_scale_max,
q_groups=q_groups,
)

View File

@ -1,4 +1,23 @@
from enum import Enum, auto
import torch
from text_generation_server.utils.import_utils import SYSTEM
def get_fp8_linear() -> torch.nn.Module:
"""
Return an FP8 linear `Module` that is compatible with the current system.
"""
if SYSTEM == "cuda":
major, minor = torch.cuda.get_device_capability()
if major == 8 and minor < 9:
from text_generation_server.layers.marlin import GPTQMarlinFP8Linear
return GPTQMarlinFP8Linear
# On other systems let Torch decide if the hardware supports FP8.
return Fp8Linear
def fp8_quantize(weight, qdtype=torch.float8_e4m3fn):

View File

@ -1,20 +1,14 @@
from dataclasses import dataclass
from loguru import logger
import os
from typing import Optional
from typing import List, Optional, Union
from safetensors import SafetensorError
from text_generation_server.utils.weights import Weights, WeightsLoader
import torch
from text_generation_server.utils.import_utils import (
SYSTEM,
)
@dataclass
class GPTQParams:
bits: int
checkpoint_format: Optional[str]
groupsize: int
desc_act: bool
quant_method: str
sym: bool
from text_generation_server.utils.log import log_once
@dataclass
@ -69,3 +63,345 @@ elif CAN_EXLLAMA:
pass
from text_generation_server.layers.gptq.quant_linear import QuantLinear
class GPTQWeightsLoader(WeightsLoader):
"""
Loader for GPTQ- and AWQ-quantized weights.
"""
def __init__(
self,
*,
bits: int,
desc_act: bool,
groupsize: int,
quant_method: str,
quantize: str,
sym: bool,
):
self.bits = bits
self.desc_act = desc_act
self.groupsize = groupsize
self.quant_method = quant_method
self.quantize = quantize
self.sym = sym
def get_weights_col_packed(
self,
weights: Weights,
prefix: str,
block_sizes: Union[int, List[int]],
):
from text_generation_server.layers.marlin import (
can_use_gptq_marlin,
repack_gptq_for_marlin,
)
try:
qweight = weights.get_packed_sharded(
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized."
)
scales = weights.get_packed_sharded(
f"{prefix}.scales", dim=1, block_sizes=block_sizes
)
scales = scales.to(dtype=weights.dtype)
self._get_gptq_params(weights)
if can_use_gptq_marlin(
bits=self.bits,
groupsize=self.groupsize,
quant_method=self.quant_method,
quantize=self.quantize,
sym=self.sym,
):
g_idx = weights.get_tensor(f"{prefix}.g_idx")
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=self.bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
sym=self.sym,
sharded_infeatures=False,
)
qzeros = weights.get_packed_sharded(
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
)
if self.quantize == "gptq" and self.quant_method == "gptq":
g_idx = weights.get_tensor(f"{prefix}.g_idx")
elif self.quantize == "gptq" and self.quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
g_idx = (
torch.arange(
qweight.shape[0] * (32 // self.bits),
device=qweight.device,
)
// self.groupsize
).to(dtype=torch.int32)
else:
g_idx = None
return GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=self.bits,
groupsize=self.groupsize,
use_exllama=False,
)
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
from text_generation_server.layers.marlin import (
can_use_gptq_marlin,
repack_gptq_for_marlin,
)
try:
qweight = torch.cat(
[weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized"
)
scales = torch.cat(
[weights.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
)
self._get_gptq_params(weights)
if can_use_gptq_marlin(
bits=self.bits,
groupsize=self.groupsize,
quant_method=self.quant_method,
quantize=self.quantize,
sym=self.sym,
):
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=self.bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
sym=self.sym,
sharded_infeatures=False,
)
qzeros = torch.cat(
[weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
)
from text_generation_server.layers.gptq import HAS_EXLLAMA
use_exllama = (
self.bits == 4
and HAS_EXLLAMA
and self.quantize == "gptq"
and not self.desc_act
)
if self.quantize == "gptq" and self.quant_method == "gptq":
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
elif self.quantize == "gptq" and self.quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
if use_exllama:
g_idx = None
else:
g_idx = (
torch.arange(
qweight.shape[0] * (32 // self.bits),
device=qweight.device,
)
// self.groupsize
).to(dtype=torch.int32)
else:
g_idx = None
return GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=self.bits,
groupsize=self.groupsize,
use_exllama=use_exllama,
)
def get_weights_row(self, weights: Weights, prefix: str):
from text_generation_server.layers.marlin import (
can_use_gptq_marlin,
repack_gptq_for_marlin,
)
self._get_gptq_params(weights)
if can_use_gptq_marlin(
bits=self.bits,
groupsize=self.groupsize,
quant_method=self.quant_method,
quantize=self.quantize,
sym=self.sym,
):
log_once(logger.info, "Using GPTQ-Marlin kernels")
try:
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
)
g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0)
if self.desc_act or self.groupsize == -1:
scales = weights.get_tensor(f"{prefix}.scales")
else:
scales = weights.get_sharded(f"{prefix}.scales", dim=0)
sharded_in_features = weights.process_group.size() > 1
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=self.bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
sym=self.sym,
sharded_infeatures=sharded_in_features,
)
use_exllama = True
if self.bits != 4:
use_exllama = False
if self.desc_act:
log_once(logger.warning, "Disabling exllama because desc_act=True")
use_exllama = False
try:
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
if self.quantize == "gptq" and self.quant_method == "gptq":
g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0)
else:
g_idx = None
if weights.process_group.size() > 1:
if g_idx is not None:
if (
not torch.equal(
g_idx.cpu(),
torch.tensor(
[i // self.groupsize for i in range(g_idx.shape[0])],
dtype=torch.int32,
),
)
and not (g_idx == 0).all()
):
# Exllama implementation does not support row tensor parallelism with act-order, as
# it would require to reorder input activations that are split unto several GPUs
use_exllama = False
from text_generation_server.layers.gptq import (
HAS_EXLLAMA,
CAN_EXLLAMA,
GPTQWeight,
)
if use_exllama:
if not HAS_EXLLAMA:
if CAN_EXLLAMA:
log_once(
logger.warning,
"Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True",
)
use_exllama = False
else:
log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")
if use_exllama and self.groupsize != -1:
qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0)
scales = weights.get_sharded(f"{prefix}.scales", dim=0)
else:
qzeros = weights.get_tensor(f"{prefix}.qzeros")
scales = weights.get_tensor(f"{prefix}.scales")
if use_exllama and g_idx is not None:
g_idx = g_idx - g_idx[0]
if self.quantize == "gptq" and self.quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
if use_exllama:
g_idx = None
else:
g_idx = (
torch.arange(
qweight.shape[0] * (32 // self.bits),
device=qweight.device,
)
// self.groupsize
).to(dtype=torch.int32)
return GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=self.bits,
groupsize=self.groupsize,
use_exllama=use_exllama,
)
def _get_gptq_params(self, weights: Weights):
if weights._has_tensor("gptq_bits") and weights._has_tensor("gptq_groupsize"):
self.bits = weights.get_tensor("gptq_bits").item()
self.groupsize = weights.get_tensor("gptq_groupsize").item()
self.desc_act = False
# `server quantize` used asymmetric quantization unconditionally
# before the `gptq_sym` setting tensor was added.
self.sym = (
weights.get_tensor("gptq_sym").item()
if weights._has_tensor("gptq_sym")
else False
)
self.quant_method = "gptq"

View File

@ -16,6 +16,8 @@ from text_generation_server.layers.gptq.quant_linear import QuantLinear
from loguru import logger
from typing import Optional
from text_generation_server.utils.weights import DefaultWeightsLoader
DEV = torch.device("cuda:0")
@ -869,6 +871,7 @@ def quantize(
upload_to_model_id: Optional[str],
percdamp: float,
act_order: bool,
sym: bool,
):
print("loading model")
config = AutoConfig.from_pretrained(
@ -891,6 +894,7 @@ def quantize(
dtype=torch.float16,
process_group=process_group,
aliases={"embed_tokens.weight": ["lm_head.weight"]},
weights_loader=DefaultWeightsLoader(),
)
hooks = []
for name, module in model.named_modules():
@ -943,6 +947,7 @@ def quantize(
percdamp=percdamp,
act_order=act_order,
hooks=hooks,
sym=sym,
)
print(time.time() - tick)
@ -954,6 +959,7 @@ def quantize(
state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}
state_dict["gptq_bits"] = torch.LongTensor([bits])
state_dict["gptq_groupsize"] = torch.LongTensor([groupsize])
state_dict["gptq_sym"] = torch.BoolTensor([sym])
max_shard_size = "10GB"
shards, index = shard_checkpoint(

View File

@ -106,9 +106,9 @@ def get_linear(weight, bias, quantize):
"Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
)
elif quantize == "fp8":
from text_generation_server.layers.fp8 import Fp8Linear
from text_generation_server.layers.fp8 import get_fp8_linear
linear = Fp8Linear(weight, bias)
linear = get_fp8_linear()(weight, bias)
elif quantize == "bitsandbytes":
try:
from text_generation_server.layers.bnb import (

View File

@ -1,11 +1,13 @@
from dataclasses import dataclass
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from text_generation_server.layers.gptq import GPTQParams
from loguru import logger
from text_generation_server.layers.fp8 import fp8_quantize
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weights, WeightsLoader
try:
import marlin_kernels
@ -24,16 +26,132 @@ GPTQ_MARLIN_GROUP_SIZES = [-1, 32, 64, 128]
MARLIN_TILE_SIZE = 16
def can_use_gptq_marlin(gptq_params: GPTQParams, quantize: str) -> bool:
class MarlinWeightsLoader(WeightsLoader):
"""Loader for Marlin-quantized weights."""
def __init__(self, *, bits: int, is_marlin_24: bool):
self.bits = bits
self.is_marlin_24 = is_marlin_24
def get_weights_col_packed(
self,
weights: Weights,
prefix: str,
block_sizes: Union[int, List[int]],
):
if self.is_marlin_24:
B = weights.get_packed_sharded(
f"{prefix}.B_24", dim=1, block_sizes=block_sizes
)
B_meta = weights.get_packed_sharded(
f"{prefix}.B_meta", dim=1, block_sizes=block_sizes
)
s = weights.get_packed_sharded(
f"{prefix}.s", dim=1, block_sizes=block_sizes
)
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
else:
B = weights.get_packed_sharded(
f"{prefix}.B", dim=1, block_sizes=block_sizes
)
s = weights.get_packed_sharded(
f"{prefix}.s", dim=1, block_sizes=block_sizes
)
weight = MarlinWeight(B=B, s=s)
return weight
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
if is_marlin_24:
try:
B = torch.cat(
[weights.get_sharded(f"{p}.B_24", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `marlin` weight, make sure the model is already quantized"
)
B_meta = torch.cat(
[weights.get_sharded(f"{p}.B_meta", dim=1) for p in prefixes], dim=1
)
s = torch.cat(
[weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
)
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
else:
try:
B = torch.cat(
[weights.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `marlin` weight, make sure the model is already quantized"
)
s = torch.cat(
[weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
)
weight = MarlinWeight(B=B, s=s)
return weight
def get_weights_row(self, weights: Weights, prefix: str):
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
if is_marlin_24:
try:
B = weights.get_sharded(f"{prefix}.B_24", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized."
)
B_meta = weights.get_sharded(f"{prefix}.B_meta", dim=0)
num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0]
if num_groups == 1:
# The number of groups is 1 when groupsize == -1. share
# scales between all shards in this case.
s = weights.get_tensor(f"{prefix}.s")
else:
s = weights.get_sharded(f"{prefix}.s", dim=0)
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
else:
try:
B = weights.get_sharded(f"{prefix}.B", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` weight, make sure the model is already quantized."
)
num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0]
if num_groups == 1:
# The number of groups is 1 when groupsize == -1. share
# scales between all shards in this case.
s = weights.get_tensor(f"{prefix}.s")
else:
s = weights.get_sharded(f"{prefix}.s", dim=0)
weight = MarlinWeight(B=B, s=s)
return weight
def can_use_gptq_marlin(
*, bits: int, groupsize: int, quant_method: str, quantize: str, sym: bool
) -> bool:
return (
SYSTEM == "cuda"
and marlin_kernels is not None
and has_sm_8_0
and quantize == "gptq"
and gptq_params.quant_method == "gptq"
and gptq_params.bits in GPTQ_MARLIN_BITS
and gptq_params.groupsize in GPTQ_MARLIN_GROUP_SIZES
and gptq_params.sym
and quant_method == "gptq"
and bits in GPTQ_MARLIN_BITS
and groupsize in GPTQ_MARLIN_GROUP_SIZES
and sym
)
@ -339,6 +457,115 @@ class GPTQMarlin24Linear(nn.Module):
return C
class GPTQMarlinFP8Linear(nn.Module):
"""
FP8 GPTQ-Marlin linear layer.
"""
def __init__(
self,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
) -> None:
super().__init__()
_check_marlin_kernels()
assert marlin_kernels is not None
log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel")
qweight, scale = fp8_quantize(weight)
scale = scale.to(torch.float16)
qweight, scales = repack_fp8_for_marlin(qweight, scale)
in_features = qweight.shape[0] * MARLIN_TILE_SIZE
out_features = scales.shape[1]
_check_valid_shape(in_features=in_features, out_features=out_features)
self.qweight = qweight
self.scales = scales
self.bias = bias if bias is not None else None
self.workspace = torch.zeros(
out_features // 64 * 16, dtype=torch.int, device=qweight.device
)
def forward(self, A: torch.Tensor) -> torch.Tensor:
assert marlin_kernels is not None
A_flat = A.view(-1, A.shape[-1])
C = marlin_kernels.fp8_marlin_gemm(
A_flat,
self.qweight,
self.scales,
self.workspace,
8,
A_flat.shape[0],
self.scales.shape[1],
A_flat.shape[1],
)
C = C.reshape(A.shape[:-1] + (self.scales.shape[1],))
if self.bias is not None:
C += self.bias
return C
def pack_fp8_as_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
"""
Repack FP8 weights to gptq format (packed int32 elements).
"""
assert fp8_tensor.dtype == torch.float8_e4m3fn
if fp8_tensor.shape[0] % 4 != 0:
raise ValueError(
f"Leading tensor dimension is not divisable by 4: {fp8_tensor.shape[0]}"
)
# Reshape to prepare for packing
reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:])
# Convert fp8 to uint8 (byte) representation
byte_tensor = reshaped.view(torch.uint8)
# Pack 4 uint8 values into one int32
packed = torch.zeros(
fp8_tensor.shape[0] // 4,
fp8_tensor.shape[1],
dtype=torch.int32,
device=fp8_tensor.device,
)
for i in range(4):
packed.bitwise_or_(byte_tensor[:, i].to(torch.int32) << i * 8)
return packed
def repack_fp8_for_marlin(weight: torch.Tensor, scale: torch.Tensor):
"""
Repack FP8 tensor for GPTQ-Marlin.
"""
out_features, in_features = weight.shape
# Torch linear layers weights with shape [out_features, in_features],
# GPTQ-quantized weights use [in_feateres/pack_factor, in_features],
# so transpose before packing.
qweight = pack_fp8_as_int32(weight.t())
perm = torch.empty(0, dtype=torch.int, device=qweight.device)
repacked = marlin_kernels.gptq_marlin_repack(
qweight, perm, in_features, out_features, 8
)
scales = scale.reshape(1, 1).repeat(1, out_features)
scales = permute_scales(scales)
return repacked, scales
@dataclass
class MarlinWeight:
"""

View File

@ -102,7 +102,7 @@ class PositionRotaryEmbedding(nn.Module):
max_position_embeddings=rope_scaling[
"original_max_position_embeddings"
],
base=10000.0,
base=base,
device=inv_freq.device,
scaling_factor=scaling_factor,
extrapolation_factor=1,

View File

@ -52,7 +52,7 @@ class TensorParallelHead(SuperLayer):
weight = weights.get_tensor(f"{prefix}.weight")
except:
# ...otherwise they are quantized.
weight = weights.get_weights_col(prefix, config.quantize)
weight = weights.get_weights_col(prefix)
should_gather = weights.process_group.size() > 1
elif weights.process_group.size() > 1:
try:
@ -129,9 +129,7 @@ class TensorParallelColumnLinear(SuperLayer):
@classmethod
def load_gate_up(cls, config, prefix: str, weights, bias: bool):
"""Specific method when the QKV was joined after the fact"""
weight = weights.get_weights_col_packed_gate_up(
prefix, quantize=config.quantize
)
weight = weights.get_weights_col_packed_gate_up(prefix)
if bias:
raise NotImplementedError("packed_gate_up only implemented without bias")
else:
@ -152,7 +150,6 @@ class TensorParallelColumnLinear(SuperLayer):
"""Specific method when the QKV was joined after the fact"""
weight = weights.get_weights_col_packed_qkv(
prefix,
quantize=config.quantize,
num_heads=num_heads,
num_key_value_heads=num_key_value_heads,
)
@ -165,7 +162,7 @@ class TensorParallelColumnLinear(SuperLayer):
@classmethod
def load(cls, config, prefix: str, weights, bias: bool):
weight = weights.get_weights_col(prefix, config.quantize)
weight = weights.get_weights_col(prefix)
if bias:
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
else:
@ -178,14 +175,12 @@ class TensorParallelColumnLinear(SuperLayer):
if config.quantize == "exl2":
linears = []
for prefix in prefixes:
weight = weights.get_weights_col(prefix, config.quantize)
weight = weights.get_weights_col(prefix)
b = weights.get_tensor(f"{prefix}.bias") if bias else None
linears.append(get_linear(weight, b, config.quantize))
linear = LayerConcat(linears)
else:
weight = weights.get_multi_weights_col(
prefixes, quantize=config.quantize, dim=dim
)
weight = weights.get_multi_weights_col(prefixes, dim=dim)
if bias:
b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes]
bias = torch.cat(b, dim=dim)
@ -202,7 +197,7 @@ class TensorParallelRowLinear(SuperLayer):
@classmethod
def load(cls, config, prefix: str, weights, bias: bool):
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
weight = weights.get_weights_row(prefix)
if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process

View File

@ -804,6 +804,10 @@ def get_model(
quantize=quantize,
speculator=speculator,
dtype=dtype,
aliases={
"lm_head.weight": ["transformer.word_embeddings.weight"],
"transformer.word_embeddings.weight": ["lm_head.weight"],
},
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=RWConfig,

View File

@ -20,6 +20,7 @@ from text_generation_server.utils import (
from text_generation_server.models import Model
from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.quantization import get_loader
from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.models.types import (
Batch,
@ -546,12 +547,17 @@ class CausalLM(Model):
tokenizer.pad_token_id = config.pad_token_id
torch.distributed.barrier(group=self.process_group)
weights_loader = get_loader(
quantize=quantize, model_id=model_id, revision=revision
)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
filenames,
device=device,
dtype=dtype,
process_group=self.process_group,
weights_loader=weights_loader,
)
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
prefix = ""
model = model_class(prefix, config, weights)

View File

@ -163,7 +163,6 @@ def _load_gqa(config, prefix: str, weights):
weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
quantize=config.quantize,
dim=0,
)

View File

@ -105,6 +105,12 @@ class DbrxFFNConfig(PretrainedConfig):
class DbrxConfig(PretrainedConfig):
attribute_map = {
"hidden_size": "d_model",
"num_attention_heads": "n_heads",
"num_hidden_layers": "n_layers",
}
def __init__(
self,
d_model: int = 2048,
@ -157,6 +163,12 @@ class DbrxConfig(PretrainedConfig):
**kwargs,
)
@property
def num_key_value_heads(self):
# We can't use the attribute map, since this the number of KV
# heads is not top-level.
return self.attn_config.kv_n_heads
def promote_scalar(x: torch.Tensor) -> torch.Tensor:
return x.view(1) if len(x.size()) == 0 else x

View File

@ -141,7 +141,6 @@ def _load_gqa(config, prefix: str, weights):
weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
quantize=config.quantize,
dim=0,
)

View File

@ -141,7 +141,6 @@ def _load_gqa(config, prefix: str, weights):
weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
quantize=config.quantize,
dim=0,
)

View File

@ -61,7 +61,6 @@ def _load_qkv_gptq(config, prefix: str, weights):
# Weights
weight = weights.get_weights_col_packed_qkv(
f"{prefix}.c_attn",
config.quantize,
config.num_attention_heads,
config.num_attention_heads,
)
@ -137,7 +136,7 @@ def load_row(config, prefix: str, weights, bias: bool):
"""load_row, but with transposed weight matrices."""
if config.quantize == "gptq":
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
weight = weights.get_weights_row(prefix)
else:
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T.contiguous()
@ -155,9 +154,7 @@ def load_row(config, prefix: str, weights, bias: bool):
def load_col(config, prefix: str, weights, bias: bool):
"""load_col, but with transposed weight matrices."""
if config.quantize == "gptq":
weight = weights.get_multi_weights_col(
[prefix], quantize=config.quantize, dim=1
)
weight = weights.get_multi_weights_col([prefix], dim=1)
else:
weight = weights.get_sharded(f"{prefix}.weight", dim=1).T.contiguous()

View File

@ -135,7 +135,6 @@ def _load_gqa(config, prefix: str, weights):
weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
quantize=config.quantize,
dim=0,
)

View File

@ -48,7 +48,7 @@ from text_generation_server.layers.rotary import (
def load_row(config, prefix: str, weights, bias: bool):
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
weight = weights.get_weights_row(prefix)
if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process
@ -64,7 +64,7 @@ def load_row(config, prefix: str, weights, bias: bool):
def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size):
weight = weights.get_multi_weights_col([prefix], quantize=config.quantize, dim=0)
weight = weights.get_multi_weights_col([prefix], dim=0)
if isinstance(weight, torch.Tensor):
# Only on non quantized versions
weight = (

View File

@ -85,7 +85,6 @@ def _load_gqa(config, prefix: str, weights):
weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
quantize=config.quantize,
dim=0,
)

View File

@ -23,7 +23,7 @@ from text_generation_server.layers.attention import (
def load_row(config, prefix: str, weights, bias: bool):
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
weight = weights.get_weights_row(prefix)
if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process
@ -42,6 +42,7 @@ class RWConfig(PretrainedConfig):
attribute_map = {
"num_hidden_layers": "n_layer",
"num_attention_heads": "n_head",
"num_key_value_heads": "n_head_kv",
}
def __init__(

View File

@ -17,6 +17,7 @@ from text_generation_server.layers import (
TensorParallelEmbedding,
get_linear,
)
from text_generation_server.layers.gptq import GPTQWeightsLoader
from text_generation_server.layers.layernorm import (
FastLayerNorm,
)
@ -81,11 +82,13 @@ def _load_multi_mqa_gptq(
qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
qzeros = qzeros.to(device=weights.device)
gptq_params = weights._get_gptq_params()
if gptq_params.quant_method == "gptq":
loader = weights.weights_loader
assert isinstance(loader, GPTQWeightsLoader)
loader._get_gptq_params(weights)
if loader.quant_method == "gptq":
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
g_idx = g_idx.to(device=weights.device)
elif gptq_params.quant_method == "awq":
elif loader.quant_method == "awq":
g_idx = None
from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq,
@ -100,8 +103,8 @@ def _load_multi_mqa_gptq(
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=gptq_params.bits,
groupsize=gptq_params.groupsize,
bits=loader.bits,
groupsize=loader.groupsize,
use_exllama=HAS_EXLLAMA,
)
@ -197,9 +200,7 @@ def load_col(config, prefix: str, weights, bias: bool):
if config.transpose:
weight = weights.get_sharded(f"{prefix}.weight", dim=1).T
else:
weight = weights.get_multi_weights_col(
[prefix], quantize=config.quantize, dim=0
)
weight = weights.get_multi_weights_col([prefix], dim=0)
if bias:
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
@ -212,7 +213,7 @@ def load_row(config, prefix: str, weights, bias: bool):
if config.transpose:
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T
else:
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
weight = weights.get_weights_row(prefix)
if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process

View File

@ -126,7 +126,6 @@ def _load_gqa(config, prefix: str, weights):
weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
quantize=config.quantize,
dim=0,
)

View File

@ -34,6 +34,7 @@ from text_generation_server.layers import (
TensorParallelEmbedding,
TensorParallelRowLinear,
)
from text_generation_server.utils.weights import DefaultWeightsLoader
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
@ -682,7 +683,7 @@ class Idefics2Connector(nn.Module):
class Idefics2ForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
config.vision_config.quantize = config.quantize
config.vision_config.quantize = None
config.vision_config.speculator = config.speculator
config.text_config.quantize = config.quantize
config.text_config.speculator = config.speculator
@ -695,16 +696,28 @@ class Idefics2ForConditionalGeneration(nn.Module):
name="text_model",
)
self.dtype = weights.dtype
self.vision_model = Idefics2VisionTransformer(
prefix=f"{prefix}.model.vision_model" if prefix else "model.vision_model",
config=vision_config,
weights=weights,
)
self.connector = Idefics2Connector(
prefix=f"{prefix}.model.connector" if prefix else "model.connector",
config=config,
weights=weights,
)
# The vision and connector models are not quantized.
with weights.use_loader(DefaultWeightsLoader()):
self.vision_model = Idefics2VisionTransformer(
prefix=(
f"{prefix}.model.vision_model" if prefix else "model.vision_model"
),
config=vision_config,
weights=weights,
)
quantize = config.quantize
try:
config.quantize = None
self.connector = Idefics2Connector(
prefix=f"{prefix}.model.connector" if prefix else "model.connector",
config=config,
weights=weights,
)
finally:
config.quantize = quantize
self.config = config
self.image_seq_len = config.perceiver_config.resampler_n_latents
self.image_token_id = config.image_token_id

View File

@ -49,6 +49,7 @@ from text_generation_server.models.globals import (
from text_generation_server.layers.attention import Seqlen
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
from text_generation_server.utils.dist import MEMORY_FRACTION
from text_generation_server.utils.quantization import get_loader
from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments
from text_generation_server.utils.import_utils import (
@ -880,12 +881,16 @@ class FlashCausalLM(Model):
torch.distributed.barrier(group=self.process_group)
weights_loader = get_loader(quantize, model_id, revision)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames, device, dtype, process_group=self.process_group, aliases=aliases
filenames,
device,
dtype,
process_group=self.process_group,
aliases=aliases,
weights_loader=weights_loader,
)
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
prefix = ""
model = model_class(prefix, config, weights)
@ -904,13 +909,12 @@ class FlashCausalLM(Model):
self.num_layers = config.num_hidden_layers
# Validation is done in the model itself
if num_kv_heads is None:
# Order is important here.
for attr in ["num_key_value_heads", "num_attention_heads", "n_head"]:
num_kv_heads = getattr(config, attr, None)
if num_kv_heads is not None:
break
num_kv_heads = getattr(config, "num_key_value_heads", None)
# GPT-2 workaround
if num_kv_heads is None:
raise ValueError("Cannot get the number of key/value heads")
num_kv_heads = getattr(config, "n_head", None)
if num_kv_heads is None:
raise ValueError("Cannot get the number of key/value heads")
self.num_kv_heads = (
num_kv_heads // self.process_group.size()
if num_kv_heads > 1

View File

@ -23,6 +23,7 @@ from text_generation_server.utils import (
weight_files,
Weights,
)
from text_generation_server.utils.quantization import get_loader
class IDEFICSSharded(IdeficsCausalLM):
@ -70,6 +71,9 @@ class IDEFICSSharded(IdeficsCausalLM):
trust_remote_code=trust_remote_code,
)
weights_loader = get_loader(
quantize=quantize, model_id=model_id, revision=revision
)
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
@ -77,6 +81,7 @@ class IDEFICSSharded(IdeficsCausalLM):
device=device,
dtype=dtype,
process_group=self.process_group,
weights_loader=weights_loader,
)
model = IdeficsForVisionText2Text(config, weights)

View File

@ -28,6 +28,7 @@ from text_generation_server.models.types import (
GeneratedText,
)
from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.quantization import get_loader
from text_generation_server.utils.tokens import batch_top_tokens, Sampling
from dataclasses import dataclass
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
@ -448,8 +449,17 @@ class Mamba(Model):
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
weights_loader = get_loader(
quantize=quantize, model_id=model_id, revision=revision
)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
weights = Weights(
filenames,
device,
dtype,
process_group=self.process_group,
weights_loader=weights_loader,
)
model = MambaModel(config, weights)
torch.distributed.barrier(group=self.process_group)
super(Mamba, self).__init__(

View File

@ -18,6 +18,7 @@ from text_generation_server.utils import (
Weights,
)
from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.quantization import get_loader
from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.models import Model
from text_generation_server.models.types import (
@ -586,6 +587,9 @@ class Seq2SeqLM(Model):
)
tokenizer.bos_token_id = config.decoder_start_token_id
weights_loader = get_loader(
quantize=quantize, model_id=model_id, revision=revision
)
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
@ -594,6 +598,7 @@ class Seq2SeqLM(Model):
dtype=dtype,
process_group=self.process_group,
aliases=aliases,
weights_loader=weights_loader,
)
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)

View File

@ -0,0 +1,119 @@
from typing import Optional
import os
import json
from dataclasses import dataclass
from huggingface_hub import hf_hub_download
from text_generation_server.utils.weights import DefaultWeightsLoader, WeightsLoader
@dataclass
class _QuantizerConfig:
bits: int
checkpoint_format: Optional[str]
desc_act: bool
groupsize: int
quant_method: str
sym: bool
# We should probably do this with Pytantic JSON deserialization,
# but for now we'll stay close to the old _set_gptq_params.
def _get_quantizer_config(model_id, revision):
bits = 4
groupsize = -1
quant_method = "gptq"
checkpoint_format = None
sym = True
desc_act = False
filename = "config.json"
try:
if os.path.exists(os.path.join(model_id, filename)):
filename = os.path.join(model_id, filename)
else:
filename = hf_hub_download(model_id, filename=filename, revision=revision)
with open(filename, "r") as f:
data = json.load(f)
bits = data["quantization_config"]["bits"]
groupsize = data["quantization_config"]["group_size"]
# Order is important here, desc_act is missing on some real models
quant_method = data["quantization_config"]["quant_method"]
checkpoint_format = data["quantization_config"].get("checkpoint_format")
sym = data["quantization_config"]["sym"]
desc_act = data["quantization_config"]["desc_act"]
except Exception:
filename = "quantize_config.json"
try:
if os.path.exists(os.path.join(model_id, filename)):
filename = os.path.join(model_id, filename)
else:
filename = hf_hub_download(
model_id, filename=filename, revision=revision
)
with open(filename, "r") as f:
data = json.load(f)
bits = data["bits"]
groupsize = data["group_size"]
sym = data["sym"]
desc_act = data["desc_act"]
if "version" in data and data["version"] == "GEMM":
quant_method = "awq"
except Exception:
filename = "quant_config.json"
try:
if os.path.exists(os.path.join(model_id, filename)):
filename = os.path.join(model_id, filename)
else:
filename = hf_hub_download(
model_id, filename=filename, revision=revision
)
with open(filename, "r") as f:
data = json.load(f)
bits = data["w_bit"]
groupsize = data["q_group_size"]
desc_act = data["desc_act"]
if "version" in data and data["version"] == "GEMM":
quant_method = "awq"
except Exception:
pass
return _QuantizerConfig(
bits=bits,
groupsize=groupsize,
quant_method=quant_method,
checkpoint_format=checkpoint_format,
sym=sym,
desc_act=desc_act,
)
def get_loader(
quantize: Optional[str], model_id: str, revision: Optional[str]
) -> WeightsLoader:
quantizer_config = _get_quantizer_config(model_id, revision)
if quantize in {"awq", "gptq"}:
from text_generation_server.layers.gptq import GPTQWeightsLoader
return GPTQWeightsLoader(
bits=quantizer_config.bits,
desc_act=quantizer_config.desc_act,
groupsize=quantizer_config.groupsize,
quant_method=quantizer_config.quant_method,
quantize=quantize,
sym=quantizer_config.sym,
)
elif quantize == "exl2":
from text_generation_server.layers.exl2 import Exl2WeightsLoader
return Exl2WeightsLoader()
elif quantize == "marlin":
from text_generation_server.layers.marlin import MarlinWeightsLoader
return MarlinWeightsLoader(
bits=quantizer_config.bits,
is_marlin_24=quantizer_config.checkpoint_format == "marlin_24",
)
else:
return DefaultWeightsLoader()

View File

@ -1,13 +1,89 @@
import os
from abc import ABC, abstractmethod
from contextlib import contextmanager
from pathlib import Path
from typing import Dict, List, Optional, Union
from safetensors import safe_open, SafetensorError
from safetensors import safe_open
import torch
from loguru import logger
from huggingface_hub import hf_hub_download
import json
from text_generation_server.layers.gptq import GPTQParams
from text_generation_server.utils.log import log_once
class WeightsLoader(ABC):
"""
Instances of this type implement higher-level weight loading.
At a low-level, every weight is stored in the Safetensors format.
The interpretation of weights may be different however, for instance
could be packed, quantized weights. Loaders are responsible for
interpreting the raw tensors, sharding tensors in a manner compatible
with the format, etc.
"""
@abstractmethod
def get_weights_col_packed(
self,
weights: "Weights",
prefix: str,
block_sizes: Union[int, List[int]],
):
"""
Get the packed weights at the given prefix with column-splitting for
tensor parallelism. This method should be used when multiple different
weights are packed into a tensor, for instance, query/key/value
weights or a gate/up projection.
The `block_sizes` determines the proportions of the packed tensors.
The columns are split in equally sized blocks when `block_sizes` is an
`int`, or in blocks proportional given to the sizes. For instance
`[2, 1, 1]` will divide an input with dimensionality `1024` in
`[512, 256, 256]`.
"""
...
def get_weights_col(self, weights: "Weights", prefix: str):
"""
Get weights at the given prefix and apply column-splitting for tensor
paralllism.
"""
return weights.get_multi_weights_col([prefix], 0)
@abstractmethod
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
"""
Get the weights at the given prefixes, column-split them for tensor
parallelim, and then concatenate the weights along the given dimension.
"""
...
@abstractmethod
def get_weights_row(self, weights: "Weights", prefix: str):
"""
Get the weights at the given prefix and apply row-splitting for tensor
parallism.
"""
...
class DefaultWeightsLoader(WeightsLoader):
"""
Loader that uses tensors as-is with the exception of applying sharding
and/or concatenation.
"""
def get_weights_col_packed(
self,
weights: "Weights",
prefix: str,
block_sizes: Union[int, List[int]],
):
return weights.get_packed_sharded(
f"{prefix}.weight", dim=0, block_sizes=block_sizes
)
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
return torch.cat(w, dim=dim)
def get_weights_row(self, weights: "Weights", prefix: str):
return weights.get_sharded(f"{prefix}.weight", dim=1)
class Weights:
@ -17,6 +93,7 @@ class Weights:
device,
dtype,
process_group,
weights_loader: WeightsLoader,
aliases: Optional[Dict[str, List[str]]] = None,
prefix: Optional[str] = None,
):
@ -37,6 +114,7 @@ class Weights:
self.dtype = dtype
self.process_group = process_group
self.prefix = prefix
self.weights_loader = weights_loader
self._handles = {}
def _get_handle(self, filename):
@ -69,6 +147,13 @@ class Weights:
slice_ = f.get_slice(tensor_name)
return slice_
def _has_tensor(self, tensor_name: str):
try:
self.get_filename(tensor_name)
except Exception:
return False
return True
def get_shape(self, tensor_name: str):
return self._get_slice(tensor_name).get_shape()
@ -176,300 +261,31 @@ class Weights:
def get_weights_col_packed_qkv(
self,
prefix: str,
quantize: str,
num_heads: int,
num_key_value_heads: int,
):
return self.get_weights_col_packed(
prefix, quantize, [num_heads, num_key_value_heads, num_key_value_heads]
prefix, [num_heads, num_key_value_heads, num_key_value_heads]
)
def get_weights_col_packed_gate_up(self, prefix: str, quantize: str):
return self.get_weights_col_packed(prefix, quantize, 2)
def get_weights_col_packed_gate_up(self, prefix: str):
return self.get_weights_col_packed(prefix, 2)
def get_weights_col_packed(
self, prefix: str, quantize: str, block_sizes: Union[int, List[int]]
):
def get_weights_col_packed(self, prefix: str, block_sizes: Union[int, List[int]]):
"""
Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being
already alternating Q,K,V within the main tensor.
The columns are split in equally sized blocks when blocks is an `int`, or
in blocks proportional given to the sizes. For instance `[2, 1, 1]` will
divide an input with dimensionality `1024` in `[512, 256, 256]`. This is
convenient for e.g. splitting QKV without knowing the storage details of
quantized weights.
"""
if quantize in ["gptq", "awq"]:
from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.layers.marlin import (
can_use_gptq_marlin,
repack_gptq_for_marlin,
)
return self.weights_loader.get_weights_col_packed(self, prefix, block_sizes)
try:
qweight = self.get_packed_sharded(
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{quantize}` weight, make sure the model is already quantized."
)
scales = self.get_packed_sharded(
f"{prefix}.scales", dim=1, block_sizes=block_sizes
)
scales = scales.to(dtype=self.dtype)
def get_weights_col(self, prefix: str):
return self.weights_loader.get_weights_col(self, prefix)
gptq_params = self._get_gptq_params()
if can_use_gptq_marlin(gptq_params, quantize):
g_idx = self.get_tensor(f"{prefix}.g_idx")
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=gptq_params.bits,
desc_act=gptq_params.desc_act,
groupsize=gptq_params.groupsize,
sym=gptq_params.sym,
sharded_infeatures=False,
)
qzeros = self.get_packed_sharded(
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
)
if quantize == "gptq" and gptq_params.quant_method == "gptq":
g_idx = self.get_tensor(f"{prefix}.g_idx")
elif quantize == "gptq" and gptq_params.quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
g_idx = (
torch.arange(
qweight.shape[0] * (32 // gptq_params.bits),
device=qweight.device,
)
// gptq_params.groupsize
).to(dtype=torch.int32)
else:
g_idx = None
weight = GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=gptq_params.bits,
groupsize=gptq_params.groupsize,
use_exllama=False,
)
elif quantize == "marlin":
from text_generation_server.layers.marlin import (
GPTQMarlin24Weight,
MarlinWeight,
repack_gptq_for_marlin,
)
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
if is_marlin_24:
B = self.get_packed_sharded(
f"{prefix}.B_24", dim=1, block_sizes=block_sizes
)
B_meta = self.get_packed_sharded(
f"{prefix}.B_meta", dim=1, block_sizes=block_sizes
)
s = self.get_packed_sharded(
f"{prefix}.s", dim=1, block_sizes=block_sizes
)
gptq_params = self._get_gptq_params()
weight = GPTQMarlin24Weight(
B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
)
else:
B = self.get_packed_sharded(
f"{prefix}.B", dim=1, block_sizes=block_sizes
)
s = self.get_packed_sharded(
f"{prefix}.s", dim=1, block_sizes=block_sizes
)
weight = MarlinWeight(B=B, s=s)
else:
weight = self.get_packed_sharded(
f"{prefix}.weight", dim=0, block_sizes=block_sizes
)
return weight
def get_weights_col(self, prefix: str, quantize: str):
if quantize == "exl2":
from text_generation_server.layers.exl2 import Exl2Weight
try:
q_weight = self.get_tensor(f"{prefix}.q_weight")
except RuntimeError:
raise RuntimeError(
f"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
)
q_scale = self.get_tensor(f"{prefix}.q_scale")
q_invperm = self.get_tensor(f"{prefix}.q_invperm")
q_scale_max = self.get_tensor(f"{prefix}.q_scale_max")
q_groups = self.get_tensor(f"{prefix}.q_groups")
return Exl2Weight(
q_weight=q_weight,
q_scale=q_scale,
q_invperm=q_invperm,
q_scale_max=q_scale_max,
q_groups=q_groups,
)
return self.get_multi_weights_col([prefix], quantize, 0)
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
if quantize == "exl2":
raise ValueError("get_multi_weights_col is not supported for exl2")
elif quantize in ["gptq", "awq"]:
from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.layers.marlin import (
can_use_gptq_marlin,
repack_gptq_for_marlin,
)
try:
qweight = torch.cat(
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{quantize}` weight, make sure the model is already quantized"
)
scales = torch.cat(
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
)
gptq_params = self._get_gptq_params()
if can_use_gptq_marlin(gptq_params, quantize):
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=gptq_params.bits,
desc_act=gptq_params.desc_act,
groupsize=gptq_params.groupsize,
sym=gptq_params.sym,
sharded_infeatures=False,
)
qzeros = torch.cat(
[self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
)
from text_generation_server.layers.gptq import HAS_EXLLAMA
use_exllama = (
gptq_params.bits == 4
and HAS_EXLLAMA
and quantize == "gptq"
and not gptq_params.desc_act
)
if quantize == "gptq" and gptq_params.quant_method == "gptq":
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
elif quantize == "gptq" and gptq_params.quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
if use_exllama:
g_idx = None
else:
g_idx = (
torch.arange(
qweight.shape[0] * (32 // gptq_params.bits),
device=qweight.device,
)
// gptq_params.groupsize
).to(dtype=torch.int32)
else:
g_idx = None
weight = GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=gptq_params.bits,
groupsize=gptq_params.groupsize,
use_exllama=use_exllama,
)
elif quantize == "marlin":
from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.layers.marlin import (
GPTQMarlin24Weight,
MarlinWeight,
)
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
if is_marlin_24:
try:
B = torch.cat(
[self.get_sharded(f"{p}.B_24", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{quantize}` weight, make sure the model is already quantized"
)
B_meta = torch.cat(
[self.get_sharded(f"{p}.B_meta", dim=1) for p in prefixes], dim=1
)
s = torch.cat(
[self.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
)
gptq_params = self._get_gptq_params()
weight = GPTQMarlin24Weight(
B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
)
else:
try:
B = torch.cat(
[self.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{quantize}` weight, make sure the model is already quantized"
)
s = torch.cat(
[self.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
)
weight = MarlinWeight(B=B, s=s)
else:
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
weight = torch.cat(w, dim=dim)
return weight
def get_multi_weights_col(self, prefixes: List[str], dim: int):
return self.weights_loader.get_multi_weights_col(self, prefixes, dim)
def get_tensor_shard(self, var, dim):
world_size = self.process_group.size()
@ -487,324 +303,22 @@ class Weights:
tensor = tensor.to(device=self.device)
return tensor
def get_multi_weights_row(self, prefix: str, quantize: str):
if quantize == "exl2":
from text_generation_server.layers.exl2 import Exl2Weight
def get_weights_row(self, prefix: str):
return self.weights_loader.get_weights_row(self, prefix)
try:
q_weight = self.get_tensor(f"{prefix}.q_weight")
except RuntimeError:
raise RuntimeError(
f"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
)
@contextmanager
def use_loader(self, weights_loader: WeightsLoader):
"""
This method is a context manager that can be used to use `Weights` with
a different loader for the duration of the context.
"""
q_scale = self.get_tensor(f"{prefix}.q_scale")
q_invperm = self.get_tensor(f"{prefix}.q_invperm")
q_scale_max = self.get_tensor(f"{prefix}.q_scale_max")
q_groups = self.get_tensor(f"{prefix}.q_groups")
return Exl2Weight(
q_weight=q_weight,
q_scale=q_scale,
q_invperm=q_invperm,
q_scale_max=q_scale_max,
q_groups=q_groups,
)
elif quantize == "gptq":
from text_generation_server.layers.marlin import (
can_use_gptq_marlin,
repack_gptq_for_marlin,
)
gptq_params = self._get_gptq_params()
if can_use_gptq_marlin(gptq_params, quantize):
log_once(logger.info, "Using GPTQ-Marlin kernels")
try:
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
)
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
if gptq_params.desc_act or gptq_params.groupsize == -1:
scales = self.get_tensor(f"{prefix}.scales")
else:
scales = self.get_sharded(f"{prefix}.scales", dim=0)
sharded_in_features = self.process_group.size() > 1
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
g_idx=g_idx,
bits=gptq_params.bits,
desc_act=gptq_params.desc_act,
groupsize=gptq_params.groupsize,
sym=gptq_params.sym,
sharded_infeatures=sharded_in_features,
)
use_exllama = True
if gptq_params.bits != 4:
use_exllama = False
if gptq_params.desc_act:
log_once(logger.warning, "Disabling exllama because desc_act=True")
use_exllama = False
try:
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
if gptq_params.quant_method == "gptq":
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
elif gptq_params.quant_method == "awq":
g_idx = None
if self.process_group.size() > 1:
if g_idx is not None:
if (
not torch.equal(
g_idx.cpu(),
torch.tensor(
[
i // gptq_params.groupsize
for i in range(g_idx.shape[0])
],
dtype=torch.int32,
),
)
and not (g_idx == 0).all()
):
# Exllama implementation does not support row tensor parallelism with act-order, as
# it would require to reorder input activations that are split unto several GPUs
use_exllama = False
from text_generation_server.layers.gptq import (
HAS_EXLLAMA,
CAN_EXLLAMA,
GPTQWeight,
)
if use_exllama:
if not HAS_EXLLAMA:
if CAN_EXLLAMA:
log_once(
logger.warning,
"Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True",
)
use_exllama = False
else:
log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")
if use_exllama and gptq_params.groupsize != -1:
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
scales = self.get_sharded(f"{prefix}.scales", dim=0)
else:
qzeros = self.get_tensor(f"{prefix}.qzeros")
scales = self.get_tensor(f"{prefix}.scales")
if use_exllama and g_idx is not None:
g_idx = g_idx - g_idx[0]
if gptq_params.quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq,
)
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
if use_exllama:
g_idx = None
else:
g_idx = (
torch.arange(
qweight.shape[0] * (32 // gptq_params.bits),
device=qweight.device,
)
// gptq_params.groupsize
).to(dtype=torch.int32)
weight = GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=gptq_params.bits,
groupsize=gptq_params.groupsize,
use_exllama=use_exllama,
)
elif quantize == "awq":
from text_generation_server.layers.gptq import GPTQWeight
gptq_params = self._get_gptq_params()
try:
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `awq` weight, make sure the model is already quantized"
)
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
scales = self.get_sharded(f"{prefix}.scales", dim=0)
g_idx = None
use_exllama = False
weight = GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=gptq_params.bits,
groupsize=gptq_params.groupsize,
use_exllama=use_exllama,
)
elif quantize == "marlin":
from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.layers.marlin import (
GPTQMarlin24Weight,
MarlinWeight,
)
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
if is_marlin_24:
try:
B = self.get_sharded(f"{prefix}.B_24", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized."
)
B_meta = self.get_sharded(f"{prefix}.B_meta", dim=0)
num_groups = self._get_slice(f"{prefix}.s").get_shape()[0]
if num_groups == 1:
# The number of groups is 1 when groupsize == -1. share
# scales between all shards in this case.
s = self.get_tensor(f"{prefix}.s")
else:
s = self.get_sharded(f"{prefix}.s", dim=0)
gptq_params = self._get_gptq_params()
weight = GPTQMarlin24Weight(
B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
)
else:
try:
B = self.get_sharded(f"{prefix}.B", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` weight, make sure the model is already quantized."
)
num_groups = self._get_slice(f"{prefix}.s").get_shape()[0]
if num_groups == 1:
# The number of groups is 1 when groupsize == -1. share
# scales between all shards in this case.
s = self.get_tensor(f"{prefix}.s")
else:
s = self.get_sharded(f"{prefix}.s", dim=0)
weight = MarlinWeight(B=B, s=s)
else:
weight = self.get_sharded(f"{prefix}.weight", dim=1)
return weight
def _get_gptq_params(self) -> GPTQParams:
old_loader = self.weights_loader
self.weights_loader = weights_loader
try:
bits = self.get_tensor("gptq_bits").item()
groupsize = self.get_tensor("gptq_groupsize").item()
checkpoint_format = getattr(self, "gptq_checkpoint_format", None)
desc_act = False
sym = False
quant_method = "gptq"
except (SafetensorError, RuntimeError) as e:
try:
bits = self.gptq_bits
groupsize = self.gptq_groupsize
checkpoint_format = getattr(self, "gptq_checkpoint_format", None)
desc_act = getattr(self, "gptq_desc_act", False)
quant_method = getattr(self, "quant_method", "gptq")
sym = getattr(self, "sym", True)
except Exception:
raise e
return GPTQParams(
bits=bits,
checkpoint_format=checkpoint_format,
desc_act=desc_act,
groupsize=groupsize,
quant_method=quant_method,
sym=sym,
)
def _set_gptq_params(self, model_id, revision):
filename = "config.json"
self.quant_method = None
try:
if os.path.exists(os.path.join(model_id, filename)):
filename = os.path.join(model_id, filename)
else:
filename = hf_hub_download(
model_id, filename=filename, revision=revision
)
with open(filename, "r") as f:
data = json.load(f)
self.gptq_bits = data["quantization_config"]["bits"]
self.gptq_groupsize = data["quantization_config"]["group_size"]
# Order is important here, desc_act is missing on some real models
self.quant_method = data["quantization_config"]["quant_method"]
self.gptq_checkpoint_format = data["quantization_config"].get(
"checkpoint_format"
)
self.gptq_sym = data["quantization_config"]["sym"]
self.gptq_desc_act = data["quantization_config"]["desc_act"]
except Exception:
filename = "quantize_config.json"
try:
if os.path.exists(os.path.join(model_id, filename)):
filename = os.path.join(model_id, filename)
else:
filename = hf_hub_download(
model_id, filename=filename, revision=revision
)
with open(filename, "r") as f:
data = json.load(f)
self.gptq_bits = data["bits"]
self.gptq_groupsize = data["group_size"]
self.gptq_sym = data["sym"]
self.gptq_desc_act = data["desc_act"]
if "version" in data and data["version"] == "GEMM":
self.quant_method = "awq"
except Exception:
filename = "quant_config.json"
try:
if os.path.exists(os.path.join(model_id, filename)):
filename = os.path.join(model_id, filename)
else:
filename = hf_hub_download(
model_id, filename=filename, revision=revision
)
with open(filename, "r") as f:
data = json.load(f)
self.gptq_bits = data["w_bit"]
self.gptq_groupsize = data["q_group_size"]
self.gptq_desc_act = data["desc_act"]
if "version" in data and data["version"] == "GEMM":
self.quant_method = "awq"
except Exception:
if self.quant_method is None:
if "awq" in model_id.lower():
self.quant_method = "awq"
elif "gptq" in model_id.lower():
self.quant_method = "gptq"
yield
finally:
self.weights_loader = old_loader
def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]:

View File

@ -155,7 +155,7 @@ def check_openapi(check: bool):
filename,
],
capture_output=True,
).stdout.decode()
).stdout.decode("utf-8")
os.remove(tmp_filename)
if diff:
@ -164,11 +164,27 @@ def check_openapi(check: bool):
"OpenAPI documentation is not up-to-date, run `python update_doc.py` in order to update it"
)
return True
else:
os.rename(tmp_filename, filename)
print("OpenAPI documentation updated.")
return True
errors = subprocess.run(
[
"swagger-cli",
# allow for trailing whitespace since it's not significant
# and the precommit hook will remove it
"validate",
filename,
],
capture_output=True,
).stderr.decode("utf-8")
# The openapi specs fails on `exclusive_minimum` which is expected to be a boolean where
# utoipa outputs a value instead: https://github.com/juhaku/utoipa/issues/969
if not errors.startswith("Swagger schema validation failed."):
print(errors)
raise Exception(
f"OpenAPI documentation is invalid, `swagger-cli validate` showed some error:\n {errors}"
)
return True
def main():