Merge branch 'main' into ci_amd3
This commit is contained in:
commit
291453fe88
|
@ -30,6 +30,10 @@ jobs:
|
|||
id: install-router
|
||||
run: cargo install --path router/
|
||||
|
||||
- uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
|
@ -37,4 +41,5 @@ jobs:
|
|||
|
||||
- name: Check that documentation is up-to-date
|
||||
run: |
|
||||
npm install -g swagger-cli
|
||||
python update_doc.py --check
|
||||
|
|
|
@ -30,7 +30,7 @@ jobs:
|
|||
group: ${{ github.workflow }}-build-and-push-image-${{ inputs.hardware }}-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
# TODO see with @Glegendre to get CPU runner here instead
|
||||
runs-on: [self-hosted, nvidia-gpu , multi-gpu, 4-a10, ci]
|
||||
runs-on: [self-hosted, intel-cpu, 32-cpu, 256-ram, ci]
|
||||
permissions:
|
||||
contents: write
|
||||
packages: write
|
||||
|
@ -142,8 +142,8 @@ jobs:
|
|||
DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }}
|
||||
tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }}
|
||||
cache-from: type=registry,ref=registry-push.github-runners.huggingface.tech/api-inference/community/text-generation-inference:cache${{ env.LABEL }},mode=min
|
||||
cache-to: type=registry,ref=registry-push.github-runners.huggingface.tech/api-inference/community/text-generation-inference:cache${{ env.LABEL }},mode=min
|
||||
cache-from: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL }},mode=min,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min
|
||||
cache-to: type=s3,region=us-east-1,bucket=ci-docker-buildx-cache,name=text-generation-inference-cache${{ env.LABEL }},mode=min,access_key_id=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_ACCESS_KEY_ID }},secret_access_key=${{ secrets.S3_CI_DOCKER_BUILDX_CACHE_SECRET_ACCESS_KEY }},mode=min
|
||||
- name: Final
|
||||
id: final
|
||||
run: |
|
||||
|
|
|
@ -1935,17 +1935,6 @@ version = "2.7.4"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3"
|
||||
|
||||
[[package]]
|
||||
name = "metrics"
|
||||
version = "0.21.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fde3af1a009ed76a778cb84fdef9e7dbbdf5775ae3e4cc1f434a6a307f6f76c5"
|
||||
dependencies = [
|
||||
"ahash",
|
||||
"metrics-macros",
|
||||
"portable-atomic",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "metrics"
|
||||
version = "0.23.0"
|
||||
|
@ -1969,7 +1958,7 @@ dependencies = [
|
|||
"hyper-util",
|
||||
"indexmap 2.2.6",
|
||||
"ipnet",
|
||||
"metrics 0.23.0",
|
||||
"metrics",
|
||||
"metrics-util",
|
||||
"quanta",
|
||||
"thiserror",
|
||||
|
@ -1977,17 +1966,6 @@ dependencies = [
|
|||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "metrics-macros"
|
||||
version = "0.7.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "38b4faf00617defe497754acde3024865bc143d44a86799b24e191ecff91354f"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.68",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "metrics-util"
|
||||
version = "0.17.0"
|
||||
|
@ -1997,7 +1975,7 @@ dependencies = [
|
|||
"crossbeam-epoch",
|
||||
"crossbeam-utils",
|
||||
"hashbrown 0.14.5",
|
||||
"metrics 0.23.0",
|
||||
"metrics",
|
||||
"num_cpus",
|
||||
"quanta",
|
||||
"sketches-ddsketch",
|
||||
|
@ -3762,7 +3740,7 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "text-generation-benchmark"
|
||||
version = "2.1.1-dev0"
|
||||
version = "2.1.2-dev0"
|
||||
dependencies = [
|
||||
"average",
|
||||
"clap",
|
||||
|
@ -3783,7 +3761,7 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "text-generation-client"
|
||||
version = "2.1.1-dev0"
|
||||
version = "2.1.2-dev0"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"base64 0.22.1",
|
||||
|
@ -3801,7 +3779,7 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "text-generation-launcher"
|
||||
version = "2.1.1-dev0"
|
||||
version = "2.1.2-dev0"
|
||||
dependencies = [
|
||||
"clap",
|
||||
"ctrlc",
|
||||
|
@ -3820,7 +3798,7 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "text-generation-router"
|
||||
version = "2.1.1-dev0"
|
||||
version = "2.1.2-dev0"
|
||||
dependencies = [
|
||||
"async-stream",
|
||||
"axum 0.7.5",
|
||||
|
@ -3834,7 +3812,7 @@ dependencies = [
|
|||
"init-tracing-opentelemetry",
|
||||
"itertools 0.10.5",
|
||||
"jsonschema",
|
||||
"metrics 0.21.1",
|
||||
"metrics",
|
||||
"metrics-exporter-prometheus",
|
||||
"minijinja",
|
||||
"minijinja-contrib",
|
||||
|
|
|
@ -40,7 +40,9 @@ RUN cargo build --profile release-opt
|
|||
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
|
||||
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS pytorch-install
|
||||
|
||||
# NOTE: When updating PyTorch version, beware to remove `pip install nvidia-nccl-cu12==2.22.3` below in the Dockerfile. Context: https://github.com/huggingface/text-generation-inference/pull/2099
|
||||
ARG PYTORCH_VERSION=2.3.0
|
||||
|
||||
ARG PYTHON_VERSION=3.10
|
||||
# Keep in sync with `server/pyproject.toml
|
||||
ARG CUDA_VERSION=12.1
|
||||
|
@ -241,7 +243,10 @@ COPY server/Makefile server/Makefile
|
|||
RUN cd server && \
|
||||
make gen-server && \
|
||||
pip install -r requirements_cuda.txt && \
|
||||
pip install ".[bnb, accelerate, quantize, peft, outlines]" --no-cache-dir
|
||||
pip install ".[bnb, accelerate, quantize, peft, outlines]" --no-cache-dir && \
|
||||
pip install nvidia-nccl-cu12==2.22.3
|
||||
|
||||
ENV LD_PRELOAD=/opt/conda/lib/python3.10/site-packages/nvidia/nccl/lib/libnccl.so.2
|
||||
|
||||
# Deps before the binaries
|
||||
# The binaries change on every build given we burn the SHA into them
|
||||
|
|
27
README.md
27
README.md
|
@ -20,19 +20,20 @@ to power Hugging Chat, the Inference API and Inference Endpoint.
|
|||
|
||||
## Table of contents
|
||||
|
||||
- [Get Started](#get-started)
|
||||
- [API Documentation](#api-documentation)
|
||||
- [Using a private or gated model](#using-a-private-or-gated-model)
|
||||
- [A note on Shared Memory](#a-note-on-shared-memory-shm)
|
||||
- [Distributed Tracing](#distributed-tracing)
|
||||
- [Local Install](#local-install)
|
||||
- [CUDA Kernels](#cuda-kernels)
|
||||
- [Optimized architectures](#optimized-architectures)
|
||||
- [Run Mistral](#run-a-model)
|
||||
- [Run](#run)
|
||||
- [Quantization](#quantization)
|
||||
- [Develop](#develop)
|
||||
- [Testing](#testing)
|
||||
- [Get Started](#get-started)
|
||||
- [Docker](#docker)
|
||||
- [API documentation](#api-documentation)
|
||||
- [Using a private or gated model](#using-a-private-or-gated-model)
|
||||
- [A note on Shared Memory (shm)](#a-note-on-shared-memory-shm)
|
||||
- [Distributed Tracing](#distributed-tracing)
|
||||
- [Architecture](#architecture)
|
||||
- [Local install](#local-install)
|
||||
- [Optimized architectures](#optimized-architectures)
|
||||
- [Run locally](#run-locally)
|
||||
- [Run](#run)
|
||||
- [Quantization](#quantization)
|
||||
- [Develop](#develop)
|
||||
- [Testing](#testing)
|
||||
|
||||
Text Generation Inference (TGI) is a toolkit for deploying and serving Large Language Models (LLMs). TGI enables high-performance text generation for the most popular open-source LLMs, including Llama, Falcon, StarCoder, BLOOM, GPT-NeoX, and [more](https://huggingface.co/docs/text-generation-inference/supported_models). TGI implements many features, such as:
|
||||
|
||||
|
|
|
@ -61,7 +61,7 @@ class ChoiceDeltaToolCall(BaseModel):
|
|||
class ChoiceDelta(BaseModel):
|
||||
role: str
|
||||
content: Optional[str] = None
|
||||
tool_calls: Optional[ChoiceDeltaToolCall]
|
||||
tool_calls: Optional[ChoiceDeltaToolCall] = None
|
||||
|
||||
|
||||
class Choice(BaseModel):
|
||||
|
|
|
@ -492,12 +492,12 @@
|
|||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/Completion"
|
||||
"$ref": "#/components/schemas/CompletionFinal"
|
||||
}
|
||||
},
|
||||
"text/event-stream": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/CompletionCompleteChunk"
|
||||
"$ref": "#/components/schemas/Chunk"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -809,7 +809,6 @@
|
|||
"ChatRequest": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"model",
|
||||
"messages"
|
||||
],
|
||||
"properties": {
|
||||
|
@ -854,7 +853,8 @@
|
|||
"model": {
|
||||
"type": "string",
|
||||
"description": "[UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.",
|
||||
"example": "mistralai/Mistral-7B-Instruct-v0.2"
|
||||
"example": "mistralai/Mistral-7B-Instruct-v0.2",
|
||||
"nullable": true
|
||||
},
|
||||
"n": {
|
||||
"type": "integer",
|
||||
|
@ -1116,7 +1116,6 @@
|
|||
"CompletionRequest": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"model",
|
||||
"prompt"
|
||||
],
|
||||
"properties": {
|
||||
|
@ -1138,7 +1137,8 @@
|
|||
"model": {
|
||||
"type": "string",
|
||||
"description": "UNUSED\nID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.",
|
||||
"example": "mistralai/Mistral-7B-Instruct-v0.2"
|
||||
"example": "mistralai/Mistral-7B-Instruct-v0.2",
|
||||
"nullable": true
|
||||
},
|
||||
"prompt": {
|
||||
"$ref": "#/components/schemas/Prompt"
|
||||
|
@ -1324,6 +1324,17 @@
|
|||
}
|
||||
}
|
||||
},
|
||||
"FunctionName": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"name"
|
||||
],
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"GenerateParameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
@ -1708,6 +1719,72 @@
|
|||
}
|
||||
}
|
||||
},
|
||||
"MessageChunk": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "object",
|
||||
"required": [
|
||||
"text",
|
||||
"type"
|
||||
],
|
||||
"properties": {
|
||||
"text": {
|
||||
"type": "string"
|
||||
},
|
||||
"type": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"text"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"required": [
|
||||
"image_url",
|
||||
"type"
|
||||
],
|
||||
"properties": {
|
||||
"image_url": {
|
||||
"$ref": "#/components/schemas/Url"
|
||||
},
|
||||
"type": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"image_url"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"discriminator": {
|
||||
"propertyName": "type"
|
||||
}
|
||||
},
|
||||
"MessageContent": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/MessageChunk"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"OutputMessage": {
|
||||
"oneOf": [
|
||||
{
|
||||
"$ref": "#/components/schemas/TextMessage"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/ToolCallMessage"
|
||||
}
|
||||
]
|
||||
},
|
||||
"PrefillToken": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
|
@ -1834,6 +1911,23 @@
|
|||
}
|
||||
}
|
||||
},
|
||||
"TextMessage": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"role",
|
||||
"content"
|
||||
],
|
||||
"properties": {
|
||||
"content": {
|
||||
"type": "string",
|
||||
"example": "My name is David and I"
|
||||
},
|
||||
"role": {
|
||||
"type": "string",
|
||||
"example": "user"
|
||||
}
|
||||
}
|
||||
},
|
||||
"Token": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
|
@ -1906,6 +2000,41 @@
|
|||
}
|
||||
}
|
||||
},
|
||||
"ToolCallDelta": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"role",
|
||||
"tool_calls"
|
||||
],
|
||||
"properties": {
|
||||
"role": {
|
||||
"type": "string",
|
||||
"example": "assistant"
|
||||
},
|
||||
"tool_calls": {
|
||||
"$ref": "#/components/schemas/DeltaToolCall"
|
||||
}
|
||||
}
|
||||
},
|
||||
"ToolCallMessage": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"role",
|
||||
"tool_calls"
|
||||
],
|
||||
"properties": {
|
||||
"role": {
|
||||
"type": "string",
|
||||
"example": "assistant"
|
||||
},
|
||||
"tool_calls": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/ToolCall"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"ToolType": {
|
||||
"oneOf": [
|
||||
{
|
||||
|
@ -1929,6 +2058,17 @@
|
|||
}
|
||||
]
|
||||
},
|
||||
"Url": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"url"
|
||||
],
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"Usage": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
|
|
|
@ -11,6 +11,8 @@
|
|||
title: Using TGI with Intel Gaudi
|
||||
- local: installation_inferentia
|
||||
title: Using TGI with AWS Inferentia
|
||||
- local: installation_intel
|
||||
title: Using TGI with Intel GPUs
|
||||
- local: installation
|
||||
title: Installation from source
|
||||
- local: supported_models
|
||||
|
|
|
@ -103,6 +103,7 @@ Several variants of the model server exist that are actively supported by Huggin
|
|||
|
||||
- By default, the model server will attempt building [a server optimized for Nvidia GPUs with CUDA](https://huggingface.co/docs/text-generation-inference/installation_nvidia). The code for this version is hosted in the [main TGI repository](https://github.com/huggingface/text-generation-inference).
|
||||
- A [version optimized for AMD with ROCm](https://huggingface.co/docs/text-generation-inference/installation_amd) is hosted in the main TGI repository. Some model features differ.
|
||||
- A [version optimized for Intel GPUs](https://huggingface.co/docs/text-generation-inference/installation_intel) is hosted in the main TGI repository. Some model features differ.
|
||||
- The [version for Intel Gaudi](https://huggingface.co/docs/text-generation-inference/installation_gaudi) is maintained on a forked repository, often resynchronized with the main [TGI repository](https://github.com/huggingface/tgi-gaudi).
|
||||
- A [version for Neuron (AWS Inferentia2)](https://huggingface.co/docs/text-generation-inference/installation_inferentia) is maintained as part of [Optimum Neuron](https://github.com/huggingface/optimum-neuron/tree/main/text-generation-inference).
|
||||
- A version for Google TPUs is maintained as part of [Optimum TPU](https://github.com/huggingface/optimum-tpu/tree/main/text-generation-inference).
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
# Using TGI with Intel GPUs
|
||||
|
||||
TGI optimized models are supported on Intel Data Center GPU [Max1100](https://www.intel.com/content/www/us/en/products/sku/232876/intel-data-center-gpu-max-1100/specifications.html), [Max1550](https://www.intel.com/content/www/us/en/products/sku/232873/intel-data-center-gpu-max-1550/specifications.html), the recommended usage is through Docker.
|
||||
|
||||
|
||||
On a server powered by Intel GPUs, TGI can be launched with the following command:
|
||||
|
||||
```bash
|
||||
model=teknium/OpenHermes-2.5-Mistral-7B
|
||||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||
|
||||
docker run --rm --privileged --cap-add=sys_nice \
|
||||
--device=/dev/dri \
|
||||
--ipc=host --shm-size 1g --net host -v $volume:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:latest-intel \
|
||||
--model-id $model --cuda-graphs 0
|
||||
```
|
||||
|
||||
The launched TGI server can then be queried from clients, make sure to check out the [Consuming TGI](./basic_tutorials/consuming_tgi) guide.
|
|
@ -17,7 +17,7 @@ docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
|
|||
|
||||
### Supported hardware
|
||||
|
||||
TGI supports various hardware. Make sure to check the [Using TGI with Nvidia GPUs](./installation_nvidia), [Using TGI with AMD GPUs](./installation_amd), [Using TGI with Gaudi](./installation_gaudi), [Using TGI with Inferentia](./installation_inferentia) guides depending on which hardware you would like to deploy TGI on.
|
||||
TGI supports various hardware. Make sure to check the [Using TGI with Nvidia GPUs](./installation_nvidia), [Using TGI with AMD GPUs](./installation_amd), [Using TGI with Intel GPUs](./installation_intel), [Using TGI with Gaudi](./installation_gaudi), [Using TGI with Inferentia](./installation_inferentia) guides depending on which hardware you would like to deploy TGI on.
|
||||
|
||||
## Consuming TGI
|
||||
|
||||
|
|
|
@ -347,6 +347,8 @@ def launcher(event_loop):
|
|||
max_input_length: Optional[int] = None,
|
||||
max_batch_prefill_tokens: Optional[int] = None,
|
||||
max_total_tokens: Optional[int] = None,
|
||||
lora_adapters: Optional[List[str]] = None,
|
||||
cuda_graphs: Optional[List[int]] = None,
|
||||
):
|
||||
port = random.randint(8000, 10_000)
|
||||
master_port = random.randint(10_000, 20_000)
|
||||
|
@ -393,6 +395,14 @@ def launcher(event_loop):
|
|||
if max_total_tokens:
|
||||
args.append("--max-total-tokens")
|
||||
args.append(str(max_total_tokens))
|
||||
if lora_adapters:
|
||||
args.append("--lora-adapters")
|
||||
args.append(",".join(lora_adapters))
|
||||
if cuda_graphs:
|
||||
args.append("--cuda-graphs")
|
||||
args.append(",".join(map(str, cuda_graphs)))
|
||||
|
||||
print(" ".join(args), file=sys.stderr)
|
||||
|
||||
env["LOG_LEVEL"] = "info,text_generation_router=debug"
|
||||
|
||||
|
@ -433,6 +443,8 @@ def launcher(event_loop):
|
|||
max_input_length: Optional[int] = None,
|
||||
max_batch_prefill_tokens: Optional[int] = None,
|
||||
max_total_tokens: Optional[int] = None,
|
||||
lora_adapters: Optional[List[str]] = None,
|
||||
cuda_graphs: Optional[List[int]] = None,
|
||||
):
|
||||
port = random.randint(8000, 10_000)
|
||||
|
||||
|
@ -462,6 +474,12 @@ def launcher(event_loop):
|
|||
if max_total_tokens:
|
||||
args.append("--max-total-tokens")
|
||||
args.append(str(max_total_tokens))
|
||||
if lora_adapters:
|
||||
args.append("--lora-adapters")
|
||||
args.append(",".join(lora_adapters))
|
||||
if cuda_graphs:
|
||||
args.append("--cuda-graphs")
|
||||
args.append(",".join(map(str, cuda_graphs)))
|
||||
|
||||
client = docker.from_env()
|
||||
|
||||
|
|
|
@ -0,0 +1,251 @@
|
|||
{
|
||||
"details": {
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 40,
|
||||
"prefill": [],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.27416992,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.17016602,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 28737,
|
||||
"logprob": -2.7109375,
|
||||
"special": false,
|
||||
"text": "I"
|
||||
},
|
||||
{
|
||||
"id": 28809,
|
||||
"logprob": -1.5,
|
||||
"special": false,
|
||||
"text": "’"
|
||||
},
|
||||
{
|
||||
"id": 28719,
|
||||
"logprob": -0.34204102,
|
||||
"special": false,
|
||||
"text": "m"
|
||||
},
|
||||
{
|
||||
"id": 459,
|
||||
"logprob": -1.6914062,
|
||||
"special": false,
|
||||
"text": " not"
|
||||
},
|
||||
{
|
||||
"id": 1864,
|
||||
"logprob": -0.69140625,
|
||||
"special": false,
|
||||
"text": " sure"
|
||||
},
|
||||
{
|
||||
"id": 513,
|
||||
"logprob": -1.6171875,
|
||||
"special": false,
|
||||
"text": " if"
|
||||
},
|
||||
{
|
||||
"id": 315,
|
||||
"logprob": -1.3837891,
|
||||
"special": false,
|
||||
"text": " I"
|
||||
},
|
||||
{
|
||||
"id": 541,
|
||||
"logprob": -1.2226562,
|
||||
"special": false,
|
||||
"text": " can"
|
||||
},
|
||||
{
|
||||
"id": 1567,
|
||||
"logprob": -1.8652344,
|
||||
"special": false,
|
||||
"text": " come"
|
||||
},
|
||||
{
|
||||
"id": 582,
|
||||
"logprob": -0.0070228577,
|
||||
"special": false,
|
||||
"text": " up"
|
||||
},
|
||||
{
|
||||
"id": 395,
|
||||
"logprob": -0.0054092407,
|
||||
"special": false,
|
||||
"text": " with"
|
||||
},
|
||||
{
|
||||
"id": 28705,
|
||||
"logprob": -0.62597656,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 28770,
|
||||
"logprob": -0.0035572052,
|
||||
"special": false,
|
||||
"text": "3"
|
||||
},
|
||||
{
|
||||
"id": 4842,
|
||||
"logprob": -0.93603516,
|
||||
"special": false,
|
||||
"text": " unique"
|
||||
},
|
||||
{
|
||||
"id": 3085,
|
||||
"logprob": -0.028411865,
|
||||
"special": false,
|
||||
"text": " words"
|
||||
},
|
||||
{
|
||||
"id": 369,
|
||||
"logprob": -1.0400391,
|
||||
"special": false,
|
||||
"text": " that"
|
||||
},
|
||||
{
|
||||
"id": 6685,
|
||||
"logprob": -0.09710693,
|
||||
"special": false,
|
||||
"text": " describe"
|
||||
},
|
||||
{
|
||||
"id": 528,
|
||||
"logprob": -0.066467285,
|
||||
"special": false,
|
||||
"text": " me"
|
||||
},
|
||||
{
|
||||
"id": 28725,
|
||||
"logprob": -1.0722656,
|
||||
"special": false,
|
||||
"text": ","
|
||||
},
|
||||
{
|
||||
"id": 562,
|
||||
"logprob": -0.33422852,
|
||||
"special": false,
|
||||
"text": " but"
|
||||
},
|
||||
{
|
||||
"id": 315,
|
||||
"logprob": -0.5136719,
|
||||
"special": false,
|
||||
"text": " I"
|
||||
},
|
||||
{
|
||||
"id": 28809,
|
||||
"logprob": -0.8989258,
|
||||
"special": false,
|
||||
"text": "’"
|
||||
},
|
||||
{
|
||||
"id": 584,
|
||||
"logprob": -0.2076416,
|
||||
"special": false,
|
||||
"text": "ll"
|
||||
},
|
||||
{
|
||||
"id": 1464,
|
||||
"logprob": -0.8808594,
|
||||
"special": false,
|
||||
"text": " try"
|
||||
},
|
||||
{
|
||||
"id": 28723,
|
||||
"logprob": -0.88427734,
|
||||
"special": false,
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.91064453,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.08105469,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 28740,
|
||||
"logprob": -1.8486328,
|
||||
"special": false,
|
||||
"text": "1"
|
||||
},
|
||||
{
|
||||
"id": 28723,
|
||||
"logprob": -0.111572266,
|
||||
"special": false,
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 23626,
|
||||
"logprob": -3.15625,
|
||||
"special": false,
|
||||
"text": " Creative"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.9194336,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 28750,
|
||||
"logprob": -0.24841309,
|
||||
"special": false,
|
||||
"text": "2"
|
||||
},
|
||||
{
|
||||
"id": 28723,
|
||||
"logprob": -9.393692e-05,
|
||||
"special": false,
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 6785,
|
||||
"logprob": -3.1386719,
|
||||
"special": false,
|
||||
"text": " Fun"
|
||||
},
|
||||
{
|
||||
"id": 1780,
|
||||
"logprob": -0.53564453,
|
||||
"special": false,
|
||||
"text": "ny"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.09033203,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 28770,
|
||||
"logprob": -0.00466156,
|
||||
"special": false,
|
||||
"text": "3"
|
||||
},
|
||||
{
|
||||
"id": 28723,
|
||||
"logprob": -0.00016450882,
|
||||
"special": false,
|
||||
"text": "."
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_text": "\n\nI’m not sure if I can come up with 3 unique words that describe me, but I’ll try.\n\n1. Creative\n2. Funny\n3."
|
||||
}
|
|
@ -0,0 +1,53 @@
|
|||
{
|
||||
"details": {
|
||||
"finish_reason": "eos_token",
|
||||
"generated_tokens": 7,
|
||||
"prefill": [],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": -0.49658203,
|
||||
"special": true,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 28705,
|
||||
"logprob": -0.0016384125,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": -1.4931641,
|
||||
"special": true,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 28705,
|
||||
"logprob": -0.00075769424,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 28740,
|
||||
"logprob": -0.25024414,
|
||||
"special": false,
|
||||
"text": "1"
|
||||
},
|
||||
{
|
||||
"id": 28740,
|
||||
"logprob": -0.2631836,
|
||||
"special": false,
|
||||
"text": "1"
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"logprob": -0.0003285408,
|
||||
"special": true,
|
||||
"text": "</s>"
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_text": " 11"
|
||||
}
|
|
@ -0,0 +1,251 @@
|
|||
{
|
||||
"details": {
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 40,
|
||||
"prefill": [],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -1.0488281,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -1.0800781,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 27332,
|
||||
"logprob": -2.1152344,
|
||||
"special": false,
|
||||
"text": "###"
|
||||
},
|
||||
{
|
||||
"id": 28705,
|
||||
"logprob": -1.6748047,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 28740,
|
||||
"logprob": -0.097229004,
|
||||
"special": false,
|
||||
"text": "1"
|
||||
},
|
||||
{
|
||||
"id": 28723,
|
||||
"logprob": -0.16467285,
|
||||
"special": false,
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 7615,
|
||||
"logprob": -2.2246094,
|
||||
"special": false,
|
||||
"text": " News"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -1.0488281,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 27332,
|
||||
"logprob": -0.69189453,
|
||||
"special": false,
|
||||
"text": "###"
|
||||
},
|
||||
{
|
||||
"id": 28705,
|
||||
"logprob": -0.013343811,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 28750,
|
||||
"logprob": -0.011230469,
|
||||
"special": false,
|
||||
"text": "2"
|
||||
},
|
||||
{
|
||||
"id": 28723,
|
||||
"logprob": -0.00096845627,
|
||||
"special": false,
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 21095,
|
||||
"logprob": -2.5605469,
|
||||
"special": false,
|
||||
"text": " Blog"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.19458008,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 27332,
|
||||
"logprob": -0.031280518,
|
||||
"special": false,
|
||||
"text": "###"
|
||||
},
|
||||
{
|
||||
"id": 28705,
|
||||
"logprob": -0.0030708313,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 28770,
|
||||
"logprob": -0.0029277802,
|
||||
"special": false,
|
||||
"text": "3"
|
||||
},
|
||||
{
|
||||
"id": 28723,
|
||||
"logprob": -0.0012350082,
|
||||
"special": false,
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 20108,
|
||||
"logprob": -2.1582031,
|
||||
"special": false,
|
||||
"text": " Article"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.05810547,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 27332,
|
||||
"logprob": -0.35083008,
|
||||
"special": false,
|
||||
"text": "###"
|
||||
},
|
||||
{
|
||||
"id": 28705,
|
||||
"logprob": -0.034332275,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 28781,
|
||||
"logprob": -0.009666443,
|
||||
"special": false,
|
||||
"text": "4"
|
||||
},
|
||||
{
|
||||
"id": 28723,
|
||||
"logprob": -0.0013113022,
|
||||
"special": false,
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 8349,
|
||||
"logprob": -2.6191406,
|
||||
"special": false,
|
||||
"text": " Review"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.04031372,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 27332,
|
||||
"logprob": -0.45239258,
|
||||
"special": false,
|
||||
"text": "###"
|
||||
},
|
||||
{
|
||||
"id": 28705,
|
||||
"logprob": -0.045410156,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 28782,
|
||||
"logprob": -0.0041236877,
|
||||
"special": false,
|
||||
"text": "5"
|
||||
},
|
||||
{
|
||||
"id": 28723,
|
||||
"logprob": -0.0010223389,
|
||||
"special": false,
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 5299,
|
||||
"logprob": -2.8066406,
|
||||
"special": false,
|
||||
"text": " Other"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.12054443,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.44580078,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -1.4921875,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -1.3574219,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -1.0039062,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.5859375,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.43481445,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.2783203,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.20410156,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_text": "\n\n### 1. News\n### 2. Blog\n### 3. Article\n### 4. Review\n### 5. Other\n\n\n\n\n\n\n\n\n"
|
||||
}
|
|
@ -0,0 +1,251 @@
|
|||
{
|
||||
"details": {
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 40,
|
||||
"prefill": [],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.31347656,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.27441406,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 28737,
|
||||
"logprob": -2.2285156,
|
||||
"special": false,
|
||||
"text": "I"
|
||||
},
|
||||
{
|
||||
"id": 28809,
|
||||
"logprob": -1.4677734,
|
||||
"special": false,
|
||||
"text": "’"
|
||||
},
|
||||
{
|
||||
"id": 28719,
|
||||
"logprob": -0.31762695,
|
||||
"special": false,
|
||||
"text": "m"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": -1.6865234,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 1215,
|
||||
"logprob": -3.2695312,
|
||||
"special": false,
|
||||
"text": " very"
|
||||
},
|
||||
{
|
||||
"id": 20640,
|
||||
"logprob": -3.1230469,
|
||||
"special": false,
|
||||
"text": " passionate"
|
||||
},
|
||||
{
|
||||
"id": 1338,
|
||||
"logprob": -0.48339844,
|
||||
"special": false,
|
||||
"text": " person"
|
||||
},
|
||||
{
|
||||
"id": 28723,
|
||||
"logprob": -0.9970703,
|
||||
"special": false,
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 315,
|
||||
"logprob": -0.5498047,
|
||||
"special": false,
|
||||
"text": " I"
|
||||
},
|
||||
{
|
||||
"id": 28809,
|
||||
"logprob": -1.1923828,
|
||||
"special": false,
|
||||
"text": "’"
|
||||
},
|
||||
{
|
||||
"id": 28719,
|
||||
"logprob": -0.080444336,
|
||||
"special": false,
|
||||
"text": "m"
|
||||
},
|
||||
{
|
||||
"id": 1215,
|
||||
"logprob": -1.8271484,
|
||||
"special": false,
|
||||
"text": " very"
|
||||
},
|
||||
{
|
||||
"id": 12215,
|
||||
"logprob": -2.8847656,
|
||||
"special": false,
|
||||
"text": " driven"
|
||||
},
|
||||
{
|
||||
"id": 28723,
|
||||
"logprob": -1.0927734,
|
||||
"special": false,
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 315,
|
||||
"logprob": -0.4584961,
|
||||
"special": false,
|
||||
"text": " I"
|
||||
},
|
||||
{
|
||||
"id": 28809,
|
||||
"logprob": -0.5019531,
|
||||
"special": false,
|
||||
"text": "’"
|
||||
},
|
||||
{
|
||||
"id": 28719,
|
||||
"logprob": -0.030715942,
|
||||
"special": false,
|
||||
"text": "m"
|
||||
},
|
||||
{
|
||||
"id": 1215,
|
||||
"logprob": -0.96972656,
|
||||
"special": false,
|
||||
"text": " very"
|
||||
},
|
||||
{
|
||||
"id": 7798,
|
||||
"logprob": -2.8847656,
|
||||
"special": false,
|
||||
"text": " determined"
|
||||
},
|
||||
{
|
||||
"id": 28723,
|
||||
"logprob": -0.27319336,
|
||||
"special": false,
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.56396484,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.011016846,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 3195,
|
||||
"logprob": -0.7163086,
|
||||
"special": false,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 349,
|
||||
"logprob": -1.1611328,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 574,
|
||||
"logprob": -0.515625,
|
||||
"special": false,
|
||||
"text": " your"
|
||||
},
|
||||
{
|
||||
"id": 6656,
|
||||
"logprob": -1.0253906,
|
||||
"special": false,
|
||||
"text": " favorite"
|
||||
},
|
||||
{
|
||||
"id": 1970,
|
||||
"logprob": -2.1738281,
|
||||
"special": false,
|
||||
"text": " thing"
|
||||
},
|
||||
{
|
||||
"id": 684,
|
||||
"logprob": -0.48364258,
|
||||
"special": false,
|
||||
"text": " about"
|
||||
},
|
||||
{
|
||||
"id": 1250,
|
||||
"logprob": -1.8876953,
|
||||
"special": false,
|
||||
"text": " being"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": -0.41967773,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 8626,
|
||||
"logprob": -2.9160156,
|
||||
"special": false,
|
||||
"text": " teacher"
|
||||
},
|
||||
{
|
||||
"id": 28804,
|
||||
"logprob": -0.11920166,
|
||||
"special": false,
|
||||
"text": "?"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.023727417,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.010848999,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 28737,
|
||||
"logprob": -1.0566406,
|
||||
"special": false,
|
||||
"text": "I"
|
||||
},
|
||||
{
|
||||
"id": 2016,
|
||||
"logprob": -0.7163086,
|
||||
"special": false,
|
||||
"text": " love"
|
||||
},
|
||||
{
|
||||
"id": 272,
|
||||
"logprob": -1.9169922,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 1639,
|
||||
"logprob": -2.03125,
|
||||
"special": false,
|
||||
"text": " fact"
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_text": "\n\nI’m a very passionate person. I’m very driven. I’m very determined.\n\nWhat is your favorite thing about being a teacher?\n\nI love the fact"
|
||||
}
|
|
@ -0,0 +1,134 @@
|
|||
import pytest
|
||||
import requests
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def lora_mistral_handle(launcher):
|
||||
with launcher(
|
||||
"mistralai/Mistral-7B-v0.1",
|
||||
lora_adapters=[
|
||||
"predibase/dbpedia",
|
||||
"predibase/customer_support",
|
||||
],
|
||||
cuda_graphs=[0],
|
||||
) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def lora_mistral(lora_mistral_handle):
|
||||
await lora_mistral_handle.health(300)
|
||||
return lora_mistral_handle.client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_lora_mistral(lora_mistral, response_snapshot):
|
||||
response = await lora_mistral.generate(
|
||||
"Test request", max_new_tokens=10, decoder_input_details=True
|
||||
)
|
||||
assert response.details.generated_tokens == 10
|
||||
|
||||
|
||||
classification_prompt = """You are given the title and the body of an article below. Please determine the type of the article.\n### Title: Great White Whale\n\n### Body: Great White Whale is the debut album by the Canadian rock band Secret and Whisper. The album was in the works for about a year and was released on February 12 2008. A music video was shot in Pittsburgh for the album's first single XOXOXO. The album reached number 17 on iTunes's top 100 albums in its first week on sale.\n\n### Article Type:"""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_lora_mistral_without_adapter(lora_mistral, response_snapshot):
|
||||
response = requests.post(
|
||||
f"{lora_mistral.base_url}/generate",
|
||||
headers=lora_mistral.headers,
|
||||
json={
|
||||
"inputs": classification_prompt,
|
||||
"parameters": {
|
||||
"max_new_tokens": 40,
|
||||
"details": True,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert (
|
||||
data["generated_text"]
|
||||
== "\n\n### 1. News\n### 2. Blog\n### 3. Article\n### 4. Review\n### 5. Other\n\n\n\n\n\n\n\n\n"
|
||||
)
|
||||
assert data == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_lora_mistral_with_dbpedia_adapter(lora_mistral, response_snapshot):
|
||||
response = requests.post(
|
||||
f"{lora_mistral.base_url}/generate",
|
||||
headers=lora_mistral.headers,
|
||||
json={
|
||||
"inputs": classification_prompt,
|
||||
"parameters": {
|
||||
"max_new_tokens": 40,
|
||||
"adapter_id": "predibase/dbpedia",
|
||||
"details": True,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["generated_text"] == " 11"
|
||||
assert data == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_lora_mistral_with_customer_support_adapter(
|
||||
lora_mistral, response_snapshot
|
||||
):
|
||||
print(lora_mistral.base_url)
|
||||
print(lora_mistral.headers)
|
||||
response = requests.post(
|
||||
f"{lora_mistral.base_url}/generate",
|
||||
headers=lora_mistral.headers,
|
||||
json={
|
||||
"inputs": "What are 3 unique words that describe you?",
|
||||
"parameters": {
|
||||
"max_new_tokens": 40,
|
||||
"adapter_id": "predibase/customer_support",
|
||||
"details": True,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert (
|
||||
data["generated_text"]
|
||||
== "\n\nI’m not sure if I can come up with 3 unique words that describe me, but I’ll try.\n\n1. Creative\n2. Funny\n3."
|
||||
)
|
||||
assert data == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_lora_mistral_without_customer_support_adapter(
|
||||
lora_mistral, response_snapshot
|
||||
):
|
||||
response = requests.post(
|
||||
f"{lora_mistral.base_url}/generate",
|
||||
headers=lora_mistral.headers,
|
||||
json={
|
||||
"inputs": "What are 3 unique words that describe you?",
|
||||
"parameters": {
|
||||
"max_new_tokens": 40,
|
||||
"details": True,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert (
|
||||
data["generated_text"]
|
||||
== "\n\nI’m a very passionate person. I’m very driven. I’m very determined.\n\nWhat is your favorite thing about being a teacher?\n\nI love the fact"
|
||||
)
|
||||
assert data == response_snapshot
|
|
@ -24,7 +24,7 @@ futures = "0.3.28"
|
|||
hf-hub = { workspace = true }
|
||||
itertools = "0.10"
|
||||
jsonschema = { version = "0.17.1", features = ["draft202012"] }
|
||||
metrics = "0.21.1"
|
||||
metrics = "0.23.0"
|
||||
metrics-exporter-prometheus = { version = "0.15.1", features = [] }
|
||||
nohash-hasher = "0.2.0"
|
||||
opentelemetry = { version = "0.20.0", features = ["rt-tokio"] }
|
||||
|
|
|
@ -91,14 +91,14 @@ impl Infer {
|
|||
.limit_concurrent_requests
|
||||
.try_acquire_owned()
|
||||
.map_err(|err| {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "overloaded");
|
||||
metrics::counter!("tgi_request_failure", "err" => "overloaded").increment(1);
|
||||
tracing::error!("{err}");
|
||||
err
|
||||
})?;
|
||||
|
||||
// Validate request
|
||||
let valid_request = self.validation.validate(request).await.map_err(|err| {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
|
||||
tracing::error!("{err}");
|
||||
err
|
||||
})?;
|
||||
|
@ -140,7 +140,7 @@ impl Infer {
|
|||
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
|
||||
.apply(messages, grammar_with_prompt)
|
||||
.map_err(|e| {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "template");
|
||||
metrics::counter!("tgi_request_failure", "err" => "template").increment(1);
|
||||
tracing::error!("{e}");
|
||||
e
|
||||
})
|
||||
|
@ -214,7 +214,7 @@ impl Infer {
|
|||
})
|
||||
} else {
|
||||
let err = InferError::IncompleteGeneration;
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "incomplete");
|
||||
metrics::counter!("tgi_request_failure", "err" => "incomplete").increment(1);
|
||||
tracing::error!("{err}");
|
||||
Err(err)
|
||||
}
|
||||
|
|
|
@ -111,7 +111,7 @@ async fn queue_task(
|
|||
match cmd {
|
||||
QueueCommand::Append(entry, span) => {
|
||||
span.in_scope(|| state.append(*entry));
|
||||
metrics::increment_gauge!("tgi_queue_size", 1.0);
|
||||
metrics::gauge!("tgi_queue_size").increment(1.0);
|
||||
}
|
||||
QueueCommand::NextBatch {
|
||||
min_size,
|
||||
|
@ -124,7 +124,7 @@ async fn queue_task(
|
|||
let next_batch =
|
||||
state.next_batch(min_size, max_size, prefill_token_budget, token_budget);
|
||||
response_sender.send(next_batch).unwrap();
|
||||
metrics::gauge!("tgi_queue_size", state.entries.len() as f64);
|
||||
metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64);
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
@ -226,7 +226,7 @@ impl State {
|
|||
// Filter entries where the response receiver was dropped (== entries where the request
|
||||
// was dropped by the client)
|
||||
if entry.response_tx.is_closed() {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
||||
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
||||
tracing::debug!("Dropping entry");
|
||||
continue;
|
||||
}
|
||||
|
@ -336,7 +336,7 @@ impl State {
|
|||
// Increment batch id
|
||||
self.next_batch_id += 1;
|
||||
|
||||
metrics::histogram!("tgi_batch_next_size", batch.size as f64);
|
||||
metrics::histogram!("tgi_batch_next_size").record(batch.size as f64);
|
||||
|
||||
Some((batch_entries, batch, next_batch_span))
|
||||
}
|
||||
|
|
|
@ -148,8 +148,8 @@ pub(crate) async fn batching_task(
|
|||
let batch_size = batch.size;
|
||||
let batch_max_tokens = batch.max_tokens;
|
||||
let mut batches = vec![batch];
|
||||
metrics::gauge!("tgi_batch_current_size", batch_size as f64);
|
||||
metrics::gauge!("tgi_batch_current_max_tokens", batch_max_tokens as f64);
|
||||
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
|
||||
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
|
||||
|
||||
let min_size = if waiting_tokens >= max_waiting_tokens {
|
||||
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
|
||||
|
@ -170,9 +170,11 @@ pub(crate) async fn batching_task(
|
|||
{
|
||||
// Tracking metrics
|
||||
if min_size.is_some() {
|
||||
metrics::increment_counter!("tgi_batch_concat", "reason" => "backpressure");
|
||||
metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
|
||||
.increment(1);
|
||||
} else {
|
||||
metrics::increment_counter!("tgi_batch_concat", "reason" => "wait_exceeded");
|
||||
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
|
||||
.increment(1);
|
||||
}
|
||||
|
||||
entries.iter_mut().for_each(|(_, entry)| {
|
||||
|
@ -219,8 +221,8 @@ pub(crate) async fn batching_task(
|
|||
.await;
|
||||
waiting_tokens += 1;
|
||||
}
|
||||
metrics::gauge!("tgi_batch_current_size", 0.0);
|
||||
metrics::gauge!("tgi_batch_current_max_tokens", 0.0);
|
||||
metrics::gauge!("tgi_batch_current_size").set(0.0);
|
||||
metrics::gauge!("tgi_batch_current_max_tokens").set(0.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -234,7 +236,7 @@ async fn prefill(
|
|||
) -> Option<CachedBatch> {
|
||||
let start_time = Instant::now();
|
||||
let batch_id = batch.id;
|
||||
metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill");
|
||||
metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);
|
||||
|
||||
match client.prefill(batch).await {
|
||||
Ok((generations, next_batch, timings)) => {
|
||||
|
@ -248,11 +250,15 @@ async fn prefill(
|
|||
// Filter next batch and remove requests that were stopped
|
||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||
|
||||
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill");
|
||||
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill");
|
||||
metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "prefill");
|
||||
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill");
|
||||
metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill");
|
||||
metrics::histogram!("tgi_batch_forward_duration","method" => "prefill")
|
||||
.record(timings.forward.as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
|
||||
.record(timings.decode.as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill")
|
||||
.record(start_filtering_time.elapsed().as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_inference_duration","method" => "prefill")
|
||||
.record(start_time.elapsed().as_secs_f64());
|
||||
metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1);
|
||||
next_batch
|
||||
}
|
||||
// If we have an error, we discard the whole batch
|
||||
|
@ -261,7 +267,7 @@ async fn prefill(
|
|||
generation_health.store(false, Ordering::SeqCst);
|
||||
let _ = client.clear_cache(Some(batch_id)).await;
|
||||
send_errors(err, entries);
|
||||
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill");
|
||||
metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1);
|
||||
None
|
||||
}
|
||||
}
|
||||
|
@ -276,7 +282,7 @@ async fn decode(
|
|||
) -> Option<CachedBatch> {
|
||||
let start_time = Instant::now();
|
||||
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
||||
metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode");
|
||||
metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1);
|
||||
|
||||
match client.decode(batches).await {
|
||||
Ok((generations, next_batch, timings)) => {
|
||||
|
@ -291,13 +297,18 @@ async fn decode(
|
|||
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||
|
||||
if let Some(concat_duration) = timings.concat {
|
||||
metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode");
|
||||
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
|
||||
.record(concat_duration.as_secs_f64());
|
||||
}
|
||||
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "decode");
|
||||
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "decode");
|
||||
metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "decode");
|
||||
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode");
|
||||
metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode");
|
||||
metrics::histogram!("tgi_batch_forward_duration", "method" => "decode")
|
||||
.record(timings.forward.as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_decode_duration", "method" => "decode")
|
||||
.record(timings.decode.as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_filter_duration", "method" => "decode")
|
||||
.record(start_filtering_time.elapsed().as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_inference_duration", "method" => "decode")
|
||||
.record(start_time.elapsed().as_secs_f64());
|
||||
metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1);
|
||||
next_batch
|
||||
}
|
||||
// If we have an error, we discard the whole batch
|
||||
|
@ -307,7 +318,7 @@ async fn decode(
|
|||
let _ = client.clear_cache(Some(id)).await;
|
||||
}
|
||||
send_errors(err, entries);
|
||||
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode");
|
||||
metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1);
|
||||
None
|
||||
}
|
||||
}
|
||||
|
@ -365,7 +376,7 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
|
|||
// request and we need to stop generating hence why we unwrap_or(true)
|
||||
let stopped = send_responses(generation, entry).map_err(|err| {
|
||||
tracing::error!("Entry response channel error.");
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
||||
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
||||
err
|
||||
}).unwrap_or(true);
|
||||
if stopped {
|
||||
|
@ -381,7 +392,7 @@ fn send_responses(
|
|||
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
|
||||
// Return directly if the channel is disconnected
|
||||
if entry.response_tx.is_closed() {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
||||
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
||||
return Ok(true);
|
||||
}
|
||||
|
||||
|
@ -407,7 +418,7 @@ fn send_responses(
|
|||
// Create last Token
|
||||
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
|
||||
let n = tokens_.ids.len();
|
||||
metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64);
|
||||
metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64);
|
||||
let mut iterator = tokens_
|
||||
.ids
|
||||
.into_iter()
|
||||
|
@ -472,7 +483,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
|||
// Create and enter a span to link this function back to the entry
|
||||
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
|
||||
let err = InferError::GenerationError(error.to_string());
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "generation");
|
||||
metrics::counter!("tgi_request_failure", "err" => "generation").increment(1);
|
||||
tracing::error!("{err}");
|
||||
|
||||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||
|
|
|
@ -126,7 +126,7 @@ async fn queue_task(
|
|||
match cmd {
|
||||
QueueCommand::Append(entry, span) => {
|
||||
span.in_scope(|| state.append(*entry));
|
||||
metrics::increment_gauge!("tgi_queue_size", 1.0);
|
||||
metrics::gauge!("tgi_queue_size").increment(1.0);
|
||||
}
|
||||
QueueCommand::NextBatch {
|
||||
min_size,
|
||||
|
@ -141,7 +141,7 @@ async fn queue_task(
|
|||
.instrument(span)
|
||||
.await;
|
||||
response_sender.send(next_batch).unwrap();
|
||||
metrics::gauge!("tgi_queue_size", state.entries.len() as f64);
|
||||
metrics::gauge!("tgi_queue_size").set(state.entries.len() as f64);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -248,7 +248,7 @@ impl State {
|
|||
// Filter entries where the response receiver was dropped (== entries where the request
|
||||
// was dropped by the client)
|
||||
if entry.response_tx.is_closed() {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
||||
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
||||
tracing::debug!("Dropping entry");
|
||||
continue;
|
||||
}
|
||||
|
@ -399,7 +399,7 @@ impl State {
|
|||
// Increment batch id
|
||||
self.next_batch_id += 1;
|
||||
|
||||
metrics::histogram!("tgi_batch_next_size", batch.size as f64);
|
||||
metrics::histogram!("tgi_batch_next_size").record(batch.size as f64);
|
||||
|
||||
Some((batch_entries, batch, next_batch_span))
|
||||
}
|
||||
|
|
|
@ -154,8 +154,8 @@ pub(crate) async fn batching_task(
|
|||
let batch_size = batch.size;
|
||||
let batch_max_tokens = batch.max_tokens;
|
||||
let mut batches = vec![batch];
|
||||
metrics::gauge!("tgi_batch_current_size", batch_size as f64);
|
||||
metrics::gauge!("tgi_batch_current_max_tokens", batch_max_tokens as f64);
|
||||
metrics::gauge!("tgi_batch_current_size").set(batch_size as f64);
|
||||
metrics::gauge!("tgi_batch_current_max_tokens").set(batch_max_tokens as f64);
|
||||
|
||||
let min_size = if waiting_tokens >= max_waiting_tokens {
|
||||
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
|
||||
|
@ -176,9 +176,11 @@ pub(crate) async fn batching_task(
|
|||
{
|
||||
// Tracking metrics
|
||||
if min_size.is_some() {
|
||||
metrics::increment_counter!("tgi_batch_concat", "reason" => "backpressure");
|
||||
metrics::counter!("tgi_batch_concat", "reason" => "backpressure")
|
||||
.increment(1);
|
||||
} else {
|
||||
metrics::increment_counter!("tgi_batch_concat", "reason" => "wait_exceeded");
|
||||
metrics::counter!("tgi_batch_concat", "reason" => "wait_exceeded")
|
||||
.increment(1);
|
||||
}
|
||||
|
||||
entries.iter_mut().for_each(|(_, entry)| {
|
||||
|
@ -225,8 +227,8 @@ pub(crate) async fn batching_task(
|
|||
.await;
|
||||
waiting_tokens += 1;
|
||||
}
|
||||
metrics::gauge!("tgi_batch_current_size", 0.0);
|
||||
metrics::gauge!("tgi_batch_current_max_tokens", 0.0);
|
||||
metrics::gauge!("tgi_batch_current_size").set(0.0);
|
||||
metrics::gauge!("tgi_batch_current_max_tokens").set(0.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -240,7 +242,7 @@ async fn prefill(
|
|||
) -> Option<CachedBatch> {
|
||||
let start_time = Instant::now();
|
||||
let batch_id = batch.id;
|
||||
metrics::increment_counter!("tgi_batch_inference_count", "method" => "prefill");
|
||||
metrics::counter!("tgi_batch_inference_count", "method" => "prefill").increment(1);
|
||||
|
||||
match client.prefill(batch).await {
|
||||
Ok((generations, next_batch, timings)) => {
|
||||
|
@ -254,11 +256,15 @@ async fn prefill(
|
|||
// Filter next batch and remove requests that were stopped
|
||||
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||
|
||||
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "prefill");
|
||||
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "prefill");
|
||||
metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "prefill");
|
||||
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "prefill");
|
||||
metrics::increment_counter!("tgi_batch_inference_success", "method" => "prefill");
|
||||
metrics::histogram!("tgi_batch_forward_duration","method" => "prefill")
|
||||
.record(timings.forward.as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_decode_duration", "method" => "prefill")
|
||||
.record(timings.decode.as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_filter_duration", "method" => "prefill")
|
||||
.record(start_filtering_time.elapsed().as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_inference_duration", "method" => "prefill")
|
||||
.record(start_time.elapsed().as_secs_f64());
|
||||
metrics::counter!("tgi_batch_inference_success", "method" => "prefill").increment(1);
|
||||
next_batch
|
||||
}
|
||||
// If we have an error, we discard the whole batch
|
||||
|
@ -267,7 +273,7 @@ async fn prefill(
|
|||
generation_health.store(false, Ordering::SeqCst);
|
||||
let _ = client.clear_cache(Some(batch_id)).await;
|
||||
send_errors(err, entries);
|
||||
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "prefill");
|
||||
metrics::counter!("tgi_batch_inference_failure", "method" => "prefill").increment(1);
|
||||
None
|
||||
}
|
||||
}
|
||||
|
@ -282,7 +288,7 @@ async fn decode(
|
|||
) -> Option<CachedBatch> {
|
||||
let start_time = Instant::now();
|
||||
let batch_ids: Vec<u64> = batches.iter().map(|b| b.id).collect();
|
||||
metrics::increment_counter!("tgi_batch_inference_count", "method" => "decode");
|
||||
metrics::counter!("tgi_batch_inference_count", "method" => "decode").increment(1);
|
||||
|
||||
match client.decode(batches).await {
|
||||
Ok((generations, next_batch, timings)) => {
|
||||
|
@ -297,13 +303,18 @@ async fn decode(
|
|||
let next_batch = filter_batch(client, next_batch, entries).await;
|
||||
|
||||
if let Some(concat_duration) = timings.concat {
|
||||
metrics::histogram!("tgi_batch_concat_duration", concat_duration.as_secs_f64(), "method" => "decode");
|
||||
metrics::histogram!("tgi_batch_concat_duration", "method" => "decode")
|
||||
.record(concat_duration.as_secs_f64());
|
||||
}
|
||||
metrics::histogram!("tgi_batch_forward_duration", timings.forward.as_secs_f64(), "method" => "decode");
|
||||
metrics::histogram!("tgi_batch_decode_duration", timings.decode.as_secs_f64(), "method" => "decode");
|
||||
metrics::histogram!("tgi_batch_filter_duration", start_filtering_time.elapsed().as_secs_f64(), "method" => "decode");
|
||||
metrics::histogram!("tgi_batch_inference_duration", start_time.elapsed().as_secs_f64(), "method" => "decode");
|
||||
metrics::increment_counter!("tgi_batch_inference_success", "method" => "decode");
|
||||
metrics::histogram!("tgi_batch_forward_duration", "method" => "decode")
|
||||
.record(timings.forward.as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_decode_duration", "method" => "decode")
|
||||
.record(timings.decode.as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_filter_duration", "method" => "decode")
|
||||
.record(start_filtering_time.elapsed().as_secs_f64());
|
||||
metrics::histogram!("tgi_batch_inference_duration", "method" => "decode")
|
||||
.record(start_time.elapsed().as_secs_f64());
|
||||
metrics::counter!("tgi_batch_inference_success", "method" => "decode").increment(1);
|
||||
next_batch
|
||||
}
|
||||
// If we have an error, we discard the whole batch
|
||||
|
@ -313,7 +324,7 @@ async fn decode(
|
|||
let _ = client.clear_cache(Some(id)).await;
|
||||
}
|
||||
send_errors(err, entries);
|
||||
metrics::increment_counter!("tgi_batch_inference_failure", "method" => "decode");
|
||||
metrics::counter!("tgi_batch_inference_failure", "method" => "decode").increment(1);
|
||||
None
|
||||
}
|
||||
}
|
||||
|
@ -371,7 +382,7 @@ fn filter_send_generations(generations: Vec<Generation>, entries: &mut IntMap<u6
|
|||
// request and we need to stop generating hence why we unwrap_or(true)
|
||||
let stopped = send_responses(generation, entry).map_err(|err| {
|
||||
tracing::error!("Entry response channel error.");
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
||||
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
||||
err
|
||||
}).unwrap_or(true);
|
||||
if stopped {
|
||||
|
@ -387,7 +398,7 @@ fn send_responses(
|
|||
) -> Result<bool, Box<SendError<Result<InferStreamResponse, InferError>>>> {
|
||||
// Return directly if the channel is disconnected
|
||||
if entry.response_tx.is_closed() {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "dropped");
|
||||
metrics::counter!("tgi_request_failure", "err" => "dropped").increment(1);
|
||||
return Ok(true);
|
||||
}
|
||||
|
||||
|
@ -413,7 +424,7 @@ fn send_responses(
|
|||
// Create last Token
|
||||
let tokens_ = generation.tokens.expect("Non empty tokens in generation");
|
||||
let n = tokens_.ids.len();
|
||||
metrics::histogram!("tgi_request_skipped_tokens", (n - 1) as f64);
|
||||
metrics::histogram!("tgi_request_skipped_tokens").record((n - 1) as f64);
|
||||
let mut iterator = tokens_
|
||||
.ids
|
||||
.into_iter()
|
||||
|
@ -478,7 +489,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
|
|||
// Create and enter a span to link this function back to the entry
|
||||
let _send_error_span = info_span!(parent: entry.temp_span.as_ref().expect("batch_span is None. This is a bug."), "send_error").entered();
|
||||
let err = InferError::GenerationError(error.to_string());
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "generation");
|
||||
metrics::counter!("tgi_request_failure", "err" => "generation").increment(1);
|
||||
tracing::error!("{err}");
|
||||
|
||||
// unwrap_or is valid here as we don't care if the receiver is gone.
|
||||
|
|
|
@ -384,7 +384,7 @@ pub struct CompletionRequest {
|
|||
/// UNUSED
|
||||
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
|
||||
/// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
|
||||
pub model: String,
|
||||
pub model: Option<String>,
|
||||
|
||||
/// The prompt to generate completions for.
|
||||
#[schema(example = "What is Deep Learning?")]
|
||||
|
@ -731,7 +731,7 @@ impl ChatCompletionChunk {
|
|||
pub(crate) struct ChatRequest {
|
||||
#[schema(example = "mistralai/Mistral-7B-Instruct-v0.2")]
|
||||
/// [UNUSED] ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
|
||||
pub model: String,
|
||||
pub model: Option<String>,
|
||||
|
||||
/// A list of messages comprising the conversation so far.
|
||||
#[schema(example = "[{\"role\": \"user\", \"content\": \"What is Deep Learning?\"}]")]
|
||||
|
@ -848,7 +848,7 @@ pub enum ToolType {
|
|||
Function { function: FunctionName },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema)]
|
||||
pub struct FunctionName {
|
||||
pub name: String,
|
||||
}
|
||||
|
|
|
@ -210,7 +210,11 @@ async fn main() -> Result<(), RouterError> {
|
|||
}
|
||||
let api = if use_api {
|
||||
if std::env::var("HF_HUB_OFFLINE") == Ok("1".to_string()) {
|
||||
let cache = Cache::default();
|
||||
let cache = std::env::var("HUGGINGFACE_HUB_CACHE")
|
||||
.map_err(|_| ())
|
||||
.map(|cache_dir| Cache::new(cache_dir.into()))
|
||||
.unwrap_or_else(|_| Cache::default());
|
||||
|
||||
tracing::warn!("Offline mode active using cache defaults");
|
||||
Type::Cache(cache)
|
||||
} else {
|
||||
|
|
|
@ -11,10 +11,11 @@ use crate::kserve::{
|
|||
};
|
||||
use crate::validation::ValidationError;
|
||||
use crate::{
|
||||
BestOfSequence, Details, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest,
|
||||
GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig, HubTokenizerConfig, Info,
|
||||
Message, PrefillToken, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeResponse,
|
||||
Usage, Validation,
|
||||
BestOfSequence, Details, ErrorResponse, FinishReason, FunctionName, GenerateParameters,
|
||||
GenerateRequest, GenerateResponse, GrammarType, HubModelInfo, HubProcessorConfig,
|
||||
HubTokenizerConfig, Info, Message, MessageChunk, MessageContent, OutputMessage, PrefillToken,
|
||||
SimpleToken, StreamDetails, StreamResponse, TextMessage, Token, TokenizeResponse,
|
||||
ToolCallDelta, ToolCallMessage, Url, Usage, Validation,
|
||||
};
|
||||
use crate::{
|
||||
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
|
||||
|
@ -185,7 +186,7 @@ pub(crate) async fn generate_internal(
|
|||
span: tracing::Span,
|
||||
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
|
||||
let start_time = Instant::now();
|
||||
metrics::increment_counter!("tgi_request_count");
|
||||
metrics::counter!("tgi_request_count").increment(1);
|
||||
|
||||
// Do not long ultra long inputs, like image payloads.
|
||||
tracing::debug!("Input: {}", &req.inputs[..1000.min(req.inputs.len())]);
|
||||
|
@ -301,25 +302,15 @@ pub(crate) async fn generate_internal(
|
|||
);
|
||||
|
||||
// Metrics
|
||||
metrics::increment_counter!("tgi_request_success");
|
||||
metrics::histogram!("tgi_request_duration", total_time.as_secs_f64());
|
||||
metrics::histogram!(
|
||||
"tgi_request_validation_duration",
|
||||
validation_time.as_secs_f64()
|
||||
);
|
||||
metrics::histogram!("tgi_request_queue_duration", queue_time.as_secs_f64());
|
||||
metrics::histogram!(
|
||||
"tgi_request_inference_duration",
|
||||
inference_time.as_secs_f64()
|
||||
);
|
||||
metrics::histogram!(
|
||||
"tgi_request_mean_time_per_token_duration",
|
||||
time_per_token.as_secs_f64()
|
||||
);
|
||||
metrics::histogram!(
|
||||
"tgi_request_generated_tokens",
|
||||
response.generated_text.generated_tokens as f64
|
||||
);
|
||||
metrics::counter!("tgi_request_success").increment(1);
|
||||
metrics::histogram!("tgi_request_duration").record(total_time.as_secs_f64());
|
||||
metrics::histogram!("tgi_request_validation_duration").record(validation_time.as_secs_f64());
|
||||
metrics::histogram!("tgi_request_queue_duration").record(queue_time.as_secs_f64());
|
||||
metrics::histogram!("tgi_request_inference_duration").record(inference_time.as_secs_f64());
|
||||
metrics::histogram!("tgi_request_mean_time_per_token_duration")
|
||||
.record(time_per_token.as_secs_f64());
|
||||
metrics::histogram!("tgi_request_generated_tokens")
|
||||
.record(response.generated_text.generated_tokens as f64);
|
||||
|
||||
// Send response
|
||||
let mut output_text = response.generated_text.text;
|
||||
|
@ -399,7 +390,7 @@ async fn generate_stream_internal(
|
|||
span: tracing::Span,
|
||||
) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) {
|
||||
let start_time = Instant::now();
|
||||
metrics::increment_counter!("tgi_request_count");
|
||||
metrics::counter!("tgi_request_count").increment(1);
|
||||
|
||||
tracing::debug!("Input: {}", req.inputs);
|
||||
|
||||
|
@ -427,12 +418,12 @@ async fn generate_stream_internal(
|
|||
let best_of = req.parameters.best_of.unwrap_or(1);
|
||||
if best_of != 1 {
|
||||
let err = InferError::from(ValidationError::BestOfStream);
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
|
||||
tracing::error!("{err}");
|
||||
yield Ok(Event::from(err));
|
||||
} else if req.parameters.decoder_input_details {
|
||||
let err = InferError::from(ValidationError::PrefillDetailsStream);
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
|
||||
tracing::error!("{err}");
|
||||
yield Ok(Event::from(err));
|
||||
} else {
|
||||
|
@ -500,13 +491,13 @@ async fn generate_stream_internal(
|
|||
span.record("seed", format!("{:?}", generated_text.seed));
|
||||
|
||||
// Metrics
|
||||
metrics::increment_counter!("tgi_request_success");
|
||||
metrics::histogram!("tgi_request_duration", total_time.as_secs_f64());
|
||||
metrics::histogram!("tgi_request_validation_duration", validation_time.as_secs_f64());
|
||||
metrics::histogram!("tgi_request_queue_duration", queue_time.as_secs_f64());
|
||||
metrics::histogram!("tgi_request_inference_duration", inference_time.as_secs_f64());
|
||||
metrics::histogram!("tgi_request_mean_time_per_token_duration", time_per_token.as_secs_f64());
|
||||
metrics::histogram!("tgi_request_generated_tokens", generated_text.generated_tokens as f64);
|
||||
metrics::counter!("tgi_request_success").increment(1);
|
||||
metrics::histogram!("tgi_request_duration").record(total_time.as_secs_f64());
|
||||
metrics::histogram!("tgi_request_validation_duration").record(validation_time.as_secs_f64());
|
||||
metrics::histogram!("tgi_request_queue_duration").record(queue_time.as_secs_f64());
|
||||
metrics::histogram!("tgi_request_inference_duration").record(inference_time.as_secs_f64());
|
||||
metrics::histogram!("tgi_request_mean_time_per_token_duration").record(time_per_token.as_secs_f64());
|
||||
metrics::histogram!("tgi_request_generated_tokens").record(generated_text.generated_tokens as f64);
|
||||
|
||||
// StreamResponse
|
||||
end_reached = true;
|
||||
|
@ -553,7 +544,7 @@ async fn generate_stream_internal(
|
|||
// Skip if we already sent an error
|
||||
if !end_reached && !error {
|
||||
let err = InferError::IncompleteGeneration;
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "incomplete");
|
||||
metrics::counter!("tgi_request_failure", "err" => "incomplete").increment(1);
|
||||
tracing::error!("{err}");
|
||||
yield Ok(Event::from(err));
|
||||
}
|
||||
|
@ -572,8 +563,8 @@ request_body = CompletionRequest,
|
|||
responses(
|
||||
(status = 200, description = "Generated Chat Completion",
|
||||
content(
|
||||
("application/json" = Completion),
|
||||
("text/event-stream" = CompletionCompleteChunk),
|
||||
("application/json" = CompletionFinal),
|
||||
("text/event-stream" = Chunk),
|
||||
)),
|
||||
(status = 424, description = "Generation Error", body = ErrorResponse,
|
||||
example = json ! ({"error": "Request failed during generation"})),
|
||||
|
@ -604,9 +595,10 @@ async fn completions(
|
|||
Json(req): Json<CompletionRequest>,
|
||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
let span = tracing::Span::current();
|
||||
metrics::increment_counter!("tgi_request_count");
|
||||
metrics::counter!("tgi_request_count").increment(1);
|
||||
|
||||
let CompletionRequest {
|
||||
model,
|
||||
max_tokens,
|
||||
seed,
|
||||
stop,
|
||||
|
@ -625,7 +617,7 @@ async fn completions(
|
|||
|
||||
// if suffix is present throw an error
|
||||
if req.suffix.is_some() {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
|
||||
return Err((
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Json(ErrorResponse {
|
||||
|
@ -637,7 +629,7 @@ async fn completions(
|
|||
}
|
||||
|
||||
if req.prompt.0.len() > info.max_client_batch_size {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
|
||||
return Err((
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Json(ErrorResponse {
|
||||
|
@ -675,7 +667,7 @@ async fn completions(
|
|||
seed,
|
||||
top_n_tokens: None,
|
||||
grammar: None,
|
||||
..Default::default()
|
||||
adapter_id: model.as_ref().filter(|m| *m != "tgi").map(String::from),
|
||||
},
|
||||
})
|
||||
.collect();
|
||||
|
@ -820,6 +812,10 @@ async fn completions(
|
|||
}
|
||||
};
|
||||
|
||||
let stream = stream.chain(futures::stream::once(async {
|
||||
Ok(Event::default().data("[DONE]"))
|
||||
}));
|
||||
|
||||
let sse = Sse::new(stream).keep_alive(KeepAlive::default());
|
||||
Ok((headers, sse).into_response())
|
||||
} else {
|
||||
|
@ -1009,8 +1005,9 @@ async fn chat_completions(
|
|||
Json(req): Json<ChatRequest>,
|
||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
let span = tracing::Span::current();
|
||||
metrics::increment_counter!("tgi_request_count");
|
||||
metrics::counter!("tgi_request_count").increment(1);
|
||||
let ChatRequest {
|
||||
model,
|
||||
logprobs,
|
||||
max_tokens,
|
||||
messages,
|
||||
|
@ -1039,7 +1036,7 @@ async fn chat_completions(
|
|||
|
||||
// response_format and tools are mutually exclusive
|
||||
if response_format.is_some() && tools.as_ref().is_some() {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
|
||||
return Err((
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
Json(ErrorResponse {
|
||||
|
@ -1053,7 +1050,7 @@ async fn chat_completions(
|
|||
let tool_grammar = match ToolGrammar::apply(tools, tool_choice) {
|
||||
Ok(grammar) => grammar,
|
||||
Err(err) => {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
|
||||
tracing::error!("{err}");
|
||||
return Err((
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
|
@ -1082,7 +1079,7 @@ async fn chat_completions(
|
|||
let inputs = match infer.apply_chat_template(messages, tools_grammar_prompt) {
|
||||
Ok(inputs) => inputs,
|
||||
Err(err) => {
|
||||
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
|
||||
metrics::counter!("tgi_request_failure", "err" => "validation").increment(1);
|
||||
tracing::error!("{err}");
|
||||
return Err((
|
||||
StatusCode::UNPROCESSABLE_ENTITY,
|
||||
|
@ -1116,7 +1113,7 @@ async fn chat_completions(
|
|||
seed,
|
||||
top_n_tokens: req.top_logprobs,
|
||||
grammar,
|
||||
..Default::default()
|
||||
adapter_id: model.filter(|m| *m != "tgi").map(String::from),
|
||||
},
|
||||
};
|
||||
|
||||
|
@ -1178,6 +1175,11 @@ async fn chat_completions(
|
|||
span,
|
||||
)
|
||||
.await;
|
||||
|
||||
let response_stream = response_stream.chain(futures::stream::once(async {
|
||||
Ok(Event::default().data("[DONE]"))
|
||||
}));
|
||||
|
||||
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
|
||||
Ok((headers, sse).into_response())
|
||||
} else {
|
||||
|
@ -1280,7 +1282,7 @@ async fn vertex_compatibility(
|
|||
Json(req): Json<VertexRequest>,
|
||||
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
|
||||
let span = tracing::Span::current();
|
||||
metrics::increment_counter!("tgi_request_count");
|
||||
metrics::counter!("tgi_request_count").increment(1);
|
||||
|
||||
// check that theres at least one instance
|
||||
if req.instances.is_empty() {
|
||||
|
@ -1454,6 +1456,14 @@ pub async fn run(
|
|||
GrammarType,
|
||||
ChatRequest,
|
||||
Message,
|
||||
MessageContent,
|
||||
MessageChunk,
|
||||
Url,
|
||||
FunctionName,
|
||||
OutputMessage,
|
||||
TextMessage,
|
||||
ToolCallMessage,
|
||||
ToolCallDelta,
|
||||
ChatCompletionComplete,
|
||||
ChatCompletionChoice,
|
||||
ChatCompletionDelta,
|
||||
|
|
|
@ -157,7 +157,7 @@ impl Validation {
|
|||
));
|
||||
}
|
||||
|
||||
metrics::histogram!("tgi_request_input_length", input_length as f64);
|
||||
metrics::histogram!("tgi_request_input_length").record(input_length as f64);
|
||||
Ok((inputs, input_length, max_new_tokens))
|
||||
}
|
||||
// Return inputs without validation
|
||||
|
@ -384,7 +384,7 @@ impl Validation {
|
|||
ignore_eos_token: false,
|
||||
};
|
||||
|
||||
metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64);
|
||||
metrics::histogram!("tgi_request_max_new_tokens").record(max_new_tokens as f64);
|
||||
|
||||
Ok(ValidGenerateRequest {
|
||||
inputs,
|
||||
|
|
|
@ -21,13 +21,14 @@ gen-server:
|
|||
install-server: gen-server
|
||||
pip install pip --upgrade
|
||||
pip install -r requirements_cuda.txt
|
||||
pip install -e ".[bnb, accelerate, quantize, peft, outlines]"
|
||||
pip install -e ".[accelerate, quantize, peft, outlines]"
|
||||
|
||||
|
||||
install: install-cuda
|
||||
echo "Installed server"
|
||||
|
||||
install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention
|
||||
pip install -e ".[bnb]"
|
||||
|
||||
install-rocm: install-server install-flash-attention-v2-rocm install-vllm-rocm
|
||||
|
||||
|
@ -35,5 +36,5 @@ run-dev:
|
|||
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded
|
||||
|
||||
export-requirements:
|
||||
poetry export -o requirements_cuda.txt --without-hashes
|
||||
poetry export -o requirements_cuda.txt --without-hashes --with cuda
|
||||
poetry export -o requirements_rocm.txt --without-hashes
|
||||
|
|
|
@ -59,3 +59,18 @@ def marlin_gemm(
|
|||
Matrix multiplication using Marlin kernels.
|
||||
"""
|
||||
...
|
||||
|
||||
# fp8 marlin
|
||||
def fp8_marlin_gemm(
|
||||
a: torch.Tensor,
|
||||
b_q_weight: torch.Tensor,
|
||||
b_scales: torch.Tensor,
|
||||
workspace: torch.Tensor,
|
||||
num_bits: int,
|
||||
size_m: int,
|
||||
size_n: int,
|
||||
size_k: int,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops._C.fp8_marlin_gemm(
|
||||
a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k
|
||||
)
|
||||
|
|
|
@ -9,4 +9,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||
m.def("gptq_marlin_repack", &gptq_marlin_repack,
|
||||
"Repack GPTQ parameters for Marlin");
|
||||
m.def("marlin_gemm", &marlin_gemm, "Marlin gemm");
|
||||
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
|
||||
m.def("fp8_marlin_gemm", &fp8_marlin_gemm);
|
||||
}
|
||||
|
|
|
@ -27,4 +27,9 @@ torch::Tensor marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
|
|||
torch::Tensor &b_scales, torch::Tensor &workspace,
|
||||
int64_t size_m, int64_t size_n, int64_t size_k);
|
||||
|
||||
torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_scales, torch::Tensor& workspace,
|
||||
int64_t num_bits, int64_t size_m, int64_t size_n,
|
||||
int64_t size_k);
|
||||
|
||||
#endif
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -9,6 +9,7 @@ setup(
|
|||
CUDAExtension(
|
||||
name="marlin_kernels",
|
||||
sources=[
|
||||
"marlin_kernels/fp8_marlin.cu",
|
||||
"marlin_kernels/gptq_marlin.cu",
|
||||
"marlin_kernels/gptq_marlin_repack.cu",
|
||||
"marlin_kernels/marlin_cuda_kernel.cu",
|
||||
|
|
|
@ -2,6 +2,7 @@ import torch
|
|||
from text_generation_server.layers import (
|
||||
TensorParallelEmbedding,
|
||||
)
|
||||
from text_generation_server.utils.weights import DefaultWeightsLoader
|
||||
|
||||
|
||||
class ProcessGroup:
|
||||
|
@ -42,7 +43,12 @@ class Weights:
|
|||
def test_weight_hub_files_offline_error():
|
||||
|
||||
vocab_size = 17
|
||||
weights = Weights(rank=0, world_size=1, vocab_size=vocab_size, hidden_dim=256)
|
||||
weights = Weights(
|
||||
rank=0,
|
||||
world_size=1,
|
||||
vocab_size=vocab_size,
|
||||
hidden_dim=256,
|
||||
)
|
||||
embeddings = TensorParallelEmbedding("", weights)
|
||||
|
||||
input_ids = torch.arange(vocab_size)
|
||||
|
|
|
@ -1,13 +1,47 @@
|
|||
import pytest
|
||||
import torch
|
||||
from text_generation_server.utils.weights import Weights
|
||||
from text_generation_server.layers.gptq import GPTQWeight
|
||||
from text_generation_server.layers.exl2 import Exl2Weight
|
||||
from text_generation_server.layers.marlin import MarlinWeight
|
||||
from text_generation_server.utils.weights import (
|
||||
DefaultWeightsLoader,
|
||||
Weights,
|
||||
WeightsLoader,
|
||||
)
|
||||
from text_generation_server.layers.gptq import GPTQWeight, GPTQWeightsLoader
|
||||
from text_generation_server.layers.exl2 import Exl2Weight, Exl2WeightsLoader
|
||||
from text_generation_server.layers.marlin import MarlinWeight, MarlinWeightsLoader
|
||||
from types import SimpleNamespace
|
||||
from typing import List, Optional, Dict, Union
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gptq_weights_loader():
|
||||
return GPTQWeightsLoader(
|
||||
bits=4,
|
||||
groupsize=-1,
|
||||
desc_act=False,
|
||||
quant_method="gptq",
|
||||
quantize="gptq",
|
||||
sym=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gptq_weights_loader_awq():
|
||||
return GPTQWeightsLoader(
|
||||
bits=4,
|
||||
groupsize=-1,
|
||||
desc_act=False,
|
||||
quant_method="awq",
|
||||
quantize="awq",
|
||||
sym=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def marlin_weights_loader():
|
||||
return MarlinWeightsLoader(bits=4, is_marlin_24=False)
|
||||
|
||||
|
||||
dummy_file_system = {
|
||||
"test_weights": {
|
||||
"layer.0.weight": torch.tensor(
|
||||
|
@ -58,7 +92,7 @@ dummy_file_system = {
|
|||
dtype=torch.float32,
|
||||
),
|
||||
},
|
||||
"test_get_multi_weights_row": {
|
||||
"test_get_weights_row": {
|
||||
"weight.weight": torch.tensor(
|
||||
[
|
||||
[1, 2],
|
||||
|
@ -101,7 +135,7 @@ dummy_file_system = {
|
|||
"weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
|
||||
"weight.s": torch.tensor([[0.5000], [0.2500]], dtype=torch.float16),
|
||||
},
|
||||
"test_get_multi_weights_row_gptq": {
|
||||
"test_get_weights_row_gptq": {
|
||||
"weight.qweight": torch.tensor(
|
||||
[
|
||||
[1, 2],
|
||||
|
@ -200,7 +234,7 @@ dummy_file_system = {
|
|||
"weight.q_scale_max": torch.tensor([100], dtype=torch.float16),
|
||||
"weight.q_groups": torch.tensor([4], dtype=torch.int16),
|
||||
},
|
||||
"test_get_multi_weights_row_exl2": {
|
||||
"test_get_weights_row_exl2": {
|
||||
"weight.q_weight": torch.tensor(
|
||||
[
|
||||
[1, 2],
|
||||
|
@ -245,7 +279,7 @@ dummy_file_system = {
|
|||
"weight.q_scale_max": torch.tensor([100], dtype=torch.float16),
|
||||
"weight.q_groups": torch.tensor([4], dtype=torch.int16),
|
||||
},
|
||||
"test_get_multi_weights_row_marlin": {
|
||||
"test_get_weights_row_marlin": {
|
||||
"weight.B": torch.tensor([[1, 2], [3, 4]], dtype=torch.int32),
|
||||
"weight.s": torch.tensor([[0.5], [0.25]], dtype=torch.float16),
|
||||
},
|
||||
|
@ -308,6 +342,7 @@ class MockWeights(Weights):
|
|||
dummy_fs,
|
||||
aliases: Optional[Dict[str, List[str]]] = None,
|
||||
prefix: Optional[str] = None,
|
||||
weights_loader: Optional[WeightsLoader] = None,
|
||||
):
|
||||
routing = {}
|
||||
self.dummy_fs = dummy_fs
|
||||
|
@ -327,6 +362,9 @@ class MockWeights(Weights):
|
|||
self.dtype = dtype
|
||||
self.process_group = process_group
|
||||
self.prefix = prefix
|
||||
self.weights_loader = (
|
||||
DefaultWeightsLoader() if weights_loader is None else weights_loader
|
||||
)
|
||||
self._handles = {}
|
||||
|
||||
def _get_handle(self, filename: Union[Path, str]):
|
||||
|
@ -412,12 +450,10 @@ def test_get_weights_col_packed():
|
|||
)
|
||||
|
||||
prefix = "weight"
|
||||
quantize = None
|
||||
block_sizes = 1
|
||||
|
||||
w = weights.get_weights_col_packed(
|
||||
prefix=prefix,
|
||||
quantize=quantize,
|
||||
block_sizes=block_sizes,
|
||||
)
|
||||
|
||||
|
@ -448,12 +484,10 @@ def test_get_weights_col_packed_block_size():
|
|||
)
|
||||
|
||||
prefix = "weight"
|
||||
quantize = None
|
||||
block_sizes = 2
|
||||
|
||||
w = weights.get_weights_col_packed(
|
||||
prefix=prefix,
|
||||
quantize=quantize,
|
||||
block_sizes=block_sizes,
|
||||
)
|
||||
|
||||
|
@ -484,12 +518,10 @@ def test_get_weights_col_packed_block_size_arr():
|
|||
)
|
||||
|
||||
prefix = "weight"
|
||||
quantize = None
|
||||
block_sizes = [1, 1]
|
||||
|
||||
w = weights.get_weights_col_packed(
|
||||
prefix=prefix,
|
||||
quantize=quantize,
|
||||
block_sizes=block_sizes,
|
||||
)
|
||||
|
||||
|
@ -519,11 +551,9 @@ def test_get_multi_weights_col():
|
|||
)
|
||||
|
||||
prefixes = ["weight", "weight"]
|
||||
quantize = None
|
||||
|
||||
w = weights.get_multi_weights_col(
|
||||
prefixes=prefixes,
|
||||
quantize=quantize,
|
||||
dim=0,
|
||||
)
|
||||
|
||||
|
@ -545,10 +575,10 @@ def test_get_multi_weights_col():
|
|||
)
|
||||
|
||||
|
||||
def test_get_multi_weights_row():
|
||||
def test_get_weights_row():
|
||||
weights = MockWeights(
|
||||
[
|
||||
"test_get_multi_weights_row",
|
||||
"test_get_weights_row",
|
||||
],
|
||||
device="cpu",
|
||||
dtype=torch.float32,
|
||||
|
@ -557,11 +587,9 @@ def test_get_multi_weights_row():
|
|||
)
|
||||
|
||||
prefix = "weight"
|
||||
quantize = None
|
||||
|
||||
w = weights.get_multi_weights_row(
|
||||
w = weights.get_weights_row(
|
||||
prefix=prefix,
|
||||
quantize=quantize,
|
||||
)
|
||||
|
||||
assert torch.allclose(
|
||||
|
@ -576,7 +604,7 @@ def test_get_multi_weights_row():
|
|||
# test_get_weights_col
|
||||
|
||||
|
||||
def test_get_weights_col_awq():
|
||||
def test_get_weights_col_awq(gptq_weights_loader_awq):
|
||||
weights = MockWeights(
|
||||
[
|
||||
"test_get_weights_col_gptq",
|
||||
|
@ -585,14 +613,13 @@ def test_get_weights_col_awq():
|
|||
dtype=torch.float32,
|
||||
process_group=dummy_process_group,
|
||||
dummy_fs=dummy_file_system,
|
||||
weights_loader=gptq_weights_loader_awq,
|
||||
)
|
||||
|
||||
prefix = "weight"
|
||||
quantize = "awq"
|
||||
|
||||
w = weights.get_weights_col(
|
||||
prefix=prefix,
|
||||
quantize=quantize,
|
||||
)
|
||||
|
||||
expected_weight = GPTQWeight(
|
||||
|
@ -617,7 +644,7 @@ def test_get_weights_col_awq():
|
|||
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
|
||||
|
||||
|
||||
def test_get_weights_col_gtpq():
|
||||
def test_get_weights_col_gtpq(gptq_weights_loader):
|
||||
weights = MockWeights(
|
||||
[
|
||||
"test_get_weights_col_gptq",
|
||||
|
@ -626,14 +653,13 @@ def test_get_weights_col_gtpq():
|
|||
dtype=torch.float32,
|
||||
process_group=dummy_process_group,
|
||||
dummy_fs=dummy_file_system,
|
||||
weights_loader=gptq_weights_loader,
|
||||
)
|
||||
|
||||
prefix = "weight"
|
||||
quantize = "gptq"
|
||||
|
||||
w = weights.get_weights_col(
|
||||
prefix=prefix,
|
||||
quantize=quantize,
|
||||
)
|
||||
|
||||
expected_weight = GPTQWeight(
|
||||
|
@ -664,14 +690,13 @@ def test_get_weights_col_exl2():
|
|||
dtype=torch.float32,
|
||||
process_group=dummy_process_group,
|
||||
dummy_fs=dummy_file_system,
|
||||
weights_loader=Exl2WeightsLoader(),
|
||||
)
|
||||
|
||||
prefix = "weight"
|
||||
quantize = "exl2"
|
||||
|
||||
w = weights.get_weights_col(
|
||||
prefix=prefix,
|
||||
quantize=quantize,
|
||||
)
|
||||
|
||||
scaled_scale_max = 0.3906 * 256
|
||||
|
@ -692,7 +717,7 @@ def test_get_weights_col_exl2():
|
|||
assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch"
|
||||
|
||||
|
||||
def test_get_weights_col_marlin():
|
||||
def test_get_weights_col_marlin(marlin_weights_loader):
|
||||
weights = MockWeights(
|
||||
[
|
||||
"test_get_weights_col_marlin",
|
||||
|
@ -701,14 +726,13 @@ def test_get_weights_col_marlin():
|
|||
dtype=torch.float16,
|
||||
process_group=dummy_process_group,
|
||||
dummy_fs=dummy_file_system,
|
||||
weights_loader=marlin_weights_loader,
|
||||
)
|
||||
|
||||
prefix = "weight"
|
||||
quantize = "marlin"
|
||||
|
||||
w = weights.get_weights_col(
|
||||
prefix=prefix,
|
||||
quantize=quantize,
|
||||
)
|
||||
|
||||
expected_weight = MarlinWeight(
|
||||
|
@ -723,7 +747,7 @@ def test_get_weights_col_marlin():
|
|||
# test_get_weights_col_packed
|
||||
|
||||
|
||||
def test_get_weights_col_packed_awq():
|
||||
def test_get_weights_col_packed_awq(gptq_weights_loader_awq):
|
||||
weights = MockWeights(
|
||||
[
|
||||
"test_get_weights_col_packed_gptq",
|
||||
|
@ -732,15 +756,14 @@ def test_get_weights_col_packed_awq():
|
|||
dtype=torch.float32,
|
||||
process_group=dummy_process_group,
|
||||
dummy_fs=dummy_file_system,
|
||||
weights_loader=gptq_weights_loader_awq,
|
||||
)
|
||||
|
||||
prefix = "weight"
|
||||
quantize = "awq"
|
||||
block_sizes = 1
|
||||
|
||||
w = weights.get_weights_col_packed(
|
||||
prefix=prefix,
|
||||
quantize=quantize,
|
||||
block_sizes=block_sizes,
|
||||
)
|
||||
|
||||
|
@ -773,15 +796,14 @@ def test_get_weights_col_packed_exl2():
|
|||
dtype=torch.float32,
|
||||
process_group=dummy_process_group,
|
||||
dummy_fs=dummy_file_system,
|
||||
weights_loader=Exl2WeightsLoader(),
|
||||
)
|
||||
|
||||
prefix = "weight"
|
||||
quantize = "exl2"
|
||||
block_sizes = 1
|
||||
|
||||
w = weights.get_weights_col_packed(
|
||||
prefix=prefix,
|
||||
quantize=quantize,
|
||||
block_sizes=block_sizes,
|
||||
)
|
||||
|
||||
|
@ -803,7 +825,7 @@ def test_get_weights_col_packed_exl2():
|
|||
assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch"
|
||||
|
||||
|
||||
def test_get_weights_col_packed_gptq():
|
||||
def test_get_weights_col_packed_gptq(gptq_weights_loader):
|
||||
weights = MockWeights(
|
||||
[
|
||||
"test_get_weights_col_packed_gptq",
|
||||
|
@ -812,14 +834,13 @@ def test_get_weights_col_packed_gptq():
|
|||
dtype=torch.float32,
|
||||
process_group=dummy_process_group,
|
||||
dummy_fs=dummy_file_system,
|
||||
weights_loader=gptq_weights_loader,
|
||||
)
|
||||
|
||||
prefixes = ["weight"]
|
||||
quantize = "gptq"
|
||||
|
||||
w = weights.get_multi_weights_col(
|
||||
prefixes=prefixes,
|
||||
quantize=quantize,
|
||||
dim=0,
|
||||
)
|
||||
|
||||
|
@ -842,7 +863,7 @@ def test_get_weights_col_packed_gptq():
|
|||
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
|
||||
|
||||
|
||||
def test_get_weights_col_packed_marlin():
|
||||
def test_get_weights_col_packed_marlin(marlin_weights_loader):
|
||||
weights = MockWeights(
|
||||
[
|
||||
"test_get_weights_col_packed_marlin",
|
||||
|
@ -851,14 +872,13 @@ def test_get_weights_col_packed_marlin():
|
|||
dtype=torch.float16,
|
||||
process_group=dummy_process_group,
|
||||
dummy_fs=dummy_file_system,
|
||||
weights_loader=marlin_weights_loader,
|
||||
)
|
||||
|
||||
prefix = "weight"
|
||||
quantize = "marlin"
|
||||
|
||||
w = weights.get_multi_weights_col(
|
||||
prefixes=[prefix],
|
||||
quantize=quantize,
|
||||
dim=0,
|
||||
)
|
||||
|
||||
|
@ -876,7 +896,7 @@ def test_get_weights_col_packed_marlin():
|
|||
# test_get_multi_weights_col
|
||||
|
||||
|
||||
def test_get_multi_weights_col_awq():
|
||||
def test_get_multi_weights_col_awq(gptq_weights_loader_awq):
|
||||
weights = MockWeights(
|
||||
[
|
||||
"test_get_multi_weights_col_gptq",
|
||||
|
@ -885,14 +905,13 @@ def test_get_multi_weights_col_awq():
|
|||
dtype=torch.float32,
|
||||
process_group=dummy_process_group,
|
||||
dummy_fs=dummy_file_system,
|
||||
weights_loader=gptq_weights_loader_awq,
|
||||
)
|
||||
|
||||
prefixes = ["weight"]
|
||||
quantize = "awq"
|
||||
|
||||
w = weights.get_multi_weights_col(
|
||||
prefixes=prefixes,
|
||||
quantize=quantize,
|
||||
dim=0,
|
||||
)
|
||||
|
||||
|
@ -924,22 +943,21 @@ def test_get_multi_weights_col_exl2():
|
|||
dtype=torch.float32,
|
||||
process_group=dummy_process_group,
|
||||
dummy_fs=dummy_file_system,
|
||||
weights_loader=Exl2WeightsLoader(),
|
||||
)
|
||||
|
||||
prefix = "weight"
|
||||
quantize = "exl2"
|
||||
|
||||
try:
|
||||
w = weights.get_multi_weights_col(
|
||||
prefixes=[prefix],
|
||||
quantize=quantize,
|
||||
dim=0,
|
||||
)
|
||||
except ValueError as e:
|
||||
assert e.args[0] == "get_multi_weights_col is not supported for exl2"
|
||||
|
||||
|
||||
def test_get_multi_weights_col_gptq():
|
||||
def test_get_multi_weights_col_gptq(gptq_weights_loader):
|
||||
weights = MockWeights(
|
||||
[
|
||||
"test_get_multi_weights_col_gptq",
|
||||
|
@ -948,14 +966,13 @@ def test_get_multi_weights_col_gptq():
|
|||
dtype=torch.float32,
|
||||
process_group=dummy_process_group,
|
||||
dummy_fs=dummy_file_system,
|
||||
weights_loader=gptq_weights_loader,
|
||||
)
|
||||
|
||||
prefixes = ["weight"]
|
||||
quantize = "gptq"
|
||||
|
||||
w = weights.get_multi_weights_col(
|
||||
prefixes=prefixes,
|
||||
quantize=quantize,
|
||||
dim=0,
|
||||
)
|
||||
|
||||
|
@ -978,7 +995,7 @@ def test_get_multi_weights_col_gptq():
|
|||
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
|
||||
|
||||
|
||||
def test_get_multi_weights_col_marlin():
|
||||
def test_get_multi_weights_col_marlin(marlin_weights_loader):
|
||||
weights = MockWeights(
|
||||
[
|
||||
"test_get_multi_weights_col_marlin",
|
||||
|
@ -987,14 +1004,13 @@ def test_get_multi_weights_col_marlin():
|
|||
dtype=torch.float16,
|
||||
process_group=dummy_process_group,
|
||||
dummy_fs=dummy_file_system,
|
||||
weights_loader=marlin_weights_loader,
|
||||
)
|
||||
|
||||
prefix = "weight"
|
||||
quantize = "marlin"
|
||||
|
||||
w = weights.get_multi_weights_col(
|
||||
prefixes=[prefix],
|
||||
quantize=quantize,
|
||||
dim=0,
|
||||
)
|
||||
|
||||
|
@ -1007,26 +1023,25 @@ def test_get_multi_weights_col_marlin():
|
|||
assert torch.allclose(w.s, expected_weight.s), "s mismatch"
|
||||
|
||||
|
||||
# test_get_multi_weights_row
|
||||
# test_get_weights_row
|
||||
|
||||
|
||||
def test_get_multi_weights_row_awq():
|
||||
def test_get_weights_row_awq(gptq_weights_loader_awq):
|
||||
weights = MockWeights(
|
||||
[
|
||||
"test_get_multi_weights_row_gptq",
|
||||
"test_get_weights_row_gptq",
|
||||
],
|
||||
device="cpu",
|
||||
dtype=torch.float32,
|
||||
process_group=dummy_process_group,
|
||||
dummy_fs=dummy_file_system,
|
||||
weights_loader=gptq_weights_loader_awq,
|
||||
)
|
||||
|
||||
prefix = "weight"
|
||||
quantize = "awq"
|
||||
|
||||
w = weights.get_multi_weights_row(
|
||||
w = weights.get_weights_row(
|
||||
prefix=prefix,
|
||||
quantize=quantize,
|
||||
)
|
||||
|
||||
expected_weight = GPTQWeight(
|
||||
|
@ -1048,23 +1063,22 @@ def test_get_multi_weights_row_awq():
|
|||
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
|
||||
|
||||
|
||||
def test_get_multi_weights_row_exl2():
|
||||
def test_get_weights_row_exl2():
|
||||
weights = MockWeights(
|
||||
[
|
||||
"test_get_multi_weights_row_exl2",
|
||||
"test_get_weights_row_exl2",
|
||||
],
|
||||
device="cpu",
|
||||
dtype=torch.float32,
|
||||
process_group=dummy_process_group,
|
||||
dummy_fs=dummy_file_system,
|
||||
weights_loader=Exl2WeightsLoader(),
|
||||
)
|
||||
|
||||
prefix = "weight"
|
||||
quantize = "exl2"
|
||||
|
||||
w = weights.get_multi_weights_row(
|
||||
w = weights.get_weights_row(
|
||||
prefix=prefix,
|
||||
quantize=quantize,
|
||||
)
|
||||
print(w)
|
||||
|
||||
|
@ -1086,23 +1100,22 @@ def test_get_multi_weights_row_exl2():
|
|||
assert torch.allclose(w.q_groups, expected_weight.q_groups), "q_groups mismatch"
|
||||
|
||||
|
||||
def test_get_multi_weights_row_gptq():
|
||||
def test_get_weights_row_gptq(gptq_weights_loader):
|
||||
weights = MockWeights(
|
||||
[
|
||||
"test_get_multi_weights_row_gptq",
|
||||
"test_get_weights_row_gptq",
|
||||
],
|
||||
device="cpu",
|
||||
dtype=torch.float32,
|
||||
process_group=dummy_process_group,
|
||||
dummy_fs=dummy_file_system,
|
||||
weights_loader=gptq_weights_loader,
|
||||
)
|
||||
|
||||
prefix = "weight"
|
||||
quantize = "gptq"
|
||||
|
||||
w = weights.get_multi_weights_row(
|
||||
w = weights.get_weights_row(
|
||||
prefix=prefix,
|
||||
quantize=quantize,
|
||||
)
|
||||
|
||||
expected_weight = GPTQWeight(
|
||||
|
@ -1124,23 +1137,22 @@ def test_get_multi_weights_row_gptq():
|
|||
assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch"
|
||||
|
||||
|
||||
def test_get_multi_weights_row_marlin():
|
||||
def test_get_weights_row_marlin(marlin_weights_loader):
|
||||
weights = MockWeights(
|
||||
[
|
||||
"test_get_multi_weights_row_marlin",
|
||||
"test_get_weights_row_marlin",
|
||||
],
|
||||
device="cpu",
|
||||
dtype=torch.float16,
|
||||
process_group=dummy_process_group,
|
||||
dummy_fs=dummy_file_system,
|
||||
weights_loader=marlin_weights_loader,
|
||||
)
|
||||
|
||||
prefix = "weight"
|
||||
quantize = "marlin"
|
||||
|
||||
w = weights.get_multi_weights_row(
|
||||
w = weights.get_weights_row(
|
||||
prefix=prefix,
|
||||
quantize=quantize,
|
||||
)
|
||||
|
||||
expected_weight = MarlinWeight(
|
||||
|
|
|
@ -91,6 +91,15 @@ def serve(
|
|||
f"LoRA adapters are enabled. This is an experimental feature and may not work as expected."
|
||||
)
|
||||
|
||||
# TODO: enable lora with cuda graphs. for now disable cuda graphs if lora is enabled
|
||||
# and warn the user
|
||||
if len(lora_adapter_ids) > 0 and os.getenv("CUDA_GRAPHS", None) is not None:
|
||||
logger.warning(
|
||||
f"LoRa adapter are not supported with CUDA Graphs. Disabling CUDA Graphs."
|
||||
)
|
||||
global CUDA_GRAPHS
|
||||
CUDA_GRAPHS = None
|
||||
|
||||
# Downgrade enum into str for easier management later on
|
||||
quantize = None if quantize is None else quantize.value
|
||||
dtype = None if dtype is None else dtype.value
|
||||
|
@ -332,6 +341,7 @@ def quantize(
|
|||
upload_to_model_id: Optional[str] = None,
|
||||
percdamp: float = 0.01,
|
||||
act_order: bool = False,
|
||||
groupsize: int = 128,
|
||||
):
|
||||
if revision is None:
|
||||
revision = "main"
|
||||
|
@ -346,13 +356,14 @@ def quantize(
|
|||
quantize(
|
||||
model_id=model_id,
|
||||
bits=4,
|
||||
groupsize=128,
|
||||
groupsize=groupsize,
|
||||
output_dir=output_dir,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
upload_to_model_id=upload_to_model_id,
|
||||
percdamp=percdamp,
|
||||
act_order=act_order,
|
||||
sym=True,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
import torch
|
||||
from typing import List, Union
|
||||
from dataclasses import dataclass
|
||||
|
||||
from text_generation_server.utils.weights import WeightsLoader, Weights
|
||||
|
||||
|
||||
@dataclass
|
||||
class Exl2Weight:
|
||||
|
@ -21,3 +24,60 @@ class Exl2Weight:
|
|||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self.q_weight.device
|
||||
|
||||
|
||||
class Exl2WeightsLoader(WeightsLoader):
|
||||
"""Loader for exl2-quantized weights."""
|
||||
|
||||
def get_weights_col_packed(
|
||||
self,
|
||||
weights: Weights,
|
||||
prefix: str,
|
||||
block_sizes: Union[int, List[int]],
|
||||
):
|
||||
raise RuntimeError("Column-packed weights are not supported for exl")
|
||||
|
||||
def get_weights_col(self, weights: Weights, prefix: str):
|
||||
try:
|
||||
q_weight = weights.get_tensor(f"{prefix}.q_weight")
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
|
||||
)
|
||||
|
||||
q_scale = weights.get_tensor(f"{prefix}.q_scale")
|
||||
q_invperm = weights.get_tensor(f"{prefix}.q_invperm")
|
||||
q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max")
|
||||
q_groups = weights.get_tensor(f"{prefix}.q_groups")
|
||||
|
||||
return Exl2Weight(
|
||||
q_weight=q_weight,
|
||||
q_scale=q_scale,
|
||||
q_invperm=q_invperm,
|
||||
q_scale_max=q_scale_max,
|
||||
q_groups=q_groups,
|
||||
)
|
||||
|
||||
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
||||
raise ValueError("get_multi_weights_col is not supported for exl2")
|
||||
|
||||
def get_weights_row(self, weights: Weights, prefix: str):
|
||||
try:
|
||||
q_weight = weights.get_tensor(f"{prefix}.q_weight")
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
|
||||
)
|
||||
|
||||
q_scale = weights.get_tensor(f"{prefix}.q_scale")
|
||||
q_invperm = weights.get_tensor(f"{prefix}.q_invperm")
|
||||
q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max")
|
||||
q_groups = weights.get_tensor(f"{prefix}.q_groups")
|
||||
|
||||
return Exl2Weight(
|
||||
q_weight=q_weight,
|
||||
q_scale=q_scale,
|
||||
q_invperm=q_invperm,
|
||||
q_scale_max=q_scale_max,
|
||||
q_groups=q_groups,
|
||||
)
|
||||
|
|
|
@ -1,4 +1,23 @@
|
|||
from enum import Enum, auto
|
||||
|
||||
import torch
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
|
||||
def get_fp8_linear() -> torch.nn.Module:
|
||||
"""
|
||||
Return an FP8 linear `Module` that is compatible with the current system.
|
||||
"""
|
||||
|
||||
if SYSTEM == "cuda":
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
if major == 8 and minor < 9:
|
||||
from text_generation_server.layers.marlin import GPTQMarlinFP8Linear
|
||||
|
||||
return GPTQMarlinFP8Linear
|
||||
|
||||
# On other systems let Torch decide if the hardware supports FP8.
|
||||
return Fp8Linear
|
||||
|
||||
|
||||
def fp8_quantize(weight, qdtype=torch.float8_e4m3fn):
|
||||
|
|
|
@ -1,20 +1,14 @@
|
|||
from dataclasses import dataclass
|
||||
from loguru import logger
|
||||
import os
|
||||
from typing import Optional
|
||||
from typing import List, Optional, Union
|
||||
from safetensors import SafetensorError
|
||||
from text_generation_server.utils.weights import Weights, WeightsLoader
|
||||
import torch
|
||||
from text_generation_server.utils.import_utils import (
|
||||
SYSTEM,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPTQParams:
|
||||
bits: int
|
||||
checkpoint_format: Optional[str]
|
||||
groupsize: int
|
||||
desc_act: bool
|
||||
quant_method: str
|
||||
sym: bool
|
||||
from text_generation_server.utils.log import log_once
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -69,3 +63,345 @@ elif CAN_EXLLAMA:
|
|||
pass
|
||||
|
||||
from text_generation_server.layers.gptq.quant_linear import QuantLinear
|
||||
|
||||
|
||||
class GPTQWeightsLoader(WeightsLoader):
|
||||
"""
|
||||
Loader for GPTQ- and AWQ-quantized weights.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
bits: int,
|
||||
desc_act: bool,
|
||||
groupsize: int,
|
||||
quant_method: str,
|
||||
quantize: str,
|
||||
sym: bool,
|
||||
):
|
||||
self.bits = bits
|
||||
self.desc_act = desc_act
|
||||
self.groupsize = groupsize
|
||||
self.quant_method = quant_method
|
||||
self.quantize = quantize
|
||||
self.sym = sym
|
||||
|
||||
def get_weights_col_packed(
|
||||
self,
|
||||
weights: Weights,
|
||||
prefix: str,
|
||||
block_sizes: Union[int, List[int]],
|
||||
):
|
||||
from text_generation_server.layers.marlin import (
|
||||
can_use_gptq_marlin,
|
||||
repack_gptq_for_marlin,
|
||||
)
|
||||
|
||||
try:
|
||||
qweight = weights.get_packed_sharded(
|
||||
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized."
|
||||
)
|
||||
scales = weights.get_packed_sharded(
|
||||
f"{prefix}.scales", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
scales = scales.to(dtype=weights.dtype)
|
||||
|
||||
self._get_gptq_params(weights)
|
||||
if can_use_gptq_marlin(
|
||||
bits=self.bits,
|
||||
groupsize=self.groupsize,
|
||||
quant_method=self.quant_method,
|
||||
quantize=self.quantize,
|
||||
sym=self.sym,
|
||||
):
|
||||
g_idx = weights.get_tensor(f"{prefix}.g_idx")
|
||||
return repack_gptq_for_marlin(
|
||||
qweight=qweight,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
bits=self.bits,
|
||||
desc_act=self.desc_act,
|
||||
groupsize=self.groupsize,
|
||||
sym=self.sym,
|
||||
sharded_infeatures=False,
|
||||
)
|
||||
|
||||
qzeros = weights.get_packed_sharded(
|
||||
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
if self.quantize == "gptq" and self.quant_method == "gptq":
|
||||
g_idx = weights.get_tensor(f"{prefix}.g_idx")
|
||||
elif self.quantize == "gptq" and self.quant_method == "awq":
|
||||
log_once(
|
||||
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
||||
)
|
||||
from text_generation_server.layers.awq.conversion_utils import (
|
||||
fast_awq_to_gptq,
|
||||
)
|
||||
|
||||
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||
g_idx = (
|
||||
torch.arange(
|
||||
qweight.shape[0] * (32 // self.bits),
|
||||
device=qweight.device,
|
||||
)
|
||||
// self.groupsize
|
||||
).to(dtype=torch.int32)
|
||||
else:
|
||||
g_idx = None
|
||||
|
||||
return GPTQWeight(
|
||||
qweight=qweight,
|
||||
qzeros=qzeros,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
bits=self.bits,
|
||||
groupsize=self.groupsize,
|
||||
use_exllama=False,
|
||||
)
|
||||
|
||||
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
||||
from text_generation_server.layers.marlin import (
|
||||
can_use_gptq_marlin,
|
||||
repack_gptq_for_marlin,
|
||||
)
|
||||
|
||||
try:
|
||||
qweight = torch.cat(
|
||||
[weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized"
|
||||
)
|
||||
|
||||
scales = torch.cat(
|
||||
[weights.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
|
||||
self._get_gptq_params(weights)
|
||||
if can_use_gptq_marlin(
|
||||
bits=self.bits,
|
||||
groupsize=self.groupsize,
|
||||
quant_method=self.quant_method,
|
||||
quantize=self.quantize,
|
||||
sym=self.sym,
|
||||
):
|
||||
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
|
||||
for w2 in w[1:]:
|
||||
torch.testing.assert_close(w2, w[0])
|
||||
g_idx = w[0]
|
||||
|
||||
return repack_gptq_for_marlin(
|
||||
qweight=qweight,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
bits=self.bits,
|
||||
desc_act=self.desc_act,
|
||||
groupsize=self.groupsize,
|
||||
sym=self.sym,
|
||||
sharded_infeatures=False,
|
||||
)
|
||||
|
||||
qzeros = torch.cat(
|
||||
[weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
|
||||
from text_generation_server.layers.gptq import HAS_EXLLAMA
|
||||
|
||||
use_exllama = (
|
||||
self.bits == 4
|
||||
and HAS_EXLLAMA
|
||||
and self.quantize == "gptq"
|
||||
and not self.desc_act
|
||||
)
|
||||
|
||||
if self.quantize == "gptq" and self.quant_method == "gptq":
|
||||
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
|
||||
for w2 in w[1:]:
|
||||
torch.testing.assert_close(w2, w[0])
|
||||
g_idx = w[0]
|
||||
elif self.quantize == "gptq" and self.quant_method == "awq":
|
||||
log_once(
|
||||
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
||||
)
|
||||
from text_generation_server.layers.awq.conversion_utils import (
|
||||
fast_awq_to_gptq,
|
||||
)
|
||||
|
||||
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||
if use_exllama:
|
||||
g_idx = None
|
||||
else:
|
||||
g_idx = (
|
||||
torch.arange(
|
||||
qweight.shape[0] * (32 // self.bits),
|
||||
device=qweight.device,
|
||||
)
|
||||
// self.groupsize
|
||||
).to(dtype=torch.int32)
|
||||
else:
|
||||
g_idx = None
|
||||
|
||||
return GPTQWeight(
|
||||
qweight=qweight,
|
||||
qzeros=qzeros,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
bits=self.bits,
|
||||
groupsize=self.groupsize,
|
||||
use_exllama=use_exllama,
|
||||
)
|
||||
|
||||
def get_weights_row(self, weights: Weights, prefix: str):
|
||||
from text_generation_server.layers.marlin import (
|
||||
can_use_gptq_marlin,
|
||||
repack_gptq_for_marlin,
|
||||
)
|
||||
|
||||
self._get_gptq_params(weights)
|
||||
if can_use_gptq_marlin(
|
||||
bits=self.bits,
|
||||
groupsize=self.groupsize,
|
||||
quant_method=self.quant_method,
|
||||
quantize=self.quantize,
|
||||
sym=self.sym,
|
||||
):
|
||||
log_once(logger.info, "Using GPTQ-Marlin kernels")
|
||||
try:
|
||||
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
|
||||
)
|
||||
|
||||
g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||
if self.desc_act or self.groupsize == -1:
|
||||
scales = weights.get_tensor(f"{prefix}.scales")
|
||||
else:
|
||||
scales = weights.get_sharded(f"{prefix}.scales", dim=0)
|
||||
|
||||
sharded_in_features = weights.process_group.size() > 1
|
||||
|
||||
return repack_gptq_for_marlin(
|
||||
qweight=qweight,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
bits=self.bits,
|
||||
desc_act=self.desc_act,
|
||||
groupsize=self.groupsize,
|
||||
sym=self.sym,
|
||||
sharded_infeatures=sharded_in_features,
|
||||
)
|
||||
|
||||
use_exllama = True
|
||||
if self.bits != 4:
|
||||
use_exllama = False
|
||||
|
||||
if self.desc_act:
|
||||
log_once(logger.warning, "Disabling exllama because desc_act=True")
|
||||
use_exllama = False
|
||||
|
||||
try:
|
||||
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||
)
|
||||
|
||||
if self.quantize == "gptq" and self.quant_method == "gptq":
|
||||
g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||
else:
|
||||
g_idx = None
|
||||
|
||||
if weights.process_group.size() > 1:
|
||||
if g_idx is not None:
|
||||
if (
|
||||
not torch.equal(
|
||||
g_idx.cpu(),
|
||||
torch.tensor(
|
||||
[i // self.groupsize for i in range(g_idx.shape[0])],
|
||||
dtype=torch.int32,
|
||||
),
|
||||
)
|
||||
and not (g_idx == 0).all()
|
||||
):
|
||||
# Exllama implementation does not support row tensor parallelism with act-order, as
|
||||
# it would require to reorder input activations that are split unto several GPUs
|
||||
use_exllama = False
|
||||
|
||||
from text_generation_server.layers.gptq import (
|
||||
HAS_EXLLAMA,
|
||||
CAN_EXLLAMA,
|
||||
GPTQWeight,
|
||||
)
|
||||
|
||||
if use_exllama:
|
||||
if not HAS_EXLLAMA:
|
||||
if CAN_EXLLAMA:
|
||||
log_once(
|
||||
logger.warning,
|
||||
"Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True",
|
||||
)
|
||||
use_exllama = False
|
||||
else:
|
||||
log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")
|
||||
|
||||
if use_exllama and self.groupsize != -1:
|
||||
qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0)
|
||||
scales = weights.get_sharded(f"{prefix}.scales", dim=0)
|
||||
else:
|
||||
qzeros = weights.get_tensor(f"{prefix}.qzeros")
|
||||
scales = weights.get_tensor(f"{prefix}.scales")
|
||||
|
||||
if use_exllama and g_idx is not None:
|
||||
g_idx = g_idx - g_idx[0]
|
||||
|
||||
if self.quantize == "gptq" and self.quant_method == "awq":
|
||||
log_once(
|
||||
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
||||
)
|
||||
from text_generation_server.layers.awq.conversion_utils import (
|
||||
fast_awq_to_gptq,
|
||||
)
|
||||
|
||||
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||
if use_exllama:
|
||||
g_idx = None
|
||||
else:
|
||||
g_idx = (
|
||||
torch.arange(
|
||||
qweight.shape[0] * (32 // self.bits),
|
||||
device=qweight.device,
|
||||
)
|
||||
// self.groupsize
|
||||
).to(dtype=torch.int32)
|
||||
|
||||
return GPTQWeight(
|
||||
qweight=qweight,
|
||||
qzeros=qzeros,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
bits=self.bits,
|
||||
groupsize=self.groupsize,
|
||||
use_exllama=use_exllama,
|
||||
)
|
||||
|
||||
def _get_gptq_params(self, weights: Weights):
|
||||
if weights._has_tensor("gptq_bits") and weights._has_tensor("gptq_groupsize"):
|
||||
self.bits = weights.get_tensor("gptq_bits").item()
|
||||
self.groupsize = weights.get_tensor("gptq_groupsize").item()
|
||||
self.desc_act = False
|
||||
# `server quantize` used asymmetric quantization unconditionally
|
||||
# before the `gptq_sym` setting tensor was added.
|
||||
self.sym = (
|
||||
weights.get_tensor("gptq_sym").item()
|
||||
if weights._has_tensor("gptq_sym")
|
||||
else False
|
||||
)
|
||||
self.quant_method = "gptq"
|
||||
|
|
|
@ -16,6 +16,8 @@ from text_generation_server.layers.gptq.quant_linear import QuantLinear
|
|||
from loguru import logger
|
||||
from typing import Optional
|
||||
|
||||
from text_generation_server.utils.weights import DefaultWeightsLoader
|
||||
|
||||
DEV = torch.device("cuda:0")
|
||||
|
||||
|
||||
|
@ -869,6 +871,7 @@ def quantize(
|
|||
upload_to_model_id: Optional[str],
|
||||
percdamp: float,
|
||||
act_order: bool,
|
||||
sym: bool,
|
||||
):
|
||||
print("loading model")
|
||||
config = AutoConfig.from_pretrained(
|
||||
|
@ -891,6 +894,7 @@ def quantize(
|
|||
dtype=torch.float16,
|
||||
process_group=process_group,
|
||||
aliases={"embed_tokens.weight": ["lm_head.weight"]},
|
||||
weights_loader=DefaultWeightsLoader(),
|
||||
)
|
||||
hooks = []
|
||||
for name, module in model.named_modules():
|
||||
|
@ -943,6 +947,7 @@ def quantize(
|
|||
percdamp=percdamp,
|
||||
act_order=act_order,
|
||||
hooks=hooks,
|
||||
sym=sym,
|
||||
)
|
||||
print(time.time() - tick)
|
||||
|
||||
|
@ -954,6 +959,7 @@ def quantize(
|
|||
state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}
|
||||
state_dict["gptq_bits"] = torch.LongTensor([bits])
|
||||
state_dict["gptq_groupsize"] = torch.LongTensor([groupsize])
|
||||
state_dict["gptq_sym"] = torch.BoolTensor([sym])
|
||||
|
||||
max_shard_size = "10GB"
|
||||
shards, index = shard_checkpoint(
|
||||
|
|
|
@ -106,9 +106,9 @@ def get_linear(weight, bias, quantize):
|
|||
"Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
|
||||
)
|
||||
elif quantize == "fp8":
|
||||
from text_generation_server.layers.fp8 import Fp8Linear
|
||||
from text_generation_server.layers.fp8 import get_fp8_linear
|
||||
|
||||
linear = Fp8Linear(weight, bias)
|
||||
linear = get_fp8_linear()(weight, bias)
|
||||
elif quantize == "bitsandbytes":
|
||||
try:
|
||||
from text_generation_server.layers.bnb import (
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from text_generation_server.layers.gptq import GPTQParams
|
||||
from loguru import logger
|
||||
from text_generation_server.layers.fp8 import fp8_quantize
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.log import log_once
|
||||
from text_generation_server.utils.weights import Weights, WeightsLoader
|
||||
|
||||
try:
|
||||
import marlin_kernels
|
||||
|
@ -24,16 +26,132 @@ GPTQ_MARLIN_GROUP_SIZES = [-1, 32, 64, 128]
|
|||
MARLIN_TILE_SIZE = 16
|
||||
|
||||
|
||||
def can_use_gptq_marlin(gptq_params: GPTQParams, quantize: str) -> bool:
|
||||
class MarlinWeightsLoader(WeightsLoader):
|
||||
"""Loader for Marlin-quantized weights."""
|
||||
|
||||
def __init__(self, *, bits: int, is_marlin_24: bool):
|
||||
self.bits = bits
|
||||
self.is_marlin_24 = is_marlin_24
|
||||
|
||||
def get_weights_col_packed(
|
||||
self,
|
||||
weights: Weights,
|
||||
prefix: str,
|
||||
block_sizes: Union[int, List[int]],
|
||||
):
|
||||
if self.is_marlin_24:
|
||||
B = weights.get_packed_sharded(
|
||||
f"{prefix}.B_24", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
B_meta = weights.get_packed_sharded(
|
||||
f"{prefix}.B_meta", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
s = weights.get_packed_sharded(
|
||||
f"{prefix}.s", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
|
||||
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
|
||||
else:
|
||||
B = weights.get_packed_sharded(
|
||||
f"{prefix}.B", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
s = weights.get_packed_sharded(
|
||||
f"{prefix}.s", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
weight = MarlinWeight(B=B, s=s)
|
||||
|
||||
return weight
|
||||
|
||||
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
||||
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
|
||||
if is_marlin_24:
|
||||
try:
|
||||
B = torch.cat(
|
||||
[weights.get_sharded(f"{p}.B_24", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `marlin` weight, make sure the model is already quantized"
|
||||
)
|
||||
|
||||
B_meta = torch.cat(
|
||||
[weights.get_sharded(f"{p}.B_meta", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
|
||||
s = torch.cat(
|
||||
[weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
|
||||
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
|
||||
else:
|
||||
try:
|
||||
B = torch.cat(
|
||||
[weights.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `marlin` weight, make sure the model is already quantized"
|
||||
)
|
||||
s = torch.cat(
|
||||
[weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
|
||||
weight = MarlinWeight(B=B, s=s)
|
||||
|
||||
return weight
|
||||
|
||||
def get_weights_row(self, weights: Weights, prefix: str):
|
||||
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
|
||||
if is_marlin_24:
|
||||
try:
|
||||
B = weights.get_sharded(f"{prefix}.B_24", dim=0)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
"Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized."
|
||||
)
|
||||
|
||||
B_meta = weights.get_sharded(f"{prefix}.B_meta", dim=0)
|
||||
num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0]
|
||||
if num_groups == 1:
|
||||
# The number of groups is 1 when groupsize == -1. share
|
||||
# scales between all shards in this case.
|
||||
s = weights.get_tensor(f"{prefix}.s")
|
||||
else:
|
||||
s = weights.get_sharded(f"{prefix}.s", dim=0)
|
||||
|
||||
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
|
||||
else:
|
||||
try:
|
||||
B = weights.get_sharded(f"{prefix}.B", dim=0)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
"Cannot load `marlin` weight, make sure the model is already quantized."
|
||||
)
|
||||
|
||||
num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0]
|
||||
if num_groups == 1:
|
||||
# The number of groups is 1 when groupsize == -1. share
|
||||
# scales between all shards in this case.
|
||||
s = weights.get_tensor(f"{prefix}.s")
|
||||
else:
|
||||
s = weights.get_sharded(f"{prefix}.s", dim=0)
|
||||
weight = MarlinWeight(B=B, s=s)
|
||||
|
||||
return weight
|
||||
|
||||
|
||||
def can_use_gptq_marlin(
|
||||
*, bits: int, groupsize: int, quant_method: str, quantize: str, sym: bool
|
||||
) -> bool:
|
||||
return (
|
||||
SYSTEM == "cuda"
|
||||
and marlin_kernels is not None
|
||||
and has_sm_8_0
|
||||
and quantize == "gptq"
|
||||
and gptq_params.quant_method == "gptq"
|
||||
and gptq_params.bits in GPTQ_MARLIN_BITS
|
||||
and gptq_params.groupsize in GPTQ_MARLIN_GROUP_SIZES
|
||||
and gptq_params.sym
|
||||
and quant_method == "gptq"
|
||||
and bits in GPTQ_MARLIN_BITS
|
||||
and groupsize in GPTQ_MARLIN_GROUP_SIZES
|
||||
and sym
|
||||
)
|
||||
|
||||
|
||||
|
@ -339,6 +457,115 @@ class GPTQMarlin24Linear(nn.Module):
|
|||
return C
|
||||
|
||||
|
||||
class GPTQMarlinFP8Linear(nn.Module):
|
||||
"""
|
||||
FP8 GPTQ-Marlin linear layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
_check_marlin_kernels()
|
||||
assert marlin_kernels is not None
|
||||
|
||||
log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel")
|
||||
|
||||
qweight, scale = fp8_quantize(weight)
|
||||
scale = scale.to(torch.float16)
|
||||
qweight, scales = repack_fp8_for_marlin(qweight, scale)
|
||||
|
||||
in_features = qweight.shape[0] * MARLIN_TILE_SIZE
|
||||
out_features = scales.shape[1]
|
||||
_check_valid_shape(in_features=in_features, out_features=out_features)
|
||||
|
||||
self.qweight = qweight
|
||||
self.scales = scales
|
||||
self.bias = bias if bias is not None else None
|
||||
|
||||
self.workspace = torch.zeros(
|
||||
out_features // 64 * 16, dtype=torch.int, device=qweight.device
|
||||
)
|
||||
|
||||
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
||||
assert marlin_kernels is not None
|
||||
|
||||
A_flat = A.view(-1, A.shape[-1])
|
||||
C = marlin_kernels.fp8_marlin_gemm(
|
||||
A_flat,
|
||||
self.qweight,
|
||||
self.scales,
|
||||
self.workspace,
|
||||
8,
|
||||
A_flat.shape[0],
|
||||
self.scales.shape[1],
|
||||
A_flat.shape[1],
|
||||
)
|
||||
C = C.reshape(A.shape[:-1] + (self.scales.shape[1],))
|
||||
|
||||
if self.bias is not None:
|
||||
C += self.bias
|
||||
|
||||
return C
|
||||
|
||||
|
||||
def pack_fp8_as_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Repack FP8 weights to gptq format (packed int32 elements).
|
||||
"""
|
||||
assert fp8_tensor.dtype == torch.float8_e4m3fn
|
||||
|
||||
if fp8_tensor.shape[0] % 4 != 0:
|
||||
raise ValueError(
|
||||
f"Leading tensor dimension is not divisable by 4: {fp8_tensor.shape[0]}"
|
||||
)
|
||||
|
||||
# Reshape to prepare for packing
|
||||
reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:])
|
||||
|
||||
# Convert fp8 to uint8 (byte) representation
|
||||
byte_tensor = reshaped.view(torch.uint8)
|
||||
|
||||
# Pack 4 uint8 values into one int32
|
||||
packed = torch.zeros(
|
||||
fp8_tensor.shape[0] // 4,
|
||||
fp8_tensor.shape[1],
|
||||
dtype=torch.int32,
|
||||
device=fp8_tensor.device,
|
||||
)
|
||||
|
||||
for i in range(4):
|
||||
packed.bitwise_or_(byte_tensor[:, i].to(torch.int32) << i * 8)
|
||||
|
||||
return packed
|
||||
|
||||
|
||||
def repack_fp8_for_marlin(weight: torch.Tensor, scale: torch.Tensor):
|
||||
"""
|
||||
Repack FP8 tensor for GPTQ-Marlin.
|
||||
"""
|
||||
|
||||
out_features, in_features = weight.shape
|
||||
|
||||
# Torch linear layers weights with shape [out_features, in_features],
|
||||
# GPTQ-quantized weights use [in_feateres/pack_factor, in_features],
|
||||
# so transpose before packing.
|
||||
qweight = pack_fp8_as_int32(weight.t())
|
||||
|
||||
perm = torch.empty(0, dtype=torch.int, device=qweight.device)
|
||||
repacked = marlin_kernels.gptq_marlin_repack(
|
||||
qweight, perm, in_features, out_features, 8
|
||||
)
|
||||
|
||||
scales = scale.reshape(1, 1).repeat(1, out_features)
|
||||
scales = permute_scales(scales)
|
||||
|
||||
return repacked, scales
|
||||
|
||||
|
||||
@dataclass
|
||||
class MarlinWeight:
|
||||
"""
|
||||
|
|
|
@ -102,7 +102,7 @@ class PositionRotaryEmbedding(nn.Module):
|
|||
max_position_embeddings=rope_scaling[
|
||||
"original_max_position_embeddings"
|
||||
],
|
||||
base=10000.0,
|
||||
base=base,
|
||||
device=inv_freq.device,
|
||||
scaling_factor=scaling_factor,
|
||||
extrapolation_factor=1,
|
||||
|
|
|
@ -52,7 +52,7 @@ class TensorParallelHead(SuperLayer):
|
|||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
except:
|
||||
# ...otherwise they are quantized.
|
||||
weight = weights.get_weights_col(prefix, config.quantize)
|
||||
weight = weights.get_weights_col(prefix)
|
||||
should_gather = weights.process_group.size() > 1
|
||||
elif weights.process_group.size() > 1:
|
||||
try:
|
||||
|
@ -129,9 +129,7 @@ class TensorParallelColumnLinear(SuperLayer):
|
|||
@classmethod
|
||||
def load_gate_up(cls, config, prefix: str, weights, bias: bool):
|
||||
"""Specific method when the QKV was joined after the fact"""
|
||||
weight = weights.get_weights_col_packed_gate_up(
|
||||
prefix, quantize=config.quantize
|
||||
)
|
||||
weight = weights.get_weights_col_packed_gate_up(prefix)
|
||||
if bias:
|
||||
raise NotImplementedError("packed_gate_up only implemented without bias")
|
||||
else:
|
||||
|
@ -152,7 +150,6 @@ class TensorParallelColumnLinear(SuperLayer):
|
|||
"""Specific method when the QKV was joined after the fact"""
|
||||
weight = weights.get_weights_col_packed_qkv(
|
||||
prefix,
|
||||
quantize=config.quantize,
|
||||
num_heads=num_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
)
|
||||
|
@ -165,7 +162,7 @@ class TensorParallelColumnLinear(SuperLayer):
|
|||
|
||||
@classmethod
|
||||
def load(cls, config, prefix: str, weights, bias: bool):
|
||||
weight = weights.get_weights_col(prefix, config.quantize)
|
||||
weight = weights.get_weights_col(prefix)
|
||||
if bias:
|
||||
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
|
||||
else:
|
||||
|
@ -178,14 +175,12 @@ class TensorParallelColumnLinear(SuperLayer):
|
|||
if config.quantize == "exl2":
|
||||
linears = []
|
||||
for prefix in prefixes:
|
||||
weight = weights.get_weights_col(prefix, config.quantize)
|
||||
weight = weights.get_weights_col(prefix)
|
||||
b = weights.get_tensor(f"{prefix}.bias") if bias else None
|
||||
linears.append(get_linear(weight, b, config.quantize))
|
||||
linear = LayerConcat(linears)
|
||||
else:
|
||||
weight = weights.get_multi_weights_col(
|
||||
prefixes, quantize=config.quantize, dim=dim
|
||||
)
|
||||
weight = weights.get_multi_weights_col(prefixes, dim=dim)
|
||||
if bias:
|
||||
b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes]
|
||||
bias = torch.cat(b, dim=dim)
|
||||
|
@ -202,7 +197,7 @@ class TensorParallelRowLinear(SuperLayer):
|
|||
|
||||
@classmethod
|
||||
def load(cls, config, prefix: str, weights, bias: bool):
|
||||
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
|
||||
weight = weights.get_weights_row(prefix)
|
||||
|
||||
if bias and weights.process_group.rank() == 0:
|
||||
# Rank is only on the first rank process
|
||||
|
|
|
@ -804,6 +804,10 @@ def get_model(
|
|||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
aliases={
|
||||
"lm_head.weight": ["transformer.word_embeddings.weight"],
|
||||
"transformer.word_embeddings.weight": ["lm_head.weight"],
|
||||
},
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
config_class=RWConfig,
|
||||
|
|
|
@ -20,6 +20,7 @@ from text_generation_server.utils import (
|
|||
from text_generation_server.models import Model
|
||||
from text_generation_server.utils.chunks import concat_text_chunks
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.quantization import get_loader
|
||||
from text_generation_server.utils.tokens import batch_top_tokens
|
||||
from text_generation_server.models.types import (
|
||||
Batch,
|
||||
|
@ -546,12 +547,17 @@ class CausalLM(Model):
|
|||
tokenizer.pad_token_id = config.pad_token_id
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
weights_loader = get_loader(
|
||||
quantize=quantize, model_id=model_id, revision=revision
|
||||
)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(
|
||||
filenames, device=device, dtype=dtype, process_group=self.process_group
|
||||
filenames,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
process_group=self.process_group,
|
||||
weights_loader=weights_loader,
|
||||
)
|
||||
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
prefix = ""
|
||||
model = model_class(prefix, config, weights)
|
||||
|
|
|
@ -163,7 +163,6 @@ def _load_gqa(config, prefix: str, weights):
|
|||
|
||||
weight = weights.get_multi_weights_col(
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
quantize=config.quantize,
|
||||
dim=0,
|
||||
)
|
||||
|
||||
|
|
|
@ -105,6 +105,12 @@ class DbrxFFNConfig(PretrainedConfig):
|
|||
|
||||
|
||||
class DbrxConfig(PretrainedConfig):
|
||||
attribute_map = {
|
||||
"hidden_size": "d_model",
|
||||
"num_attention_heads": "n_heads",
|
||||
"num_hidden_layers": "n_layers",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int = 2048,
|
||||
|
@ -157,6 +163,12 @@ class DbrxConfig(PretrainedConfig):
|
|||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def num_key_value_heads(self):
|
||||
# We can't use the attribute map, since this the number of KV
|
||||
# heads is not top-level.
|
||||
return self.attn_config.kv_n_heads
|
||||
|
||||
|
||||
def promote_scalar(x: torch.Tensor) -> torch.Tensor:
|
||||
return x.view(1) if len(x.size()) == 0 else x
|
||||
|
|
|
@ -141,7 +141,6 @@ def _load_gqa(config, prefix: str, weights):
|
|||
|
||||
weight = weights.get_multi_weights_col(
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
quantize=config.quantize,
|
||||
dim=0,
|
||||
)
|
||||
|
||||
|
|
|
@ -141,7 +141,6 @@ def _load_gqa(config, prefix: str, weights):
|
|||
|
||||
weight = weights.get_multi_weights_col(
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
quantize=config.quantize,
|
||||
dim=0,
|
||||
)
|
||||
|
||||
|
|
|
@ -61,7 +61,6 @@ def _load_qkv_gptq(config, prefix: str, weights):
|
|||
# Weights
|
||||
weight = weights.get_weights_col_packed_qkv(
|
||||
f"{prefix}.c_attn",
|
||||
config.quantize,
|
||||
config.num_attention_heads,
|
||||
config.num_attention_heads,
|
||||
)
|
||||
|
@ -137,7 +136,7 @@ def load_row(config, prefix: str, weights, bias: bool):
|
|||
"""load_row, but with transposed weight matrices."""
|
||||
|
||||
if config.quantize == "gptq":
|
||||
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
|
||||
weight = weights.get_weights_row(prefix)
|
||||
else:
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T.contiguous()
|
||||
|
||||
|
@ -155,9 +154,7 @@ def load_row(config, prefix: str, weights, bias: bool):
|
|||
def load_col(config, prefix: str, weights, bias: bool):
|
||||
"""load_col, but with transposed weight matrices."""
|
||||
if config.quantize == "gptq":
|
||||
weight = weights.get_multi_weights_col(
|
||||
[prefix], quantize=config.quantize, dim=1
|
||||
)
|
||||
weight = weights.get_multi_weights_col([prefix], dim=1)
|
||||
else:
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=1).T.contiguous()
|
||||
|
||||
|
|
|
@ -135,7 +135,6 @@ def _load_gqa(config, prefix: str, weights):
|
|||
|
||||
weight = weights.get_multi_weights_col(
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
quantize=config.quantize,
|
||||
dim=0,
|
||||
)
|
||||
|
||||
|
|
|
@ -48,7 +48,7 @@ from text_generation_server.layers.rotary import (
|
|||
|
||||
|
||||
def load_row(config, prefix: str, weights, bias: bool):
|
||||
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
|
||||
weight = weights.get_weights_row(prefix)
|
||||
|
||||
if bias and weights.process_group.rank() == 0:
|
||||
# Rank is only on the first rank process
|
||||
|
@ -64,7 +64,7 @@ def load_row(config, prefix: str, weights, bias: bool):
|
|||
|
||||
|
||||
def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size):
|
||||
weight = weights.get_multi_weights_col([prefix], quantize=config.quantize, dim=0)
|
||||
weight = weights.get_multi_weights_col([prefix], dim=0)
|
||||
if isinstance(weight, torch.Tensor):
|
||||
# Only on non quantized versions
|
||||
weight = (
|
||||
|
|
|
@ -85,7 +85,6 @@ def _load_gqa(config, prefix: str, weights):
|
|||
|
||||
weight = weights.get_multi_weights_col(
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
quantize=config.quantize,
|
||||
dim=0,
|
||||
)
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ from text_generation_server.layers.attention import (
|
|||
|
||||
|
||||
def load_row(config, prefix: str, weights, bias: bool):
|
||||
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
|
||||
weight = weights.get_weights_row(prefix)
|
||||
|
||||
if bias and weights.process_group.rank() == 0:
|
||||
# Rank is only on the first rank process
|
||||
|
@ -42,6 +42,7 @@ class RWConfig(PretrainedConfig):
|
|||
attribute_map = {
|
||||
"num_hidden_layers": "n_layer",
|
||||
"num_attention_heads": "n_head",
|
||||
"num_key_value_heads": "n_head_kv",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
|
|
|
@ -17,6 +17,7 @@ from text_generation_server.layers import (
|
|||
TensorParallelEmbedding,
|
||||
get_linear,
|
||||
)
|
||||
from text_generation_server.layers.gptq import GPTQWeightsLoader
|
||||
from text_generation_server.layers.layernorm import (
|
||||
FastLayerNorm,
|
||||
)
|
||||
|
@ -81,11 +82,13 @@ def _load_multi_mqa_gptq(
|
|||
qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
|
||||
qzeros = qzeros.to(device=weights.device)
|
||||
|
||||
gptq_params = weights._get_gptq_params()
|
||||
if gptq_params.quant_method == "gptq":
|
||||
loader = weights.weights_loader
|
||||
assert isinstance(loader, GPTQWeightsLoader)
|
||||
loader._get_gptq_params(weights)
|
||||
if loader.quant_method == "gptq":
|
||||
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
|
||||
g_idx = g_idx.to(device=weights.device)
|
||||
elif gptq_params.quant_method == "awq":
|
||||
elif loader.quant_method == "awq":
|
||||
g_idx = None
|
||||
from text_generation_server.layers.awq.conversion_utils import (
|
||||
fast_awq_to_gptq,
|
||||
|
@ -100,8 +103,8 @@ def _load_multi_mqa_gptq(
|
|||
qzeros=qzeros,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
bits=gptq_params.bits,
|
||||
groupsize=gptq_params.groupsize,
|
||||
bits=loader.bits,
|
||||
groupsize=loader.groupsize,
|
||||
use_exllama=HAS_EXLLAMA,
|
||||
)
|
||||
|
||||
|
@ -197,9 +200,7 @@ def load_col(config, prefix: str, weights, bias: bool):
|
|||
if config.transpose:
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=1).T
|
||||
else:
|
||||
weight = weights.get_multi_weights_col(
|
||||
[prefix], quantize=config.quantize, dim=0
|
||||
)
|
||||
weight = weights.get_multi_weights_col([prefix], dim=0)
|
||||
|
||||
if bias:
|
||||
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
|
||||
|
@ -212,7 +213,7 @@ def load_row(config, prefix: str, weights, bias: bool):
|
|||
if config.transpose:
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T
|
||||
else:
|
||||
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
|
||||
weight = weights.get_weights_row(prefix)
|
||||
|
||||
if bias and weights.process_group.rank() == 0:
|
||||
# Rank is only on the first rank process
|
||||
|
|
|
@ -126,7 +126,6 @@ def _load_gqa(config, prefix: str, weights):
|
|||
|
||||
weight = weights.get_multi_weights_col(
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
quantize=config.quantize,
|
||||
dim=0,
|
||||
)
|
||||
|
||||
|
|
|
@ -34,6 +34,7 @@ from text_generation_server.layers import (
|
|||
TensorParallelEmbedding,
|
||||
TensorParallelRowLinear,
|
||||
)
|
||||
from text_generation_server.utils.weights import DefaultWeightsLoader
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
|
@ -682,7 +683,7 @@ class Idefics2Connector(nn.Module):
|
|||
class Idefics2ForConditionalGeneration(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
config.vision_config.quantize = config.quantize
|
||||
config.vision_config.quantize = None
|
||||
config.vision_config.speculator = config.speculator
|
||||
config.text_config.quantize = config.quantize
|
||||
config.text_config.speculator = config.speculator
|
||||
|
@ -695,16 +696,28 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
|||
name="text_model",
|
||||
)
|
||||
self.dtype = weights.dtype
|
||||
self.vision_model = Idefics2VisionTransformer(
|
||||
prefix=f"{prefix}.model.vision_model" if prefix else "model.vision_model",
|
||||
config=vision_config,
|
||||
weights=weights,
|
||||
)
|
||||
self.connector = Idefics2Connector(
|
||||
prefix=f"{prefix}.model.connector" if prefix else "model.connector",
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
# The vision and connector models are not quantized.
|
||||
with weights.use_loader(DefaultWeightsLoader()):
|
||||
self.vision_model = Idefics2VisionTransformer(
|
||||
prefix=(
|
||||
f"{prefix}.model.vision_model" if prefix else "model.vision_model"
|
||||
),
|
||||
config=vision_config,
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
quantize = config.quantize
|
||||
try:
|
||||
config.quantize = None
|
||||
self.connector = Idefics2Connector(
|
||||
prefix=f"{prefix}.model.connector" if prefix else "model.connector",
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
finally:
|
||||
config.quantize = quantize
|
||||
|
||||
self.config = config
|
||||
self.image_seq_len = config.perceiver_config.resampler_n_latents
|
||||
self.image_token_id = config.image_token_id
|
||||
|
|
|
@ -49,6 +49,7 @@ from text_generation_server.models.globals import (
|
|||
from text_generation_server.layers.attention import Seqlen
|
||||
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
||||
from text_generation_server.utils.dist import MEMORY_FRACTION
|
||||
from text_generation_server.utils.quantization import get_loader
|
||||
from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments
|
||||
|
||||
from text_generation_server.utils.import_utils import (
|
||||
|
@ -880,12 +881,16 @@ class FlashCausalLM(Model):
|
|||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
weights_loader = get_loader(quantize, model_id, revision)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(
|
||||
filenames, device, dtype, process_group=self.process_group, aliases=aliases
|
||||
filenames,
|
||||
device,
|
||||
dtype,
|
||||
process_group=self.process_group,
|
||||
aliases=aliases,
|
||||
weights_loader=weights_loader,
|
||||
)
|
||||
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
prefix = ""
|
||||
model = model_class(prefix, config, weights)
|
||||
|
@ -904,13 +909,12 @@ class FlashCausalLM(Model):
|
|||
self.num_layers = config.num_hidden_layers
|
||||
# Validation is done in the model itself
|
||||
if num_kv_heads is None:
|
||||
# Order is important here.
|
||||
for attr in ["num_key_value_heads", "num_attention_heads", "n_head"]:
|
||||
num_kv_heads = getattr(config, attr, None)
|
||||
if num_kv_heads is not None:
|
||||
break
|
||||
num_kv_heads = getattr(config, "num_key_value_heads", None)
|
||||
# GPT-2 workaround
|
||||
if num_kv_heads is None:
|
||||
raise ValueError("Cannot get the number of key/value heads")
|
||||
num_kv_heads = getattr(config, "n_head", None)
|
||||
if num_kv_heads is None:
|
||||
raise ValueError("Cannot get the number of key/value heads")
|
||||
self.num_kv_heads = (
|
||||
num_kv_heads // self.process_group.size()
|
||||
if num_kv_heads > 1
|
||||
|
|
|
@ -23,6 +23,7 @@ from text_generation_server.utils import (
|
|||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.quantization import get_loader
|
||||
|
||||
|
||||
class IDEFICSSharded(IdeficsCausalLM):
|
||||
|
@ -70,6 +71,9 @@ class IDEFICSSharded(IdeficsCausalLM):
|
|||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
weights_loader = get_loader(
|
||||
quantize=quantize, model_id=model_id, revision=revision
|
||||
)
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(
|
||||
|
@ -77,6 +81,7 @@ class IDEFICSSharded(IdeficsCausalLM):
|
|||
device=device,
|
||||
dtype=dtype,
|
||||
process_group=self.process_group,
|
||||
weights_loader=weights_loader,
|
||||
)
|
||||
|
||||
model = IdeficsForVisionText2Text(config, weights)
|
||||
|
|
|
@ -28,6 +28,7 @@ from text_generation_server.models.types import (
|
|||
GeneratedText,
|
||||
)
|
||||
from text_generation_server.utils.chunks import concat_text_chunks
|
||||
from text_generation_server.utils.quantization import get_loader
|
||||
from text_generation_server.utils.tokens import batch_top_tokens, Sampling
|
||||
from dataclasses import dataclass
|
||||
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
|
||||
|
@ -448,8 +449,17 @@ class Mamba(Model):
|
|||
config.quantize = quantize
|
||||
config.speculator = speculator
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
weights_loader = get_loader(
|
||||
quantize=quantize, model_id=model_id, revision=revision
|
||||
)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
weights = Weights(
|
||||
filenames,
|
||||
device,
|
||||
dtype,
|
||||
process_group=self.process_group,
|
||||
weights_loader=weights_loader,
|
||||
)
|
||||
model = MambaModel(config, weights)
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(Mamba, self).__init__(
|
||||
|
|
|
@ -18,6 +18,7 @@ from text_generation_server.utils import (
|
|||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.chunks import concat_text_chunks
|
||||
from text_generation_server.utils.quantization import get_loader
|
||||
from text_generation_server.utils.tokens import batch_top_tokens
|
||||
from text_generation_server.models import Model
|
||||
from text_generation_server.models.types import (
|
||||
|
@ -586,6 +587,9 @@ class Seq2SeqLM(Model):
|
|||
)
|
||||
tokenizer.bos_token_id = config.decoder_start_token_id
|
||||
|
||||
weights_loader = get_loader(
|
||||
quantize=quantize, model_id=model_id, revision=revision
|
||||
)
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(
|
||||
|
@ -594,6 +598,7 @@ class Seq2SeqLM(Model):
|
|||
dtype=dtype,
|
||||
process_group=self.process_group,
|
||||
aliases=aliases,
|
||||
weights_loader=weights_loader,
|
||||
)
|
||||
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
|
|
@ -0,0 +1,119 @@
|
|||
from typing import Optional
|
||||
import os
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from text_generation_server.utils.weights import DefaultWeightsLoader, WeightsLoader
|
||||
|
||||
|
||||
@dataclass
|
||||
class _QuantizerConfig:
|
||||
bits: int
|
||||
checkpoint_format: Optional[str]
|
||||
desc_act: bool
|
||||
groupsize: int
|
||||
quant_method: str
|
||||
sym: bool
|
||||
|
||||
|
||||
# We should probably do this with Pytantic JSON deserialization,
|
||||
# but for now we'll stay close to the old _set_gptq_params.
|
||||
def _get_quantizer_config(model_id, revision):
|
||||
bits = 4
|
||||
groupsize = -1
|
||||
quant_method = "gptq"
|
||||
checkpoint_format = None
|
||||
sym = True
|
||||
desc_act = False
|
||||
|
||||
filename = "config.json"
|
||||
try:
|
||||
if os.path.exists(os.path.join(model_id, filename)):
|
||||
filename = os.path.join(model_id, filename)
|
||||
else:
|
||||
filename = hf_hub_download(model_id, filename=filename, revision=revision)
|
||||
with open(filename, "r") as f:
|
||||
data = json.load(f)
|
||||
bits = data["quantization_config"]["bits"]
|
||||
groupsize = data["quantization_config"]["group_size"]
|
||||
# Order is important here, desc_act is missing on some real models
|
||||
quant_method = data["quantization_config"]["quant_method"]
|
||||
checkpoint_format = data["quantization_config"].get("checkpoint_format")
|
||||
sym = data["quantization_config"]["sym"]
|
||||
desc_act = data["quantization_config"]["desc_act"]
|
||||
except Exception:
|
||||
filename = "quantize_config.json"
|
||||
try:
|
||||
if os.path.exists(os.path.join(model_id, filename)):
|
||||
filename = os.path.join(model_id, filename)
|
||||
else:
|
||||
filename = hf_hub_download(
|
||||
model_id, filename=filename, revision=revision
|
||||
)
|
||||
with open(filename, "r") as f:
|
||||
data = json.load(f)
|
||||
bits = data["bits"]
|
||||
groupsize = data["group_size"]
|
||||
sym = data["sym"]
|
||||
desc_act = data["desc_act"]
|
||||
if "version" in data and data["version"] == "GEMM":
|
||||
quant_method = "awq"
|
||||
except Exception:
|
||||
filename = "quant_config.json"
|
||||
try:
|
||||
if os.path.exists(os.path.join(model_id, filename)):
|
||||
filename = os.path.join(model_id, filename)
|
||||
else:
|
||||
filename = hf_hub_download(
|
||||
model_id, filename=filename, revision=revision
|
||||
)
|
||||
with open(filename, "r") as f:
|
||||
data = json.load(f)
|
||||
bits = data["w_bit"]
|
||||
groupsize = data["q_group_size"]
|
||||
desc_act = data["desc_act"]
|
||||
if "version" in data and data["version"] == "GEMM":
|
||||
quant_method = "awq"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return _QuantizerConfig(
|
||||
bits=bits,
|
||||
groupsize=groupsize,
|
||||
quant_method=quant_method,
|
||||
checkpoint_format=checkpoint_format,
|
||||
sym=sym,
|
||||
desc_act=desc_act,
|
||||
)
|
||||
|
||||
|
||||
def get_loader(
|
||||
quantize: Optional[str], model_id: str, revision: Optional[str]
|
||||
) -> WeightsLoader:
|
||||
quantizer_config = _get_quantizer_config(model_id, revision)
|
||||
if quantize in {"awq", "gptq"}:
|
||||
from text_generation_server.layers.gptq import GPTQWeightsLoader
|
||||
|
||||
return GPTQWeightsLoader(
|
||||
bits=quantizer_config.bits,
|
||||
desc_act=quantizer_config.desc_act,
|
||||
groupsize=quantizer_config.groupsize,
|
||||
quant_method=quantizer_config.quant_method,
|
||||
quantize=quantize,
|
||||
sym=quantizer_config.sym,
|
||||
)
|
||||
elif quantize == "exl2":
|
||||
from text_generation_server.layers.exl2 import Exl2WeightsLoader
|
||||
|
||||
return Exl2WeightsLoader()
|
||||
elif quantize == "marlin":
|
||||
from text_generation_server.layers.marlin import MarlinWeightsLoader
|
||||
|
||||
return MarlinWeightsLoader(
|
||||
bits=quantizer_config.bits,
|
||||
is_marlin_24=quantizer_config.checkpoint_format == "marlin_24",
|
||||
)
|
||||
else:
|
||||
return DefaultWeightsLoader()
|
|
@ -1,13 +1,89 @@
|
|||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
from safetensors import safe_open, SafetensorError
|
||||
from safetensors import safe_open
|
||||
import torch
|
||||
from loguru import logger
|
||||
from huggingface_hub import hf_hub_download
|
||||
import json
|
||||
from text_generation_server.layers.gptq import GPTQParams
|
||||
from text_generation_server.utils.log import log_once
|
||||
|
||||
|
||||
class WeightsLoader(ABC):
|
||||
"""
|
||||
Instances of this type implement higher-level weight loading.
|
||||
|
||||
At a low-level, every weight is stored in the Safetensors format.
|
||||
The interpretation of weights may be different however, for instance
|
||||
could be packed, quantized weights. Loaders are responsible for
|
||||
interpreting the raw tensors, sharding tensors in a manner compatible
|
||||
with the format, etc.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_weights_col_packed(
|
||||
self,
|
||||
weights: "Weights",
|
||||
prefix: str,
|
||||
block_sizes: Union[int, List[int]],
|
||||
):
|
||||
"""
|
||||
Get the packed weights at the given prefix with column-splitting for
|
||||
tensor parallelism. This method should be used when multiple different
|
||||
weights are packed into a tensor, for instance, query/key/value
|
||||
weights or a gate/up projection.
|
||||
|
||||
The `block_sizes` determines the proportions of the packed tensors.
|
||||
The columns are split in equally sized blocks when `block_sizes` is an
|
||||
`int`, or in blocks proportional given to the sizes. For instance
|
||||
`[2, 1, 1]` will divide an input with dimensionality `1024` in
|
||||
`[512, 256, 256]`.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_weights_col(self, weights: "Weights", prefix: str):
|
||||
"""
|
||||
Get weights at the given prefix and apply column-splitting for tensor
|
||||
paralllism.
|
||||
"""
|
||||
return weights.get_multi_weights_col([prefix], 0)
|
||||
|
||||
@abstractmethod
|
||||
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
|
||||
"""
|
||||
Get the weights at the given prefixes, column-split them for tensor
|
||||
parallelim, and then concatenate the weights along the given dimension.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_weights_row(self, weights: "Weights", prefix: str):
|
||||
"""
|
||||
Get the weights at the given prefix and apply row-splitting for tensor
|
||||
parallism.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class DefaultWeightsLoader(WeightsLoader):
|
||||
"""
|
||||
Loader that uses tensors as-is with the exception of applying sharding
|
||||
and/or concatenation.
|
||||
"""
|
||||
|
||||
def get_weights_col_packed(
|
||||
self,
|
||||
weights: "Weights",
|
||||
prefix: str,
|
||||
block_sizes: Union[int, List[int]],
|
||||
):
|
||||
return weights.get_packed_sharded(
|
||||
f"{prefix}.weight", dim=0, block_sizes=block_sizes
|
||||
)
|
||||
|
||||
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
|
||||
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||
return torch.cat(w, dim=dim)
|
||||
|
||||
def get_weights_row(self, weights: "Weights", prefix: str):
|
||||
return weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||
|
||||
|
||||
class Weights:
|
||||
|
@ -17,6 +93,7 @@ class Weights:
|
|||
device,
|
||||
dtype,
|
||||
process_group,
|
||||
weights_loader: WeightsLoader,
|
||||
aliases: Optional[Dict[str, List[str]]] = None,
|
||||
prefix: Optional[str] = None,
|
||||
):
|
||||
|
@ -37,6 +114,7 @@ class Weights:
|
|||
self.dtype = dtype
|
||||
self.process_group = process_group
|
||||
self.prefix = prefix
|
||||
self.weights_loader = weights_loader
|
||||
self._handles = {}
|
||||
|
||||
def _get_handle(self, filename):
|
||||
|
@ -69,6 +147,13 @@ class Weights:
|
|||
slice_ = f.get_slice(tensor_name)
|
||||
return slice_
|
||||
|
||||
def _has_tensor(self, tensor_name: str):
|
||||
try:
|
||||
self.get_filename(tensor_name)
|
||||
except Exception:
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_shape(self, tensor_name: str):
|
||||
return self._get_slice(tensor_name).get_shape()
|
||||
|
||||
|
@ -176,300 +261,31 @@ class Weights:
|
|||
def get_weights_col_packed_qkv(
|
||||
self,
|
||||
prefix: str,
|
||||
quantize: str,
|
||||
num_heads: int,
|
||||
num_key_value_heads: int,
|
||||
):
|
||||
return self.get_weights_col_packed(
|
||||
prefix, quantize, [num_heads, num_key_value_heads, num_key_value_heads]
|
||||
prefix, [num_heads, num_key_value_heads, num_key_value_heads]
|
||||
)
|
||||
|
||||
def get_weights_col_packed_gate_up(self, prefix: str, quantize: str):
|
||||
return self.get_weights_col_packed(prefix, quantize, 2)
|
||||
def get_weights_col_packed_gate_up(self, prefix: str):
|
||||
return self.get_weights_col_packed(prefix, 2)
|
||||
|
||||
def get_weights_col_packed(
|
||||
self, prefix: str, quantize: str, block_sizes: Union[int, List[int]]
|
||||
):
|
||||
def get_weights_col_packed(self, prefix: str, block_sizes: Union[int, List[int]]):
|
||||
"""
|
||||
Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being
|
||||
already alternating Q,K,V within the main tensor.
|
||||
|
||||
The columns are split in equally sized blocks when blocks is an `int`, or
|
||||
in blocks proportional given to the sizes. For instance `[2, 1, 1]` will
|
||||
divide an input with dimensionality `1024` in `[512, 256, 256]`. This is
|
||||
convenient for e.g. splitting QKV without knowing the storage details of
|
||||
quantized weights.
|
||||
"""
|
||||
if quantize in ["gptq", "awq"]:
|
||||
from text_generation_server.layers.gptq import GPTQWeight
|
||||
from text_generation_server.layers.marlin import (
|
||||
can_use_gptq_marlin,
|
||||
repack_gptq_for_marlin,
|
||||
)
|
||||
return self.weights_loader.get_weights_col_packed(self, prefix, block_sizes)
|
||||
|
||||
try:
|
||||
qweight = self.get_packed_sharded(
|
||||
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `{quantize}` weight, make sure the model is already quantized."
|
||||
)
|
||||
scales = self.get_packed_sharded(
|
||||
f"{prefix}.scales", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
scales = scales.to(dtype=self.dtype)
|
||||
def get_weights_col(self, prefix: str):
|
||||
return self.weights_loader.get_weights_col(self, prefix)
|
||||
|
||||
gptq_params = self._get_gptq_params()
|
||||
if can_use_gptq_marlin(gptq_params, quantize):
|
||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
||||
return repack_gptq_for_marlin(
|
||||
qweight=qweight,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
bits=gptq_params.bits,
|
||||
desc_act=gptq_params.desc_act,
|
||||
groupsize=gptq_params.groupsize,
|
||||
sym=gptq_params.sym,
|
||||
sharded_infeatures=False,
|
||||
)
|
||||
|
||||
qzeros = self.get_packed_sharded(
|
||||
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
if quantize == "gptq" and gptq_params.quant_method == "gptq":
|
||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
||||
elif quantize == "gptq" and gptq_params.quant_method == "awq":
|
||||
log_once(
|
||||
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
||||
)
|
||||
from text_generation_server.layers.awq.conversion_utils import (
|
||||
fast_awq_to_gptq,
|
||||
)
|
||||
|
||||
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||
g_idx = (
|
||||
torch.arange(
|
||||
qweight.shape[0] * (32 // gptq_params.bits),
|
||||
device=qweight.device,
|
||||
)
|
||||
// gptq_params.groupsize
|
||||
).to(dtype=torch.int32)
|
||||
else:
|
||||
g_idx = None
|
||||
|
||||
weight = GPTQWeight(
|
||||
qweight=qweight,
|
||||
qzeros=qzeros,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
bits=gptq_params.bits,
|
||||
groupsize=gptq_params.groupsize,
|
||||
use_exllama=False,
|
||||
)
|
||||
elif quantize == "marlin":
|
||||
from text_generation_server.layers.marlin import (
|
||||
GPTQMarlin24Weight,
|
||||
MarlinWeight,
|
||||
repack_gptq_for_marlin,
|
||||
)
|
||||
|
||||
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
|
||||
if is_marlin_24:
|
||||
B = self.get_packed_sharded(
|
||||
f"{prefix}.B_24", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
B_meta = self.get_packed_sharded(
|
||||
f"{prefix}.B_meta", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
s = self.get_packed_sharded(
|
||||
f"{prefix}.s", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
|
||||
gptq_params = self._get_gptq_params()
|
||||
weight = GPTQMarlin24Weight(
|
||||
B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
|
||||
)
|
||||
else:
|
||||
B = self.get_packed_sharded(
|
||||
f"{prefix}.B", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
s = self.get_packed_sharded(
|
||||
f"{prefix}.s", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
weight = MarlinWeight(B=B, s=s)
|
||||
else:
|
||||
weight = self.get_packed_sharded(
|
||||
f"{prefix}.weight", dim=0, block_sizes=block_sizes
|
||||
)
|
||||
return weight
|
||||
|
||||
def get_weights_col(self, prefix: str, quantize: str):
|
||||
if quantize == "exl2":
|
||||
from text_generation_server.layers.exl2 import Exl2Weight
|
||||
|
||||
try:
|
||||
q_weight = self.get_tensor(f"{prefix}.q_weight")
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
|
||||
)
|
||||
|
||||
q_scale = self.get_tensor(f"{prefix}.q_scale")
|
||||
q_invperm = self.get_tensor(f"{prefix}.q_invperm")
|
||||
q_scale_max = self.get_tensor(f"{prefix}.q_scale_max")
|
||||
q_groups = self.get_tensor(f"{prefix}.q_groups")
|
||||
|
||||
return Exl2Weight(
|
||||
q_weight=q_weight,
|
||||
q_scale=q_scale,
|
||||
q_invperm=q_invperm,
|
||||
q_scale_max=q_scale_max,
|
||||
q_groups=q_groups,
|
||||
)
|
||||
|
||||
return self.get_multi_weights_col([prefix], quantize, 0)
|
||||
|
||||
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
|
||||
if quantize == "exl2":
|
||||
raise ValueError("get_multi_weights_col is not supported for exl2")
|
||||
elif quantize in ["gptq", "awq"]:
|
||||
from text_generation_server.layers.gptq import GPTQWeight
|
||||
from text_generation_server.layers.marlin import (
|
||||
can_use_gptq_marlin,
|
||||
repack_gptq_for_marlin,
|
||||
)
|
||||
|
||||
try:
|
||||
qweight = torch.cat(
|
||||
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `{quantize}` weight, make sure the model is already quantized"
|
||||
)
|
||||
|
||||
scales = torch.cat(
|
||||
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
|
||||
gptq_params = self._get_gptq_params()
|
||||
if can_use_gptq_marlin(gptq_params, quantize):
|
||||
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
|
||||
for w2 in w[1:]:
|
||||
torch.testing.assert_close(w2, w[0])
|
||||
g_idx = w[0]
|
||||
|
||||
return repack_gptq_for_marlin(
|
||||
qweight=qweight,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
bits=gptq_params.bits,
|
||||
desc_act=gptq_params.desc_act,
|
||||
groupsize=gptq_params.groupsize,
|
||||
sym=gptq_params.sym,
|
||||
sharded_infeatures=False,
|
||||
)
|
||||
|
||||
qzeros = torch.cat(
|
||||
[self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
|
||||
from text_generation_server.layers.gptq import HAS_EXLLAMA
|
||||
|
||||
use_exllama = (
|
||||
gptq_params.bits == 4
|
||||
and HAS_EXLLAMA
|
||||
and quantize == "gptq"
|
||||
and not gptq_params.desc_act
|
||||
)
|
||||
|
||||
if quantize == "gptq" and gptq_params.quant_method == "gptq":
|
||||
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
|
||||
for w2 in w[1:]:
|
||||
torch.testing.assert_close(w2, w[0])
|
||||
g_idx = w[0]
|
||||
elif quantize == "gptq" and gptq_params.quant_method == "awq":
|
||||
log_once(
|
||||
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
||||
)
|
||||
from text_generation_server.layers.awq.conversion_utils import (
|
||||
fast_awq_to_gptq,
|
||||
)
|
||||
|
||||
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||
if use_exllama:
|
||||
g_idx = None
|
||||
else:
|
||||
g_idx = (
|
||||
torch.arange(
|
||||
qweight.shape[0] * (32 // gptq_params.bits),
|
||||
device=qweight.device,
|
||||
)
|
||||
// gptq_params.groupsize
|
||||
).to(dtype=torch.int32)
|
||||
else:
|
||||
g_idx = None
|
||||
|
||||
weight = GPTQWeight(
|
||||
qweight=qweight,
|
||||
qzeros=qzeros,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
bits=gptq_params.bits,
|
||||
groupsize=gptq_params.groupsize,
|
||||
use_exllama=use_exllama,
|
||||
)
|
||||
elif quantize == "marlin":
|
||||
from text_generation_server.layers.gptq import GPTQWeight
|
||||
from text_generation_server.layers.marlin import (
|
||||
GPTQMarlin24Weight,
|
||||
MarlinWeight,
|
||||
)
|
||||
|
||||
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
|
||||
if is_marlin_24:
|
||||
try:
|
||||
B = torch.cat(
|
||||
[self.get_sharded(f"{p}.B_24", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `{quantize}` weight, make sure the model is already quantized"
|
||||
)
|
||||
|
||||
B_meta = torch.cat(
|
||||
[self.get_sharded(f"{p}.B_meta", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
|
||||
s = torch.cat(
|
||||
[self.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
|
||||
gptq_params = self._get_gptq_params()
|
||||
weight = GPTQMarlin24Weight(
|
||||
B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
|
||||
)
|
||||
else:
|
||||
try:
|
||||
B = torch.cat(
|
||||
[self.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `{quantize}` weight, make sure the model is already quantized"
|
||||
)
|
||||
s = torch.cat(
|
||||
[self.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
|
||||
weight = MarlinWeight(B=B, s=s)
|
||||
|
||||
else:
|
||||
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||
weight = torch.cat(w, dim=dim)
|
||||
|
||||
return weight
|
||||
def get_multi_weights_col(self, prefixes: List[str], dim: int):
|
||||
return self.weights_loader.get_multi_weights_col(self, prefixes, dim)
|
||||
|
||||
def get_tensor_shard(self, var, dim):
|
||||
world_size = self.process_group.size()
|
||||
|
@ -487,324 +303,22 @@ class Weights:
|
|||
tensor = tensor.to(device=self.device)
|
||||
return tensor
|
||||
|
||||
def get_multi_weights_row(self, prefix: str, quantize: str):
|
||||
if quantize == "exl2":
|
||||
from text_generation_server.layers.exl2 import Exl2Weight
|
||||
def get_weights_row(self, prefix: str):
|
||||
return self.weights_loader.get_weights_row(self, prefix)
|
||||
|
||||
try:
|
||||
q_weight = self.get_tensor(f"{prefix}.q_weight")
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `exl2`-quantized weight, make sure the model is already quantized."
|
||||
)
|
||||
@contextmanager
|
||||
def use_loader(self, weights_loader: WeightsLoader):
|
||||
"""
|
||||
This method is a context manager that can be used to use `Weights` with
|
||||
a different loader for the duration of the context.
|
||||
"""
|
||||
|
||||
q_scale = self.get_tensor(f"{prefix}.q_scale")
|
||||
q_invperm = self.get_tensor(f"{prefix}.q_invperm")
|
||||
q_scale_max = self.get_tensor(f"{prefix}.q_scale_max")
|
||||
q_groups = self.get_tensor(f"{prefix}.q_groups")
|
||||
|
||||
return Exl2Weight(
|
||||
q_weight=q_weight,
|
||||
q_scale=q_scale,
|
||||
q_invperm=q_invperm,
|
||||
q_scale_max=q_scale_max,
|
||||
q_groups=q_groups,
|
||||
)
|
||||
|
||||
elif quantize == "gptq":
|
||||
from text_generation_server.layers.marlin import (
|
||||
can_use_gptq_marlin,
|
||||
repack_gptq_for_marlin,
|
||||
)
|
||||
|
||||
gptq_params = self._get_gptq_params()
|
||||
if can_use_gptq_marlin(gptq_params, quantize):
|
||||
log_once(logger.info, "Using GPTQ-Marlin kernels")
|
||||
try:
|
||||
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `{quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
|
||||
)
|
||||
|
||||
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||
if gptq_params.desc_act or gptq_params.groupsize == -1:
|
||||
scales = self.get_tensor(f"{prefix}.scales")
|
||||
else:
|
||||
scales = self.get_sharded(f"{prefix}.scales", dim=0)
|
||||
|
||||
sharded_in_features = self.process_group.size() > 1
|
||||
|
||||
return repack_gptq_for_marlin(
|
||||
qweight=qweight,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
bits=gptq_params.bits,
|
||||
desc_act=gptq_params.desc_act,
|
||||
groupsize=gptq_params.groupsize,
|
||||
sym=gptq_params.sym,
|
||||
sharded_infeatures=sharded_in_features,
|
||||
)
|
||||
|
||||
use_exllama = True
|
||||
if gptq_params.bits != 4:
|
||||
use_exllama = False
|
||||
|
||||
if gptq_params.desc_act:
|
||||
log_once(logger.warning, "Disabling exllama because desc_act=True")
|
||||
use_exllama = False
|
||||
|
||||
try:
|
||||
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||
)
|
||||
|
||||
if gptq_params.quant_method == "gptq":
|
||||
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||
elif gptq_params.quant_method == "awq":
|
||||
g_idx = None
|
||||
|
||||
if self.process_group.size() > 1:
|
||||
if g_idx is not None:
|
||||
if (
|
||||
not torch.equal(
|
||||
g_idx.cpu(),
|
||||
torch.tensor(
|
||||
[
|
||||
i // gptq_params.groupsize
|
||||
for i in range(g_idx.shape[0])
|
||||
],
|
||||
dtype=torch.int32,
|
||||
),
|
||||
)
|
||||
and not (g_idx == 0).all()
|
||||
):
|
||||
# Exllama implementation does not support row tensor parallelism with act-order, as
|
||||
# it would require to reorder input activations that are split unto several GPUs
|
||||
use_exllama = False
|
||||
|
||||
from text_generation_server.layers.gptq import (
|
||||
HAS_EXLLAMA,
|
||||
CAN_EXLLAMA,
|
||||
GPTQWeight,
|
||||
)
|
||||
|
||||
if use_exllama:
|
||||
if not HAS_EXLLAMA:
|
||||
if CAN_EXLLAMA:
|
||||
log_once(
|
||||
logger.warning,
|
||||
"Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True",
|
||||
)
|
||||
use_exllama = False
|
||||
else:
|
||||
log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")
|
||||
|
||||
if use_exllama and gptq_params.groupsize != -1:
|
||||
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
|
||||
scales = self.get_sharded(f"{prefix}.scales", dim=0)
|
||||
else:
|
||||
qzeros = self.get_tensor(f"{prefix}.qzeros")
|
||||
scales = self.get_tensor(f"{prefix}.scales")
|
||||
|
||||
if use_exllama and g_idx is not None:
|
||||
g_idx = g_idx - g_idx[0]
|
||||
|
||||
if gptq_params.quant_method == "awq":
|
||||
log_once(
|
||||
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
||||
)
|
||||
from text_generation_server.layers.awq.conversion_utils import (
|
||||
fast_awq_to_gptq,
|
||||
)
|
||||
|
||||
qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
|
||||
if use_exllama:
|
||||
g_idx = None
|
||||
else:
|
||||
g_idx = (
|
||||
torch.arange(
|
||||
qweight.shape[0] * (32 // gptq_params.bits),
|
||||
device=qweight.device,
|
||||
)
|
||||
// gptq_params.groupsize
|
||||
).to(dtype=torch.int32)
|
||||
|
||||
weight = GPTQWeight(
|
||||
qweight=qweight,
|
||||
qzeros=qzeros,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
bits=gptq_params.bits,
|
||||
groupsize=gptq_params.groupsize,
|
||||
use_exllama=use_exllama,
|
||||
)
|
||||
elif quantize == "awq":
|
||||
from text_generation_server.layers.gptq import GPTQWeight
|
||||
|
||||
gptq_params = self._get_gptq_params()
|
||||
|
||||
try:
|
||||
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
"Cannot load `awq` weight, make sure the model is already quantized"
|
||||
)
|
||||
|
||||
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
|
||||
scales = self.get_sharded(f"{prefix}.scales", dim=0)
|
||||
g_idx = None
|
||||
use_exllama = False
|
||||
|
||||
weight = GPTQWeight(
|
||||
qweight=qweight,
|
||||
qzeros=qzeros,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
bits=gptq_params.bits,
|
||||
groupsize=gptq_params.groupsize,
|
||||
use_exllama=use_exllama,
|
||||
)
|
||||
elif quantize == "marlin":
|
||||
from text_generation_server.layers.gptq import GPTQWeight
|
||||
from text_generation_server.layers.marlin import (
|
||||
GPTQMarlin24Weight,
|
||||
MarlinWeight,
|
||||
)
|
||||
|
||||
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
|
||||
if is_marlin_24:
|
||||
try:
|
||||
B = self.get_sharded(f"{prefix}.B_24", dim=0)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
"Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized."
|
||||
)
|
||||
|
||||
B_meta = self.get_sharded(f"{prefix}.B_meta", dim=0)
|
||||
num_groups = self._get_slice(f"{prefix}.s").get_shape()[0]
|
||||
if num_groups == 1:
|
||||
# The number of groups is 1 when groupsize == -1. share
|
||||
# scales between all shards in this case.
|
||||
s = self.get_tensor(f"{prefix}.s")
|
||||
else:
|
||||
s = self.get_sharded(f"{prefix}.s", dim=0)
|
||||
|
||||
gptq_params = self._get_gptq_params()
|
||||
weight = GPTQMarlin24Weight(
|
||||
B=B, B_meta=B_meta, s=s, bits=gptq_params.bits
|
||||
)
|
||||
else:
|
||||
try:
|
||||
B = self.get_sharded(f"{prefix}.B", dim=0)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
"Cannot load `marlin` weight, make sure the model is already quantized."
|
||||
)
|
||||
|
||||
num_groups = self._get_slice(f"{prefix}.s").get_shape()[0]
|
||||
if num_groups == 1:
|
||||
# The number of groups is 1 when groupsize == -1. share
|
||||
# scales between all shards in this case.
|
||||
s = self.get_tensor(f"{prefix}.s")
|
||||
else:
|
||||
s = self.get_sharded(f"{prefix}.s", dim=0)
|
||||
weight = MarlinWeight(B=B, s=s)
|
||||
else:
|
||||
weight = self.get_sharded(f"{prefix}.weight", dim=1)
|
||||
return weight
|
||||
|
||||
def _get_gptq_params(self) -> GPTQParams:
|
||||
old_loader = self.weights_loader
|
||||
self.weights_loader = weights_loader
|
||||
try:
|
||||
bits = self.get_tensor("gptq_bits").item()
|
||||
groupsize = self.get_tensor("gptq_groupsize").item()
|
||||
checkpoint_format = getattr(self, "gptq_checkpoint_format", None)
|
||||
desc_act = False
|
||||
sym = False
|
||||
quant_method = "gptq"
|
||||
except (SafetensorError, RuntimeError) as e:
|
||||
try:
|
||||
bits = self.gptq_bits
|
||||
groupsize = self.gptq_groupsize
|
||||
checkpoint_format = getattr(self, "gptq_checkpoint_format", None)
|
||||
desc_act = getattr(self, "gptq_desc_act", False)
|
||||
quant_method = getattr(self, "quant_method", "gptq")
|
||||
sym = getattr(self, "sym", True)
|
||||
except Exception:
|
||||
raise e
|
||||
|
||||
return GPTQParams(
|
||||
bits=bits,
|
||||
checkpoint_format=checkpoint_format,
|
||||
desc_act=desc_act,
|
||||
groupsize=groupsize,
|
||||
quant_method=quant_method,
|
||||
sym=sym,
|
||||
)
|
||||
|
||||
def _set_gptq_params(self, model_id, revision):
|
||||
filename = "config.json"
|
||||
|
||||
self.quant_method = None
|
||||
try:
|
||||
if os.path.exists(os.path.join(model_id, filename)):
|
||||
filename = os.path.join(model_id, filename)
|
||||
else:
|
||||
filename = hf_hub_download(
|
||||
model_id, filename=filename, revision=revision
|
||||
)
|
||||
with open(filename, "r") as f:
|
||||
data = json.load(f)
|
||||
self.gptq_bits = data["quantization_config"]["bits"]
|
||||
self.gptq_groupsize = data["quantization_config"]["group_size"]
|
||||
# Order is important here, desc_act is missing on some real models
|
||||
self.quant_method = data["quantization_config"]["quant_method"]
|
||||
self.gptq_checkpoint_format = data["quantization_config"].get(
|
||||
"checkpoint_format"
|
||||
)
|
||||
self.gptq_sym = data["quantization_config"]["sym"]
|
||||
self.gptq_desc_act = data["quantization_config"]["desc_act"]
|
||||
except Exception:
|
||||
filename = "quantize_config.json"
|
||||
try:
|
||||
if os.path.exists(os.path.join(model_id, filename)):
|
||||
filename = os.path.join(model_id, filename)
|
||||
else:
|
||||
filename = hf_hub_download(
|
||||
model_id, filename=filename, revision=revision
|
||||
)
|
||||
with open(filename, "r") as f:
|
||||
data = json.load(f)
|
||||
self.gptq_bits = data["bits"]
|
||||
self.gptq_groupsize = data["group_size"]
|
||||
self.gptq_sym = data["sym"]
|
||||
self.gptq_desc_act = data["desc_act"]
|
||||
if "version" in data and data["version"] == "GEMM":
|
||||
self.quant_method = "awq"
|
||||
except Exception:
|
||||
filename = "quant_config.json"
|
||||
try:
|
||||
if os.path.exists(os.path.join(model_id, filename)):
|
||||
filename = os.path.join(model_id, filename)
|
||||
else:
|
||||
filename = hf_hub_download(
|
||||
model_id, filename=filename, revision=revision
|
||||
)
|
||||
with open(filename, "r") as f:
|
||||
data = json.load(f)
|
||||
self.gptq_bits = data["w_bit"]
|
||||
self.gptq_groupsize = data["q_group_size"]
|
||||
self.gptq_desc_act = data["desc_act"]
|
||||
if "version" in data and data["version"] == "GEMM":
|
||||
self.quant_method = "awq"
|
||||
except Exception:
|
||||
if self.quant_method is None:
|
||||
if "awq" in model_id.lower():
|
||||
self.quant_method = "awq"
|
||||
elif "gptq" in model_id.lower():
|
||||
self.quant_method = "gptq"
|
||||
yield
|
||||
finally:
|
||||
self.weights_loader = old_loader
|
||||
|
||||
|
||||
def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]:
|
||||
|
|
|
@ -155,7 +155,7 @@ def check_openapi(check: bool):
|
|||
filename,
|
||||
],
|
||||
capture_output=True,
|
||||
).stdout.decode()
|
||||
).stdout.decode("utf-8")
|
||||
os.remove(tmp_filename)
|
||||
|
||||
if diff:
|
||||
|
@ -164,11 +164,27 @@ def check_openapi(check: bool):
|
|||
"OpenAPI documentation is not up-to-date, run `python update_doc.py` in order to update it"
|
||||
)
|
||||
|
||||
return True
|
||||
else:
|
||||
os.rename(tmp_filename, filename)
|
||||
print("OpenAPI documentation updated.")
|
||||
return True
|
||||
errors = subprocess.run(
|
||||
[
|
||||
"swagger-cli",
|
||||
# allow for trailing whitespace since it's not significant
|
||||
# and the precommit hook will remove it
|
||||
"validate",
|
||||
filename,
|
||||
],
|
||||
capture_output=True,
|
||||
).stderr.decode("utf-8")
|
||||
# The openapi specs fails on `exclusive_minimum` which is expected to be a boolean where
|
||||
# utoipa outputs a value instead: https://github.com/juhaku/utoipa/issues/969
|
||||
if not errors.startswith("Swagger schema validation failed."):
|
||||
print(errors)
|
||||
raise Exception(
|
||||
f"OpenAPI documentation is invalid, `swagger-cli validate` showed some error:\n {errors}"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
|
|
Loading…
Reference in New Issue