Merge branch 'main' into ci_amd3

This commit is contained in:
fxmarty 2024-07-08 13:06:39 +02:00
commit 8c590be463
84 changed files with 2061 additions and 3022 deletions

View File

@ -11,10 +11,30 @@ jobs:
- name: Checkout code
uses: actions/checkout@v2
- name: Set up Rust
uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: stable
- name: Install Protocol Buffers compiler
run: |
sudo apt-get update
sudo apt-get install -y protobuf-compiler libprotobuf-dev
- name: Install Launcher
id: install-launcher
run: cargo install --path launcher/
- name: Check launcher Docs are up-to-date
- name: Install router
id: install-router
run: cargo install --path router/
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.x'
- name: Check that documentation is up-to-date
run: |
echo text-generation-launcher --help
python update_doc.py --check

View File

@ -11,6 +11,11 @@ on:
# - rocm
# - xpu
required: true
release-tests:
description: "Run release integration tests"
required: true
default: false
type: boolean
jobs:
build-and-push:
@ -195,7 +200,7 @@ jobs:
runs-on: ["self-hosted", "${{ needs.build-and-push.outputs.runs_on }}", "multi-gpu"]
if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest'
env:
PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main') && '--release' || '' }}
PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || inputs.release-tests == true) && '--release' || '' }}
steps:
- name: Checkout repository
uses: actions/checkout@v4

View File

@ -20,7 +20,14 @@ on:
- "Dockerfile_amd"
- "Dockerfile_intel"
branches:
- 'main'
- "main"
workflow_dispatch:
inputs:
release-tests:
description: "Run release integration tests"
required: true
default: false
type: boolean
jobs:
build:
@ -33,4 +40,6 @@ jobs:
uses: ./.github/workflows/build.yaml # calls the one above ^
with:
hardware: ${{ matrix.hardware }}
# https://github.com/actions/runner/issues/2206
release-tests: ${{ inputs.release-tests == true }}
secrets: inherit

View File

@ -9,7 +9,7 @@ members = [
resolver = "2"
[workspace.package]
version = "2.1.1-dev0"
version = "2.1.2-dev0"
edition = "2021"
authors = ["Olivier Dehaene"]
homepage = "https://github.com/huggingface/text-generation-inference"

View File

@ -4,7 +4,7 @@ WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
FROM chef as planner
FROM chef AS planner
COPY Cargo.lock Cargo.lock
COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml
@ -38,7 +38,7 @@ RUN cargo build --profile release-opt
# Python builder
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 as pytorch-install
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS pytorch-install
ARG PYTORCH_VERSION=2.3.0
ARG PYTHON_VERSION=3.10
@ -81,7 +81,7 @@ RUN case ${TARGETPLATFORM} in \
/opt/conda/bin/conda clean -ya
# CUDA kernels builder image
FROM pytorch-install as kernel-builder
FROM pytorch-install AS kernel-builder
ARG MAX_JOBS=8
@ -90,7 +90,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
&& rm -rf /var/lib/apt/lists/*
# Build Flash Attention CUDA kernels
FROM kernel-builder as flash-att-builder
FROM kernel-builder AS flash-att-builder
WORKDIR /usr/src
@ -100,7 +100,7 @@ COPY server/Makefile-flash-att Makefile
RUN make build-flash-attention
# Build Flash Attention v2 CUDA kernels
FROM kernel-builder as flash-att-v2-builder
FROM kernel-builder AS flash-att-v2-builder
WORKDIR /usr/src
@ -110,14 +110,14 @@ COPY server/Makefile-flash-att-v2 Makefile
RUN make build-flash-attention-v2-cuda
# Build Transformers exllama kernels
FROM kernel-builder as exllama-kernels-builder
FROM kernel-builder AS exllama-kernels-builder
WORKDIR /usr/src
COPY server/exllama_kernels/ .
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build
# Build Transformers exllama kernels
FROM kernel-builder as exllamav2-kernels-builder
FROM kernel-builder AS exllamav2-kernels-builder
WORKDIR /usr/src
COPY server/exllamav2_kernels/ .
@ -125,42 +125,42 @@ COPY server/exllamav2_kernels/ .
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build
# Build Transformers awq kernels
FROM kernel-builder as awq-kernels-builder
FROM kernel-builder AS awq-kernels-builder
WORKDIR /usr/src
COPY server/Makefile-awq Makefile
# Build specific version of transformers
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-awq
# Build eetq kernels
FROM kernel-builder as eetq-kernels-builder
FROM kernel-builder AS eetq-kernels-builder
WORKDIR /usr/src
COPY server/Makefile-eetq Makefile
# Build specific version of transformers
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-eetq
# Build marlin kernels
FROM kernel-builder as marlin-kernels-builder
FROM kernel-builder AS marlin-kernels-builder
WORKDIR /usr/src
COPY server/marlin/ .
# Build specific version of transformers
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build
# Build Lorax Punica kernels
FROM kernel-builder as lorax-punica-builder
FROM kernel-builder AS lorax-punica-builder
WORKDIR /usr/src
COPY server/Makefile-lorax-punica Makefile
# Build specific version of transformers
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-lorax-punica
# Build Transformers CUDA kernels
FROM kernel-builder as custom-kernels-builder
FROM kernel-builder AS custom-kernels-builder
WORKDIR /usr/src
COPY server/custom_kernels/ .
# Build specific version of transformers
RUN python setup.py build
# Build vllm CUDA kernels
FROM kernel-builder as vllm-builder
FROM kernel-builder AS vllm-builder
WORKDIR /usr/src
@ -172,13 +172,13 @@ COPY server/Makefile-vllm Makefile
RUN make build-vllm-cuda
# Build mamba kernels
FROM kernel-builder as mamba-builder
FROM kernel-builder AS mamba-builder
WORKDIR /usr/src
COPY server/Makefile-selective-scan Makefile
RUN make build-all
# Text Generation Inference base image
FROM nvidia/cuda:12.1.0-base-ubuntu22.04 as base
FROM nvidia/cuda:12.1.0-base-ubuntu22.04 AS base
# Conda env
ENV PATH=/opt/conda/bin:$PATH \
@ -260,7 +260,7 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/lo
# AWS Sagemaker compatible image
FROM base as sagemaker
FROM base AS sagemaker
COPY sagemaker-entrypoint.sh entrypoint.sh
RUN chmod +x entrypoint.sh

View File

@ -4,7 +4,7 @@ WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
FROM chef as planner
FROM chef AS planner
COPY Cargo.lock Cargo.lock
COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml
@ -37,7 +37,7 @@ COPY launcher launcher
RUN cargo build --profile release-opt
# Text Generation Inference base image for RoCm
FROM rocm/dev-ubuntu-22.04:6.1.1_hip_update as base
FROM rocm/dev-ubuntu-22.04:6.1.1_hip_update AS base
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
build-essential \
@ -118,7 +118,7 @@ ARG BUILD_CAFFE2="0" \
RUN cd pytorch && python tools/amd_build/build_amd.py && python setup.py install
# Set as recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm
# Set AS recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm
ENV HIP_FORCE_DEV_KERNARG=1
# On MI250 and MI300, performances for flash with Triton FA are slightly better than CK.
@ -150,26 +150,26 @@ COPY server/Makefile-flash-att-v2 Makefile
RUN make build-flash-attention-v2-rocm
# Build Transformers CUDA kernels (gpt-neox and bloom)
FROM kernel-builder as custom-kernels-builder
FROM kernel-builder AS custom-kernels-builder
WORKDIR /usr/src
COPY server/custom_kernels/ .
RUN python setup.py build
# Build exllama kernels
FROM kernel-builder as exllama-kernels-builder
FROM kernel-builder AS exllama-kernels-builder
WORKDIR /usr/src
COPY server/exllama_kernels/ .
RUN python setup.py build
# Build exllama v2 kernels
FROM kernel-builder as exllamav2-kernels-builder
FROM kernel-builder AS exllamav2-kernels-builder
WORKDIR /usr/src
COPY server/exllamav2_kernels/ .
RUN python setup.py build
FROM base as base-copy
FROM base AS base-copy
# Text Generation Inference base env
ENV HUGGINGFACE_HUB_CACHE=/data \
@ -208,7 +208,7 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
# AWS Sagemaker compatible image
FROM base as sagemaker
FROM base AS sagemaker
COPY sagemaker-entrypoint.sh entrypoint.sh
RUN chmod +x entrypoint.sh

View File

@ -5,7 +5,7 @@ WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
FROM chef as planner
FROM chef AS planner
COPY Cargo.lock Cargo.lock
COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml
@ -40,7 +40,7 @@ RUN cargo build --profile release-opt
# Text Generation Inference base image for Intel
FROM intel/intel-extension-for-pytorch:2.1.30-xpu as xpu
FROM intel/intel-extension-for-pytorch:2.1.30-xpu AS xpu
USER root
# libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it
@ -95,7 +95,7 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/lo
# Text Generation Inference base image for Intel-cpu
FROM ubuntu:22.04 as cpu
FROM ubuntu:22.04 AS cpu
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
curl \
@ -172,6 +172,6 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca
# Install launcher
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
FROM ${PLATFORM} as final
FROM ${PLATFORM} AS final
ENTRYPOINT ["text-generation-launcher"]
CMD ["--json-output"]

View File

@ -79,7 +79,7 @@ model=HuggingFaceH4/zephyr-7b-beta
volume=$PWD/data
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:2.1.0 --model-id $model
ghcr.io/huggingface/text-generation-inference:2.1.1 --model-id $model
```
And then you can make requests like
@ -93,7 +93,7 @@ curl 127.0.0.1:8080/generate_stream \
**Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.1.0-rocm --model-id $model` instead of the command above.
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.1.1-rocm --model-id $model` instead of the command above.
To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli):
```

View File

@ -10,7 +10,7 @@
"name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0"
},
"version": "2.0.1"
"version": "2.1.2-dev0"
},
"paths": {
"/": {
@ -19,7 +19,6 @@
"Text Generation Inference"
],
"summary": "Generate tokens if `stream == false` or a stream of token if `stream == true`",
"description": "Generate tokens if `stream == false` or a stream of token if `stream == true`",
"operationId": "compat_generate",
"requestBody": {
"content": {
@ -108,7 +107,6 @@
"Text Generation Inference"
],
"summary": "Generate tokens",
"description": "Generate tokens",
"operationId": "generate",
"requestBody": {
"content": {
@ -192,7 +190,6 @@
"Text Generation Inference"
],
"summary": "Generate a stream of token using Server-Sent Events",
"description": "Generate a stream of token using Server-Sent Events",
"operationId": "generate_stream",
"requestBody": {
"content": {
@ -276,7 +273,6 @@
"Text Generation Inference"
],
"summary": "Health check method",
"description": "Health check method",
"operationId": "health",
"responses": {
"200": {
@ -305,7 +301,6 @@
"Text Generation Inference"
],
"summary": "Text Generation Inference endpoint info",
"description": "Text Generation Inference endpoint info",
"operationId": "get_model_info",
"responses": {
"200": {
@ -327,7 +322,6 @@
"Text Generation Inference"
],
"summary": "Prometheus metrics scrape endpoint",
"description": "Prometheus metrics scrape endpoint",
"operationId": "metrics",
"responses": {
"200": {
@ -349,7 +343,6 @@
"Text Generation Inference"
],
"summary": "Tokenize inputs",
"description": "Tokenize inputs",
"operationId": "tokenize",
"requestBody": {
"content": {
@ -394,7 +387,6 @@
"Text Generation Inference"
],
"summary": "Generate tokens",
"description": "Generate tokens",
"operationId": "chat_completions",
"requestBody": {
"content": {
@ -483,7 +475,6 @@
"Text Generation Inference"
],
"summary": "Generate tokens",
"description": "Generate tokens",
"operationId": "completions",
"requestBody": {
"content": {
@ -626,7 +617,6 @@
"type": "object",
"required": [
"id",
"object",
"created",
"model",
"system_fingerprint",
@ -653,9 +643,6 @@
"type": "string",
"example": "mistralai/Mistral-7B-Instruct-v0.2"
},
"object": {
"type": "string"
},
"system_fingerprint": {
"type": "string"
},
@ -697,7 +684,6 @@
"type": "object",
"required": [
"id",
"object",
"created",
"model",
"system_fingerprint",
@ -723,9 +709,6 @@
"type": "string",
"example": "mistralai/Mistral-7B-Instruct-v0.2"
},
"object": {
"type": "string"
},
"system_fingerprint": {
"type": "string"
}
@ -756,34 +739,19 @@
"nullable": true
},
"message": {
"$ref": "#/components/schemas/Message"
"$ref": "#/components/schemas/OutputMessage"
}
}
},
"ChatCompletionDelta": {
"type": "object",
"required": [
"role"
],
"properties": {
"content": {
"type": "string",
"example": "What is Deep Learning?",
"nullable": true
"oneOf": [
{
"$ref": "#/components/schemas/TextMessage"
},
"role": {
"type": "string",
"example": "user"
},
"tool_calls": {
"allOf": [
{
"$ref": "#/components/schemas/DeltaToolCall"
}
],
"nullable": true
{
"$ref": "#/components/schemas/ToolCallDelta"
}
}
]
},
"ChatCompletionLogprob": {
"type": "object",
@ -903,6 +871,15 @@
"example": 0.1,
"nullable": true
},
"response_format": {
"allOf": [
{
"$ref": "#/components/schemas/GrammarType"
}
],
"default": "null",
"nullable": true
},
"seed": {
"type": "integer",
"format": "int64",
@ -969,6 +946,38 @@
}
}
},
"Chunk": {
"type": "object",
"required": [
"id",
"created",
"choices",
"model",
"system_fingerprint"
],
"properties": {
"choices": {
"type": "array",
"items": {
"$ref": "#/components/schemas/CompletionComplete"
}
},
"created": {
"type": "integer",
"format": "int64",
"minimum": 0
},
"id": {
"type": "string"
},
"model": {
"type": "string"
},
"system_fingerprint": {
"type": "string"
}
}
},
"CompatGenerateRequest": {
"type": "object",
"required": [
@ -988,6 +997,55 @@
}
}
},
"Completion": {
"oneOf": [
{
"allOf": [
{
"$ref": "#/components/schemas/Chunk"
},
{
"type": "object",
"required": [
"object"
],
"properties": {
"object": {
"type": "string",
"enum": [
"text_completion"
]
}
}
}
]
},
{
"allOf": [
{
"$ref": "#/components/schemas/CompletionFinal"
},
{
"type": "object",
"required": [
"object"
],
"properties": {
"object": {
"type": "string",
"enum": [
"text_completion"
]
}
}
}
]
}
],
"discriminator": {
"propertyName": "object"
}
},
"CompletionComplete": {
"type": "object",
"required": [
@ -1017,15 +1075,15 @@
}
}
},
"CompletionCompleteChunk": {
"CompletionFinal": {
"type": "object",
"required": [
"id",
"object",
"created",
"choices",
"model",
"system_fingerprint"
"system_fingerprint",
"choices",
"usage"
],
"properties": {
"choices": {
@ -1037,19 +1095,21 @@
"created": {
"type": "integer",
"format": "int64",
"example": "1706270835",
"minimum": 0
},
"id": {
"type": "string"
},
"model": {
"type": "string"
},
"object": {
"type": "string"
"type": "string",
"example": "mistralai/Mistral-7B-Instruct-v0.2"
},
"system_fingerprint": {
"type": "string"
},
"usage": {
"$ref": "#/components/schemas/Usage"
}
}
},
@ -1081,12 +1141,7 @@
"example": "mistralai/Mistral-7B-Instruct-v0.2"
},
"prompt": {
"type": "array",
"items": {
"type": "string"
},
"description": "The prompt to generate completions for.",
"example": "What is Deep Learning?"
"$ref": "#/components/schemas/Prompt"
},
"repetition_penalty": {
"type": "number",
@ -1100,6 +1155,15 @@
"nullable": true,
"minimum": 0
},
"stop": {
"type": "array",
"items": {
"type": "string"
},
"description": "Up to 4 sequences where the API will stop generating further tokens.",
"example": "null",
"nullable": true
},
"stream": {
"type": "boolean"
},
@ -1121,15 +1185,6 @@
"description": "An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the\ntokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.",
"example": 0.95,
"nullable": true
},
"stop": {
"type": "array",
"items": {
"type": "string"
},
"description": "Up to 4 sequences where the API will stop generating further tokens.",
"example": "null",
"nullable": true
}
}
},
@ -1272,8 +1327,16 @@
"GenerateParameters": {
"type": "object",
"properties": {
"adapter_id": {
"type": "string",
"description": "Lora adapter id",
"default": "null",
"example": "null",
"nullable": true
},
"best_of": {
"type": "integer",
"description": "Generate best_of sequences and return the one if the highest token logprobs.",
"default": "null",
"example": 1,
"nullable": true,
@ -1282,20 +1345,24 @@
},
"decoder_input_details": {
"type": "boolean",
"description": "Whether to return decoder input token logprobs and ids.",
"default": "false"
},
"details": {
"type": "boolean",
"description": "Whether to return generation details.",
"default": "true"
},
"do_sample": {
"type": "boolean",
"description": "Activate logits sampling.",
"default": "false",
"example": true
},
"frequency_penalty": {
"type": "number",
"format": "float",
"description": "The parameter for frequency penalty. 1.0 means no penalty\nPenalize new tokens based on their existing frequency in the text so far,\ndecreasing the model's likelihood to repeat the same line verbatim.",
"default": "null",
"example": 0.1,
"nullable": true,
@ -1313,6 +1380,7 @@
"max_new_tokens": {
"type": "integer",
"format": "int32",
"description": "Maximum number of tokens to generate.",
"default": "100",
"example": "20",
"nullable": true,
@ -1321,6 +1389,7 @@
"repetition_penalty": {
"type": "number",
"format": "float",
"description": "The parameter for repetition penalty. 1.0 means no penalty.\nSee [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.",
"default": "null",
"example": 1.03,
"nullable": true,
@ -1328,6 +1397,7 @@
},
"return_full_text": {
"type": "boolean",
"description": "Whether to prepend the prompt to the generated text",
"default": "null",
"example": false,
"nullable": true
@ -1335,6 +1405,7 @@
"seed": {
"type": "integer",
"format": "int64",
"description": "Random sampling seed.",
"default": "null",
"example": "null",
"nullable": true,
@ -1346,6 +1417,7 @@
"items": {
"type": "string"
},
"description": "Stop generating tokens if a member of `stop` is generated.",
"example": [
"photographer"
],
@ -1354,6 +1426,7 @@
"temperature": {
"type": "number",
"format": "float",
"description": "The value used to module the logits distribution.",
"default": "null",
"example": 0.5,
"nullable": true,
@ -1362,6 +1435,7 @@
"top_k": {
"type": "integer",
"format": "int32",
"description": "The number of highest probability vocabulary tokens to keep for top-k-filtering.",
"default": "null",
"example": 10,
"nullable": true,
@ -1370,6 +1444,7 @@
"top_n_tokens": {
"type": "integer",
"format": "int32",
"description": "The number of highest probability vocabulary tokens to keep for top-n-filtering.",
"default": "null",
"example": 5,
"nullable": true,
@ -1379,6 +1454,7 @@
"top_p": {
"type": "number",
"format": "float",
"description": "Top-p value for nucleus sampling.",
"default": "null",
"example": 0.95,
"nullable": true,
@ -1387,6 +1463,7 @@
},
"truncate": {
"type": "integer",
"description": "Truncate inputs tokens to the given size.",
"default": "null",
"example": "null",
"nullable": true,
@ -1395,6 +1472,7 @@
"typical_p": {
"type": "number",
"format": "float",
"description": "Typical Decoding mass\nSee [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information.",
"default": "null",
"example": 0.95,
"nullable": true,
@ -1403,6 +1481,7 @@
},
"watermark": {
"type": "boolean",
"description": "Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226).",
"default": "false",
"example": true
}
@ -1495,13 +1574,14 @@
"max_concurrent_requests",
"max_best_of",
"max_stop_sequences",
"max_input_length",
"max_input_tokens",
"max_total_tokens",
"waiting_served_ratio",
"max_batch_total_tokens",
"max_waiting_tokens",
"validation_workers",
"max_client_batch_size",
"router",
"version"
],
"properties": {
@ -1538,7 +1618,7 @@
"example": "128",
"minimum": 0
},
"max_input_length": {
"max_input_tokens": {
"type": "integer",
"example": "1024",
"minimum": 0
@ -1581,6 +1661,11 @@
"example": "e985a63cdc139290c5f700ff1929f0b5942cced2",
"nullable": true
},
"router": {
"type": "string",
"description": "Router Info",
"example": "text-generation-router"
},
"sha": {
"type": "string",
"example": "null",
@ -1593,7 +1678,6 @@
},
"version": {
"type": "string",
"description": "Router Info",
"example": "0.5.0"
},
"waiting_served_ratio": {
@ -1606,13 +1690,12 @@
"Message": {
"type": "object",
"required": [
"role"
"role",
"content"
],
"properties": {
"content": {
"type": "string",
"example": "My name is David and I",
"nullable": true
"$ref": "#/components/schemas/MessageContent"
},
"name": {
"type": "string",
@ -1622,13 +1705,6 @@
"role": {
"type": "string",
"example": "user"
},
"tool_calls": {
"type": "array",
"items": {
"$ref": "#/components/schemas/ToolCall"
},
"nullable": true
}
}
},
@ -1658,6 +1734,12 @@
}
}
},
"Prompt": {
"type": "array",
"items": {
"type": "string"
}
},
"SimpleToken": {
"type": "object",
"required": [
@ -1817,9 +1899,7 @@
"$ref": "#/components/schemas/FunctionDefinition"
},
"id": {
"type": "integer",
"format": "int32",
"minimum": 0
"type": "string"
},
"type": {
"type": "string"
@ -1830,20 +1910,22 @@
"oneOf": [
{
"type": "object",
"required": [
"FunctionName"
],
"properties": {
"FunctionName": {
"type": "string"
}
}
"default": null,
"nullable": true
},
{
"type": "string",
"enum": [
"OneOf"
]
"type": "string"
},
{
"type": "object",
"required": [
"function"
],
"properties": {
"function": {
"$ref": "#/components/schemas/FunctionName"
}
}
}
]
},

View File

@ -11,7 +11,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
docker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
--device=/dev/kfd --device=/dev/dri --group-add video \
--ipc=host --shm-size 256g --net host -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:2.1.0-rocm \
ghcr.io/huggingface/text-generation-inference:2.1.1-rocm \
--model-id $model
```

View File

@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run --gpus all --shm-size 64g -p 8080:80 -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:2.1.0 \
ghcr.io/huggingface/text-generation-inference:2.1.1 \
--model-id $model
```

View File

@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
ghcr.io/huggingface/text-generation-inference:2.1.0 \
ghcr.io/huggingface/text-generation-inference:2.1.1 \
--model-id $model
```
@ -88,7 +88,7 @@ curl 127.0.0.1:8080/generate \
To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more.
```bash
docker run ghcr.io/huggingface/text-generation-inference:2.1.0 --help
docker run ghcr.io/huggingface/text-generation-inference:2.1.1 --help
```
</Tip>

View File

@ -10,6 +10,7 @@ Text Generation Inference enables serving optimized models on specific hardware
- [Llama](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)
- [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)
- [Gemma](https://huggingface.co/google/gemma-7b)
- [PaliGemma](https://huggingface.co/google/paligemma-3b-pt-224)
- [Gemma2](https://huggingface.co/google/gemma2-9b)
- [Cohere](https://huggingface.co/CohereForAI/c4ai-command-r-plus)
- [Dbrx](https://huggingface.co/databricks/dbrx-instruct)

View File

@ -5,85 +5,80 @@
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"id": 2323,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.7890625,
"text": "Test"
},
{
"id": 2009,
"logprob": -9.625,
"text": "request"
"id": 1715,
"logprob": -11.34375,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -2.3359375,
"id": 198,
"logprob": -2.5742188,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -1.8779297,
"id": 262,
"logprob": -1.6230469,
"special": false,
"text": "Test"
"text": " "
},
{
"id": 2009,
"logprob": -1.2744141,
"id": 3270,
"logprob": -2.046875,
"special": false,
"text": " \"\"\"\n"
},
{
"id": 262,
"logprob": -0.015281677,
"special": false,
"text": " "
},
{
"id": 422,
"logprob": -2.1425781,
"special": false,
"text": " if"
},
{
"id": 1715,
"logprob": -0.9238281,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -1.6933594,
"id": 13204,
"logprob": -0.076660156,
"special": false,
"text": "\n"
"text": ".method"
},
{
"id": 3057,
"logprob": -1.4648438,
"id": 624,
"logprob": -0.021987915,
"special": false,
"text": "Test"
"text": " =="
},
{
"id": 2009,
"logprob": -0.15600586,
"id": 364,
"logprob": -0.39208984,
"special": false,
"text": " request"
"text": " '"
},
{
"id": 13,
"logprob": -0.8027344,
"id": 3019,
"logprob": -0.10821533,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -0.23022461,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.0069885254,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.02218628,
"special": false,
"text": "\n"
"text": "POST"
}
],
"top_tokens": null
},
"generated_text": "\nTest request\nTest request\nTest request\n"
"generated_text": "\n \"\"\"\n if request.method == 'POST"
}

View File

@ -5,85 +5,80 @@
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"id": 2323,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.84375,
"text": "Test"
},
{
"id": 2009,
"logprob": -9.6015625,
"text": "request"
"id": 1715,
"logprob": -11.34375,
"text": " request"
}
],
"seed": 0,
"tokens": [
{
"id": 29899,
"logprob": -1.5625,
"id": 13,
"logprob": -2.2539062,
"special": false,
"text": "-"
"text": "."
},
{
"id": 1454,
"logprob": -0.20410156,
"id": 578,
"logprob": -0.15563965,
"special": false,
"text": "for"
"text": " The"
},
{
"id": 29899,
"id": 3622,
"logprob": -0.8203125,
"special": false,
"text": " server"
},
{
"id": 706,
"logprob": 0.0,
"special": false,
"text": "-"
"text": " has"
},
{
"id": 9342,
"id": 539,
"logprob": 0.0,
"special": false,
"text": "comment"
"text": " not"
},
{
"id": 29901,
"id": 3686,
"logprob": 0.0,
"special": false,
"text": ":"
"text": " yet"
},
{
"id": 396,
"logprob": -0.27685547,
"special": false,
"text": " #"
},
{
"id": 29906,
"logprob": -0.4970703,
"special": false,
"text": "2"
},
{
"id": 29900,
"logprob": -0.80615234,
"special": false,
"text": "0"
},
{
"id": 29896,
"id": 3288,
"logprob": 0.0,
"special": false,
"text": "1"
"text": " sent"
},
{
"id": 29955,
"logprob": -1.0751953,
"id": 904,
"logprob": 0.0,
"special": false,
"text": "7"
"text": " any"
},
{
"id": 828,
"logprob": 0.0,
"special": false,
"text": " data"
},
{
"id": 382,
"logprob": -1.5517578,
"special": false,
"text": ".\n\n"
}
],
"top_tokens": null
},
"generated_text": "Test request-for-comment: #2017"
"generated_text": "Test request. The server has not yet sent any data.\n\n"
}

View File

@ -6,87 +6,82 @@
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"id": 2323,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.828125,
"text": "Test"
},
{
"id": 2009,
"logprob": -9.609375,
"text": "request"
"id": 1715,
"logprob": -11.34375,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -2.3300781,
"id": 198,
"logprob": -2.5742188,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -1.8740234,
"id": 262,
"logprob": -1.6220703,
"special": false,
"text": "Test"
"text": " "
},
{
"id": 2009,
"logprob": -1.2646484,
"id": 3270,
"logprob": -2.0410156,
"special": false,
"text": " \"\"\"\n"
},
{
"id": 262,
"logprob": -0.015281677,
"special": false,
"text": " "
},
{
"id": 422,
"logprob": -2.1445312,
"special": false,
"text": " if"
},
{
"id": 1715,
"logprob": -0.92333984,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -1.7158203,
"id": 13204,
"logprob": -0.07672119,
"special": false,
"text": "\n"
"text": ".method"
},
{
"id": 3057,
"logprob": -1.4667969,
"id": 624,
"logprob": -0.021987915,
"special": false,
"text": "Test"
"text": " =="
},
{
"id": 2009,
"logprob": -0.15344238,
"id": 364,
"logprob": -0.39208984,
"special": false,
"text": " request"
"text": " '"
},
{
"id": 13,
"logprob": -0.81591797,
"id": 3019,
"logprob": -0.10638428,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -0.22973633,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.007045746,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.021957397,
"special": false,
"text": "\n"
"text": "POST"
}
],
"top_tokens": null
},
"generated_text": "\nTest request\nTest request\nTest request\n"
"generated_text": "\n \"\"\"\n if request.method == 'POST"
},
{
"details": {
@ -95,87 +90,82 @@
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"id": 2323,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.84375,
"text": "Test"
},
{
"id": 2009,
"logprob": -9.59375,
"text": "request"
"id": 1715,
"logprob": -11.34375,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -2.3378906,
"id": 198,
"logprob": -2.5742188,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -1.8779297,
"id": 262,
"logprob": -1.6220703,
"special": false,
"text": "Test"
"text": " "
},
{
"id": 2009,
"logprob": -1.2636719,
"id": 3270,
"logprob": -2.0410156,
"special": false,
"text": " \"\"\"\n"
},
{
"id": 262,
"logprob": -0.015281677,
"special": false,
"text": " "
},
{
"id": 422,
"logprob": -2.1445312,
"special": false,
"text": " if"
},
{
"id": 1715,
"logprob": -0.92333984,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -1.6992188,
"id": 13204,
"logprob": -0.07672119,
"special": false,
"text": "\n"
"text": ".method"
},
{
"id": 3057,
"logprob": -1.4589844,
"id": 624,
"logprob": -0.021987915,
"special": false,
"text": "Test"
"text": " =="
},
{
"id": 2009,
"logprob": -0.15344238,
"id": 364,
"logprob": -0.39208984,
"special": false,
"text": " request"
"text": " '"
},
{
"id": 13,
"logprob": -0.79052734,
"id": 3019,
"logprob": -0.10638428,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -0.22937012,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.007041931,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.022140503,
"special": false,
"text": "\n"
"text": "POST"
}
],
"top_tokens": null
},
"generated_text": "\nTest request\nTest request\nTest request\n"
"generated_text": "\n \"\"\"\n if request.method == 'POST"
},
{
"details": {
@ -184,87 +174,82 @@
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"id": 2323,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.84375,
"text": "Test"
},
{
"id": 2009,
"logprob": -9.609375,
"text": "request"
"id": 1715,
"logprob": -11.34375,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -2.3261719,
"id": 198,
"logprob": -2.5742188,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -1.8730469,
"id": 262,
"logprob": -1.6220703,
"special": false,
"text": "Test"
"text": " "
},
{
"id": 2009,
"logprob": -1.2587891,
"id": 3270,
"logprob": -2.0410156,
"special": false,
"text": " \"\"\"\n"
},
{
"id": 262,
"logprob": -0.015281677,
"special": false,
"text": " "
},
{
"id": 422,
"logprob": -2.1445312,
"special": false,
"text": " if"
},
{
"id": 1715,
"logprob": -0.92333984,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -1.6894531,
"id": 13204,
"logprob": -0.07672119,
"special": false,
"text": "\n"
"text": ".method"
},
{
"id": 3057,
"logprob": -1.46875,
"id": 624,
"logprob": -0.021987915,
"special": false,
"text": "Test"
"text": " =="
},
{
"id": 2009,
"logprob": -0.1541748,
"id": 364,
"logprob": -0.39208984,
"special": false,
"text": " request"
"text": " '"
},
{
"id": 13,
"logprob": -0.80322266,
"id": 3019,
"logprob": -0.10638428,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -0.22912598,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.0070495605,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.021606445,
"special": false,
"text": "\n"
"text": "POST"
}
],
"top_tokens": null
},
"generated_text": "\nTest request\nTest request\nTest request\n"
"generated_text": "\n \"\"\"\n if request.method == 'POST"
},
{
"details": {
@ -273,86 +258,81 @@
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"id": 2323,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.84375,
"text": "Test"
},
{
"id": 2009,
"logprob": -9.6015625,
"text": "request"
"id": 1715,
"logprob": -11.34375,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -2.3320312,
"id": 198,
"logprob": -2.5742188,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -1.875,
"id": 262,
"logprob": -1.6220703,
"special": false,
"text": "Test"
"text": " "
},
{
"id": 2009,
"logprob": -1.2646484,
"id": 3270,
"logprob": -2.0410156,
"special": false,
"text": " \"\"\"\n"
},
{
"id": 262,
"logprob": -0.015281677,
"special": false,
"text": " "
},
{
"id": 422,
"logprob": -2.1445312,
"special": false,
"text": " if"
},
{
"id": 1715,
"logprob": -0.92333984,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -1.6884766,
"id": 13204,
"logprob": -0.07672119,
"special": false,
"text": "\n"
"text": ".method"
},
{
"id": 3057,
"logprob": -1.4589844,
"id": 624,
"logprob": -0.021987915,
"special": false,
"text": "Test"
"text": " =="
},
{
"id": 2009,
"logprob": -0.15185547,
"id": 364,
"logprob": -0.39208984,
"special": false,
"text": " request"
"text": " '"
},
{
"id": 13,
"logprob": -0.79833984,
"id": 3019,
"logprob": -0.10638428,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -0.22827148,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.006996155,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.021560669,
"special": false,
"text": "\n"
"text": "POST"
}
],
"top_tokens": null
},
"generated_text": "\nTest request\nTest request\nTest request\n"
"generated_text": "\n \"\"\"\n if request.method == 'POST"
}
]

View File

@ -1,130 +1,124 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 20,
"finish_reason": "eos_token",
"generated_tokens": 19,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 415,
"logprob": -0.039886475,
"logprob": -0.03665161,
"special": false,
"text": " The"
},
{
"id": 12072,
"logprob": -0.1430664,
"logprob": -0.13549805,
"special": false,
"text": " cow"
},
{
"id": 349,
"logprob": -0.056488037,
"logprob": -0.05819702,
"special": false,
"text": " is"
},
{
"id": 6328,
"logprob": -0.6855469,
"logprob": -0.6826172,
"special": false,
"text": " standing"
},
{
"id": 356,
"logprob": -0.1685791,
"logprob": -0.1607666,
"special": false,
"text": " on"
},
{
"id": 272,
"logprob": -0.50097656,
"logprob": -0.5073242,
"special": false,
"text": " the"
},
{
"id": 10305,
"logprob": -0.017303467,
"logprob": -0.016418457,
"special": false,
"text": " beach"
},
{
"id": 304,
"logprob": -1.3564453,
"logprob": -1.3916016,
"special": false,
"text": " and"
},
{
"id": 272,
"logprob": -0.017868042,
"logprob": -0.020217896,
"special": false,
"text": " the"
},
{
"id": 13088,
"logprob": -0.0027103424,
"logprob": -0.0028133392,
"special": false,
"text": " chicken"
},
{
"id": 349,
"logprob": -0.003156662,
"logprob": -0.003145218,
"special": false,
"text": " is"
},
{
"id": 6398,
"logprob": -0.37304688,
"logprob": -0.37060547,
"special": false,
"text": " sitting"
},
{
"id": 356,
"logprob": -0.034576416,
"logprob": -0.034851074,
"special": false,
"text": " on"
},
{
"id": 264,
"logprob": -0.29418945,
"logprob": -0.2878418,
"special": false,
"text": " a"
},
{
"id": 17972,
"logprob": -0.042877197,
"logprob": -0.046051025,
"special": false,
"text": " pile"
},
{
"id": 302,
"logprob": -0.00028443336,
"logprob": -0.00028848648,
"special": false,
"text": " of"
},
{
"id": 2445,
"logprob": -0.023223877,
"logprob": -0.025772095,
"special": false,
"text": " money"
},
{
"id": 28723,
"logprob": -0.018157959,
"logprob": -0.018127441,
"special": false,
"text": "."
},
{
"id": 32002,
"logprob": -0.00018393993,
"logprob": -0.00019824505,
"special": true,
"text": "<end_of_utterance>"
},
{
"id": 2,
"logprob": -1.1920929e-07,
"special": true,
"text": "</s>"
}
],
"top_tokens": null

View File

@ -0,0 +1,89 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.8359375,
"text": "Test"
},
{
"id": 2009,
"logprob": -9.6171875,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -2.3417969,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -1.8730469,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -1.2626953,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -1.7060547,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -1.4482422,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.15246582,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.796875,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -0.22766113,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.007045746,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.021759033,
"special": false,
"text": "\n"
}
],
"top_tokens": null
},
"generated_text": "\nTest request\nTest request\nTest request\n"
}

View File

@ -0,0 +1,89 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.7890625,
"text": "Test"
},
{
"id": 2009,
"logprob": -9.625,
"text": "request"
}
],
"seed": 0,
"tokens": [
{
"id": 29899,
"logprob": -1.4980469,
"special": false,
"text": "-"
},
{
"id": 1454,
"logprob": -0.19433594,
"special": false,
"text": "for"
},
{
"id": 29899,
"logprob": 0.0,
"special": false,
"text": "-"
},
{
"id": 9342,
"logprob": 0.0,
"special": false,
"text": "comment"
},
{
"id": 29901,
"logprob": 0.0,
"special": false,
"text": ":"
},
{
"id": 396,
"logprob": -0.27392578,
"special": false,
"text": " #"
},
{
"id": 29906,
"logprob": -0.49389648,
"special": false,
"text": "2"
},
{
"id": 29900,
"logprob": -0.81103516,
"special": false,
"text": "0"
},
{
"id": 29896,
"logprob": 0.0,
"special": false,
"text": "1"
},
{
"id": 29955,
"logprob": -1.0800781,
"special": false,
"text": "7"
}
],
"top_tokens": null
},
"generated_text": "Test request-for-comment: #2017"
}

View File

@ -0,0 +1,358 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.8828125,
"text": "Test"
},
{
"id": 2009,
"logprob": -9.5859375,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -2.3359375,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -1.8623047,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -1.2451172,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -1.6923828,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -1.4492188,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.15197754,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.8022461,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -0.22583008,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.007095337,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.021652222,
"special": false,
"text": "\n"
}
],
"top_tokens": null
},
"generated_text": "\nTest request\nTest request\nTest request\n"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.796875,
"text": "Test"
},
{
"id": 2009,
"logprob": -9.625,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -2.3476562,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -1.8789062,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -1.2734375,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -1.703125,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -1.4677734,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.15454102,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.7973633,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -0.23278809,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.006980896,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.022033691,
"special": false,
"text": "\n"
}
],
"top_tokens": null
},
"generated_text": "\nTest request\nTest request\nTest request\n"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.9296875,
"text": "Test"
},
{
"id": 2009,
"logprob": -9.5703125,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -2.3203125,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -1.8486328,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -1.2480469,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -1.7060547,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -1.4511719,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.1529541,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.81396484,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -0.22180176,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.007133484,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.021835327,
"special": false,
"text": "\n"
}
],
"top_tokens": null
},
"generated_text": "\nTest request\nTest request\nTest request\n"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 1,
"logprob": null,
"text": "<s>"
},
{
"id": 4321,
"logprob": -9.84375,
"text": "Test"
},
{
"id": 2009,
"logprob": -9.6171875,
"text": "request"
}
],
"seed": null,
"tokens": [
{
"id": 13,
"logprob": -2.3261719,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -1.8691406,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -1.2597656,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -1.7070312,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -1.4550781,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.1538086,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.79345703,
"special": false,
"text": "\n"
},
{
"id": 3057,
"logprob": -0.22924805,
"special": false,
"text": "Test"
},
{
"id": 2009,
"logprob": -0.0070266724,
"special": false,
"text": " request"
},
{
"id": 13,
"logprob": -0.021942139,
"special": false,
"text": "\n"
}
],
"top_tokens": null
},
"generated_text": "\nTest request\nTest request\nTest request\n"
}
]

View File

@ -5,7 +5,9 @@ from testing_utils import is_flaky_async, SYSTEM, require_backend_async
@pytest.fixture(scope="module")
def flash_llama_gptq_handle(launcher):
with launcher("huggingface/llama-7b-gptq", num_shard=2, quantize="gptq") as handle:
with launcher(
"astronomer/Llama-3-8B-Instruct-GPTQ-4-Bit", num_shard=2, quantize="gptq"
) as handle:
yield handle

View File

@ -62,7 +62,7 @@ async def test_flash_idefics2_two_images(flash_idefics2_next, response_snapshot)
response.generated_text
== " The cow is standing on the beach and the chicken is sitting on a pile of money."
), f"{repr(response.generated_text)}"
assert response.details.generated_tokens == 20
assert response.details.generated_tokens == 19
assert response == response_snapshot

View File

@ -433,8 +433,17 @@ pub struct CompletionRequest {
pub stop: Option<Vec<String>>,
}
#[derive(Clone, Serialize, ToSchema)]
#[serde(tag = "object")]
enum Completion {
#[serde(rename = "text_completion")]
Chunk(Chunk),
#[serde(rename = "text_completion")]
Final(CompletionFinal),
}
#[derive(Clone, Deserialize, Serialize, ToSchema, Default)]
pub(crate) struct Completion {
pub(crate) struct CompletionFinal {
pub id: String,
#[schema(example = "1706270835")]
pub created: u64,
@ -453,6 +462,15 @@ pub(crate) struct CompletionComplete {
pub finish_reason: String,
}
#[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct Chunk {
pub id: String,
pub created: u64,
pub choices: Vec<CompletionComplete>,
pub model: String,
pub system_fingerprint: String,
}
#[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletion {
pub id: String,
@ -614,15 +632,6 @@ impl ChatCompletion {
}
}
}
#[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct CompletionCompleteChunk {
pub id: String,
pub created: u64,
pub choices: Vec<CompletionComplete>,
pub model: String,
pub system_fingerprint: String,
}
#[derive(Clone, Serialize, ToSchema)]
pub(crate) struct ChatCompletionChunk {
pub id: String,

View File

@ -1,5 +1,6 @@
use axum::http::HeaderValue;
use clap::Parser;
use clap::Subcommand;
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
use hf_hub::{Cache, Repo, RepoType};
use opentelemetry::sdk::propagation::TraceContextPropagator;
@ -27,6 +28,9 @@ use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer};
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
#[command(subcommand)]
command: Option<Commands>,
#[clap(default_value = "128", long, env)]
max_concurrent_requests: usize,
#[clap(default_value = "2", long, env)]
@ -85,10 +89,15 @@ struct Args {
max_client_batch_size: usize,
}
#[derive(Debug, Subcommand)]
enum Commands {
PrintSchema,
}
#[tokio::main]
async fn main() -> Result<(), RouterError> {
// Get args
let args = Args::parse();
// Pattern match configuration
let Args {
max_concurrent_requests,
@ -119,10 +128,17 @@ async fn main() -> Result<(), RouterError> {
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
command,
} = args;
// Launch Tokio runtime
init_logging(otlp_endpoint, otlp_service_name, json_output);
let print_schema_command = match command {
Some(Commands::PrintSchema) => true,
None => {
// only init logging if we are not running the print schema command
init_logging(otlp_endpoint, otlp_service_name, json_output);
false
}
};
// Validate args
if max_input_tokens >= max_total_tokens {
@ -388,6 +404,7 @@ async fn main() -> Result<(), RouterError> {
messages_api_enabled,
disable_grammar_support,
max_client_batch_size,
print_schema_command,
)
.await?;
Ok(())

View File

@ -19,8 +19,8 @@ use crate::{
use crate::{
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob,
ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk,
CompletionRequest, CompletionType, DeltaToolCall, Function, Tool, VertexRequest,
ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal,
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, VertexRequest,
VertexResponse,
};
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolType};
@ -705,7 +705,7 @@ async fn completions(
.as_secs();
event
.json_data(CompletionCompleteChunk {
.json_data(Completion::Chunk(Chunk {
id: "".to_string(),
created: current_time,
@ -718,7 +718,7 @@ async fn completions(
model: model_id.clone(),
system_fingerprint: system_fingerprint.clone(),
})
}))
.unwrap_or_else(|_e| Event::default())
};
@ -931,7 +931,7 @@ async fn completions(
.collect::<Result<Vec<_>, _>>()
.map_err(|(status, Json(err))| (status, Json(err)))?;
let response = Completion {
let response = Completion::Final(CompletionFinal {
id: "".to_string(),
created: current_time,
model: info.model_id.clone(),
@ -946,7 +946,7 @@ async fn completions(
completion_tokens,
total_tokens,
},
};
});
// headers similar to `generate` but aggregated
let mut headers = HeaderMap::new();
@ -1387,10 +1387,10 @@ async fn tokenize(
/// Prometheus metrics scrape endpoint
#[utoipa::path(
get,
tag = "Text Generation Inference",
path = "/metrics",
responses((status = 200, description = "Prometheus Metrics", body = String))
get,
tag = "Text Generation Inference",
path = "/metrics",
responses((status = 200, description = "Prometheus Metrics", body = String))
)]
async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
prom_handle.render()
@ -1430,6 +1430,7 @@ pub async fn run(
messages_api_enabled: bool,
grammar_support: bool,
max_client_batch_size: usize,
print_schema_command: bool,
) -> Result<(), WebServerError> {
// OpenAPI documentation
#[derive(OpenApi)]
@ -1463,7 +1464,10 @@ pub async fn run(
ChatCompletion,
CompletionRequest,
CompletionComplete,
CompletionCompleteChunk,
Chunk,
Completion,
CompletionFinal,
Prompt,
GenerateParameters,
PrefillToken,
Token,
@ -1500,6 +1504,12 @@ pub async fn run(
struct ApiDoc;
// Create state
if print_schema_command {
let api_doc = ApiDoc::openapi();
let api_doc = serde_json::to_string_pretty(&api_doc).unwrap();
println!("{}", api_doc);
std::process::exit(0);
}
// Open connection, get model info and warmup
let (scheduler, health_ext, shard_info, max_batch_total_tokens): (

View File

@ -8,6 +8,9 @@ from text_generation_server.pb import generate_pb2
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.utils import weight_hub_files, download_weights
from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOMSharded
from text_generation_server.models.custom_modeling.bloom_modeling import (
BloomForCausalLM,
)
@pytest.fixture(scope="session")
@ -16,7 +19,10 @@ def default_bloom():
revision = "main"
filenames = weight_hub_files(model_id, revision, ".safetensors")
download_weights(filenames, model_id, revision)
return BLOOMSharded(model_id)
return BLOOMSharded(
model_id,
model_class=BloomForCausalLM,
)
@pytest.fixture(scope="session")

View File

@ -10,7 +10,7 @@ from text_generation_server.models.causal_lm import CausalLM, CausalLMBatch
@pytest.fixture(scope="session")
def default_causal_lm():
return CausalLM("gpt2")
return CausalLM.fallback("gpt2")
@pytest.fixture(scope="session")

View File

@ -1,13 +1,12 @@
import pytest
from text_generation_server.pb import generate_pb2
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.causal_lm import CausalLMBatch, CausalLM
@pytest.fixture(scope="session")
def default_santacoder():
return SantaCoder("bigcode/santacoder")
return CausalLM.fallback(model_id="bigcode/santacoder")
@pytest.fixture

View File

@ -20,7 +20,7 @@ def mt0_small_tokenizer():
@pytest.fixture(scope="session")
def default_seq2seq_lm():
return Seq2SeqLM("bigscience/mt0-small")
return Seq2SeqLM.fallback("bigscience/mt0-small")
@pytest.fixture

View File

@ -110,7 +110,7 @@ class PositionRotaryEmbedding(nn.Module):
beta_fast=32,
beta_slow=1,
)
elif rope_scaling["type"] == "su":
elif rope_scaling["type"] in ["su", "longrope"]:
short_factor = torch.tensor(
rope_scaling["short_factor"], dtype=torch.float32, device=device
)

View File

@ -11,17 +11,27 @@ from pathlib import Path
from text_generation_server.utils.speculate import get_speculate, set_speculate
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.bloom import BLOOMSharded
from text_generation_server.models.mpt import MPTSharded
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
from text_generation_server.models.custom_modeling.mpt_modeling import (
MPTForCausalLM,
)
from text_generation_server.models.bloom import BloomCausalLMBatch
from text_generation_server.models.custom_modeling.bloom_modeling import (
BloomForCausalLM,
)
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
from text_generation_server.models.rw import RW
from text_generation_server.models.opt import OPTSharded
from text_generation_server.models.galactica import GalacticaSharded
from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.t5 import T5Sharded
from text_generation_server.models.gpt_neox import GPTNeoxSharded
from text_generation_server.models.phi import Phi
from text_generation_server.models.galactica import GalacticaCausalLMBatch
from text_generation_server.models.custom_modeling.neox_modeling import (
GPTNeoxForCausalLM,
)
from text_generation_server.models.custom_modeling.phi_modeling import (
PhiConfig,
PhiForCausalLM,
)
from text_generation_server.models.custom_modeling.t5_modeling import (
T5ForConditionalGeneration,
)
from text_generation_server.utils.import_utils import SYSTEM
@ -41,9 +51,6 @@ __all__ = [
"CausalLM",
"GalacticaSharded",
"Seq2SeqLM",
"SantaCoder",
"OPTSharded",
"T5Sharded",
"get_model",
]
@ -53,38 +60,65 @@ FLASH_ATTENTION = True
try:
from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.flash_rw import FlashRWSharded
from text_generation_server.models.flash_gpt2 import FlashGPT2
from text_generation_server.models.flash_neox import FlashNeoXSharded
from text_generation_server.models.flash_llama import (
FlashLlama,
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM,
)
from text_generation_server.models.flash_qwen2 import (
FlashQwen2,
from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
FlashCohereForCausalLM,
)
from text_generation_server.models.flash_cohere import (
FlashCohere,
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
FlashGemmaForCausalLM,
)
from text_generation_server.models.flash_gemma import (
FlashGemma,
from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
FlashGemma2ForCausalLM,
)
from text_generation_server.models.flash_gemma2 import (
FlashGemma2,
from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (
FlashDbrxForCausalLM,
DbrxConfig,
)
from text_generation_server.models.custom_modeling.flash_rw_modeling import (
RWConfig,
FlashRWForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
FlashGPTNeoXForCausalLM,
)
from text_generation_server.models.pali_gemma import (
PaliGemma,
PaliGemmaBatch,
)
from text_generation_server.models.flash_santacoder import (
FlashSantacoderSharded,
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
PaliGemmaForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
FlashPhiForCausalLM,
)
from text_generation_server.models.idefics import IDEFICSSharded
from text_generation_server.models.llava_next import LlavaNext
from text_generation_server.models.idefics2 import Idefics2
from text_generation_server.models.flash_mistral import FlashMistral
from text_generation_server.models.flash_mixtral import FlashMixtral
from text_generation_server.models.flash_phi import FlashPhi
from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
from text_generation_server.models.flash_dbrx import FlashDbrx
from text_generation_server.models.custom_modeling.llava_next import (
LlavaNextForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
FlashSantacoderForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
FlashStarcoder2ForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
Qwen2ForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
FlashMistralForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
FlashMixtralForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (
FlashGPT2ForCausalLM,
)
from text_generation_server.models.custom_modeling.idefics2 import (
Idefics2ForConditionalGeneration,
)
from text_generation_server.layers.attention import SUPPORTS_WINDOWING
except ImportError as e:
logger.warning(f"Could not import Flash Attention enabled models: {e}")
@ -93,21 +127,7 @@ except ImportError as e:
if FLASH_ATTENTION:
__all__.append(FlashCausalLM)
__all__.append(FlashGPT2)
__all__.append(FlashNeoXSharded)
__all__.append(FlashRWSharded)
__all__.append(FlashSantacoderSharded)
__all__.append(FlashLlama)
__all__.append(IDEFICSSharded)
__all__.append(FlashMistral)
__all__.append(FlashMixtral)
__all__.append(FlashDbrx)
__all__.append(FlashPhi)
__all__.append(FlashQwen2)
__all__.append(FlashStarcoder2)
__all__.append(FlashGemma)
__all__.append(FlashGemma2)
__all__.append(FlashCohere)
MAMBA_AVAILABLE = True
MAMBA_IMPORT_ERROR = None
@ -150,6 +170,11 @@ class ModelType(enum.Enum):
"name": "Gemma",
"url": "https://huggingface.co/google/gemma-7b",
}
PALIGEMMA = {
"type": "paligemma",
"name": "PaliGemma",
"url": "https://huggingface.co/google/paligemma-3b-pt-224",
}
GEMMA2 = {
"type": "gemma2",
"name": "Gemma2",
@ -452,13 +477,16 @@ def get_model(
)
if model_id.startswith("facebook/galactica"):
return GalacticaSharded(
model_id,
revision,
return CausalLM(
model_id=model_id,
# Yes galactica is just an OPT model.
model_class=OPTForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
batch_class=GalacticaCausalLMBatch,
)
if (
@ -467,22 +495,26 @@ def get_model(
and model_id.startswith("bigcode/")
):
if FLASH_ATTENTION:
return FlashSantacoderSharded(
model_id,
revision,
return FlashCausalLM(
model_id=model_id,
model_class=FlashSantacoderForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
aliases={"transformer.wte.weight": ["lm_head.weight"]},
num_kv_heads=1,
)
elif sharded:
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
)
else:
return SantaCoder(
model_id,
revision,
return CausalLM.fallback(
model_id=model_id,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
@ -490,38 +522,44 @@ def get_model(
)
if model_type == BLOOM:
return BLOOMSharded(
model_id,
revision,
return CausalLM(
model_id=model_id,
model_class=BloomForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
batch_class=BloomCausalLMBatch,
)
elif model_type == MPT:
return MPTSharded(
model_id,
revision,
return CausalLM(
model_id=model_id,
model_class=MPTForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
batch_class=CausalLMBatchKeysLast,
)
elif model_type == GPT2:
if FLASH_ATTENTION:
try:
return FlashGPT2(
model_id,
revision,
return FlashCausalLM(
model_id=model_id,
model_class=FlashGPT2ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
except RuntimeError as e:
# Lots of legacy models with various weight names.
logger.warning(f"Couldn't load flash gpt2 variant: {e}")
return CausalLM(
return CausalLM.fallback(
model_id,
revision,
quantize=quantize,
@ -532,7 +570,7 @@ def get_model(
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
else:
return CausalLM(
return CausalLM.fallback(
model_id,
revision,
quantize=quantize,
@ -542,25 +580,28 @@ def get_model(
)
elif model_type == GPT_NEOX:
if FLASH_ATTENTION:
return FlashNeoXSharded(
model_id,
revision,
return FlashCausalLM(
model_id=model_id,
model_class=FlashGPTNeoXForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif sharded:
return GPTNeoxSharded(
model_id,
revision,
return CausalLM(
model_id=model_id,
model_class=GPTNeoxForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
else:
return CausalLM(
return CausalLM.fallback(
model_id,
revision,
quantize=quantize,
@ -571,16 +612,18 @@ def get_model(
elif model_type == PHI:
if FLASH_ATTENTION:
return FlashPhi(
model_id,
revision,
return FlashCausalLM(
model_id=model_id,
model_class=FlashPhiForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
else:
return CausalLM(
return CausalLM.fallback(
model_id,
revision,
quantize=quantize,
@ -595,9 +638,11 @@ def get_model(
"Legacy phi-msft is not supported with Flash Attention"
)
else:
return Phi(
model_id,
revision,
return CausalLM(
model_id=model_id,
model_class=PhiForCausalLM,
config_class=PhiConfig,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
@ -606,9 +651,10 @@ def get_model(
elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
if FLASH_ATTENTION:
return FlashLlama(
model_id,
revision,
return FlashCausalLM(
model_id=model_id,
model_class=FlashLlamaForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
@ -618,7 +664,7 @@ def get_model(
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
else:
return CausalLM(
return CausalLM.fallback(
model_id,
revision,
quantize=quantize,
@ -628,18 +674,22 @@ def get_model(
)
if model_type == GEMMA:
if FLASH_ATTENTION:
return FlashGemma(
model_id,
revision,
return FlashCausalLM(
model_id=model_id,
model_class=FlashGemmaForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
# Works better for these models
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
else:
return CausalLM(
return CausalLM.fallback(
model_id,
revision,
quantize=quantize,
@ -649,18 +699,22 @@ def get_model(
)
elif model_type == GEMMA2:
if FLASH_ATTENTION:
return FlashGemma2(
model_id,
revision,
return FlashCausalLM(
model_id=model_id,
model_class=FlashGemma2ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
# Works better for these models
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
else:
return CausalLM(
return CausalLM.fallback(
model_id,
revision,
quantize=quantize,
@ -671,18 +725,20 @@ def get_model(
if model_type == COHERE:
if FLASH_ATTENTION:
return FlashCohere(
model_id,
revision,
return FlashCausalLM(
model_id=model_id,
model_class=FlashCohereForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
else:
return CausalLM(
return CausalLM.fallback(
model_id,
revision,
quantize=quantize,
@ -693,18 +749,23 @@ def get_model(
if model_type == DBRX:
if FLASH_ATTENTION:
return FlashDbrx(
model_id,
revision,
return FlashCausalLM(
model_id=model_id,
model_class=FlashDbrxForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
# Dbrx works better in bfloat16.
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=DbrxConfig,
)
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
else:
return CausalLM(
return CausalLM.fallback(
model_id,
revision,
quantize=quantize,
@ -718,27 +779,37 @@ def get_model(
if FLASH_ATTENTION:
if config_dict.get("alibi", False):
raise NotImplementedError("sharded is not supported for this model")
return FlashRWSharded(
model_id,
revision,
return FlashCausalLM(
model_id=model_id,
model_class=FlashRWForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
aliases={
"lm_head.weight": ["transformer.word_embeddings.weight"],
"transformer.word_embeddings.weight": ["lm_head.weight"],
},
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=RWConfig,
)
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
else:
if FLASH_ATTENTION and not config_dict.get("alibi", False):
return FlashRWSharded(
model_id,
revision,
return FlashCausalLM(
model_id=model_id,
model_class=FlashRWForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=RWConfig,
)
else:
return RW(
return CausalLM.fallback(
model_id,
revision,
quantize=quantize,
@ -749,18 +820,20 @@ def get_model(
if model_type == MISTRAL:
if FLASH_ATTENTION:
return FlashMistral(
model_id,
revision,
return FlashCausalLM(
model_id=model_id,
model_class=FlashMistralForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
else:
return CausalLM(
return CausalLM.fallback(
model_id,
revision,
quantize=quantize,
@ -771,18 +844,20 @@ def get_model(
if model_type == MIXTRAL:
if FLASH_ATTENTION:
return FlashMixtral(
model_id,
revision,
return FlashCausalLM(
model_id=model_id,
model_class=FlashMixtralForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
else:
return CausalLM(
return CausalLM.fallback(
model_id,
revision,
quantize=quantize,
@ -793,19 +868,22 @@ def get_model(
if model_type == STARCODER2:
if FLASH_ATTENTION:
return FlashStarcoder2(
model_id,
revision,
return FlashCausalLM(
model_id=model_id,
model_class=FlashStarcoder2ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif sharded:
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
)
else:
return CausalLM(
return CausalLM.fallback(
model_id,
revision,
quantize=quantize,
@ -816,17 +894,20 @@ def get_model(
if model_type == QWEN2:
if FLASH_ATTENTION:
return FlashQwen2(
model_id,
revision,
return FlashCausalLM(
model_id=model_id,
model_class=Qwen2ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
else:
return CausalLM(
return CausalLM.fallback(
model_id,
revision,
quantize=quantize,
@ -836,9 +917,10 @@ def get_model(
)
if model_type == OPT:
return OPTSharded(
model_id,
revision,
return CausalLM(
model_id=model_id,
model_class=OPTForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
@ -846,13 +928,20 @@ def get_model(
)
if model_type == T5:
return T5Sharded(
model_id,
revision,
return Seq2SeqLM(
model_id=model_id,
model_class=T5ForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
aliases={
"shared.weight": [
"encoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
]
},
)
if model_type == IDEFICS:
if FLASH_ATTENTION:
@ -868,34 +957,45 @@ def get_model(
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == IDEFICS2:
if FLASH_ATTENTION:
return Idefics2(
model_id,
revision,
return VlmCausalLM(
model_id=model_id,
model_class=Idefics2ForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
# XXX: Extremely important to cap resolution in order to limit
# VRAM usage.
processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
)
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == "paligemma":
if model_type == PALIGEMMA:
if FLASH_ATTENTION:
return PaliGemma(
model_id,
revision,
return VlmCausalLM(
model_id=model_id,
model_class=PaliGemmaForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
# Works better for these models
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
batch_class=PaliGemmaBatch,
)
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == LLAVA_NEXT:
if FLASH_ATTENTION:
return LlavaNext(
model_id,
revision,
return VlmCausalLM(
model_class=LlavaNextForConditionalGeneration,
model_id=model_id,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
@ -919,7 +1019,7 @@ def get_model(
elif quantize == "exl2":
raise NotImplementedError("exl2 quantization is not supported for AutoModel")
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return CausalLM(
return CausalLM.fallback(
model_id,
revision,
quantize=quantize,
@ -928,7 +1028,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
return Seq2SeqLM(
return Seq2SeqLM.fallback(
model_id,
revision,
quantize=quantize,
@ -940,7 +1040,7 @@ def get_model(
auto_map = config_dict.get("auto_map", None)
if trust_remote_code and auto_map is not None:
if "AutoModelForCausalLM" in auto_map.keys():
return CausalLM(
return CausalLM.fallback(
model_id,
revision,
quantize=quantize,
@ -949,7 +1049,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
if "AutoModelForSeq2SeqLM" in auto_map.keys():
return Seq2SeqLM(
return Seq2SeqLM.fallback(
model_id,
revision,
quantize=quantize,

View File

@ -4,22 +4,12 @@ import torch.distributed
from typing import Optional, Type
from transformers import (
AutoTokenizer,
AutoConfig,
PreTrainedTokenizerBase,
)
from text_generation_server.models.custom_modeling.bloom_modeling import (
BloomForCausalLM,
)
from text_generation_server.models import CausalLM
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
class BloomCausalLMBatch(CausalLMBatch):
@ -37,69 +27,6 @@ class BloomCausalLMBatch(CausalLMBatch):
class BLOOMSharded(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id,
revision=revision,
slow_but_exact=False,
tp_parallel=True,
trust_remote_code=trust_remote_code,
)
config.pad_token_id = 3
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames,
device=device,
dtype=dtype,
process_group=self.process_group,
prefix="transformer",
)
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = BloomForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
@property
def batch_type(self) -> Type[CausalLMBatch]:
return BloomCausalLMBatch

View File

@ -1,13 +1,25 @@
import torch
import time
import torch.distributed
from dataclasses import dataclass
from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
from transformers import (
AutoConfig,
AutoTokenizer,
AutoModelForCausalLM,
PreTrainedTokenizerBase,
)
from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.models import Model
from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.models.types import (
Batch,
@ -478,10 +490,88 @@ class CausalLMBatch(Batch):
return len(self.requests)
@dataclass
class CausalLMBatchKeysLast(Batch):
keys_head_dim_last: bool = False
class CausalLM(Model):
def __init__(
self,
model_id: str,
model_class,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
default_dtype=torch.float16,
trust_remote_code: bool = False,
tokenizer_class=AutoTokenizer,
config_class=AutoConfig,
batch_class=CausalLMBatch,
):
self.batch_class = batch_class
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = default_dtype if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = default_dtype if dtype is None else dtype
else:
device = torch.device("cpu")
# Float16 doesn't exist on target.
dtype = torch.bfloat16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
tokenizer = tokenizer_class.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = config_class.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
)
config.quantize = quantize
config.speculator = speculator
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = config.pad_token_id
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
)
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
prefix = ""
model = model_class(prefix, config, weights)
torch.distributed.barrier(group=self.process_group)
super().__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
@classmethod
def fallback(
cls,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
@ -537,7 +627,12 @@ class CausalLM(Model):
else:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
super(CausalLM, self).__init__(
self = cls.__new__(
cls,
)
self.batch_class = CausalLMBatch
super().__init__(
self,
model_id=model_id,
model=model,
tokenizer=tokenizer,
@ -545,15 +640,11 @@ class CausalLM(Model):
dtype=dtype,
device=device,
)
return self
@property
def batch_type(self) -> Type[CausalLMBatch]:
return CausalLMBatch
def decode(self, generated_ids: List[int]) -> str:
return self.tokenizer.decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return self.batch_class
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None

View File

@ -815,7 +815,7 @@ class BloomModel(BloomPreTrainedModel):
class BloomForCausalLM(BloomPreTrainedModel):
def __init__(self, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__(config)
self.transformer = BloomModel(config, weights)

View File

@ -446,7 +446,7 @@ class CLIPEncoder(nn.Module):
class CLIPTextTransformer(nn.Module):
def __init__(self, config: CLIPTextConfig):
def __init__(self, prefix: str, config: CLIPTextConfig):
super().__init__()
self.config = config
embed_dim = config.hidden_size
@ -536,9 +536,9 @@ class CLIPTextModel(CLIPPreTrainedModel):
_no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
def __init__(self, config: CLIPTextConfig):
def __init__(self, prefix, config: CLIPTextConfig):
super().__init__(config)
self.text_model = CLIPTextTransformer(config)
self.text_model = CLIPTextTransformer(prefix, config)
# Initialize weights and apply final processing
self.post_init()

View File

@ -363,9 +363,9 @@ class CohereMLP(nn.Module):
class FlashCohereLayer(nn.Module):
def __init__(self, layer_id, config, weights):
def __init__(self, prefix: str, layer_id, config, weights):
super().__init__()
prefix = f"model.layers.{layer_id}"
prefix = f"{prefix}.layers.{layer_id}"
self.self_attn = FlashCohereAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
@ -416,18 +416,19 @@ class FlashCohereLayer(nn.Module):
class FlashCohereModel(torch.nn.Module):
def __init__(self, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__()
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights
prefix=f"{prefix}.embed_tokens", weights=weights
)
self.layers = nn.ModuleList(
[
FlashCohereLayer(
prefix,
layer_id,
config,
weights,
@ -436,7 +437,7 @@ class FlashCohereModel(torch.nn.Module):
]
)
self.norm = FastLayerNorm.load_no_bias(
prefix="model.norm", weights=weights, eps=config.layer_norm_eps
prefix=f"{prefix}.norm", weights=weights, eps=config.layer_norm_eps
)
self.gradient_checkpointing = False
@ -486,10 +487,15 @@ class FlashCohereModel(torch.nn.Module):
class FlashCohereForCausalLM(torch.nn.Module):
def __init__(self, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__()
self.model = FlashCohereModel(config, weights)
if not prefix:
prefix = "model"
else:
prefix = f"{prefix}.model"
self.model = FlashCohereModel(prefix, config, weights)
try:
self.lm_head = SpeculativeHead.load(
config,
@ -499,7 +505,7 @@ class FlashCohereForCausalLM(torch.nn.Module):
except RuntimeError:
self.lm_head = SpeculativeHead.load(
config,
prefix="model.embed_tokens",
prefix=f"{prefix}.embed_tokens",
weights=weights,
)
self.logit_scale = config.logit_scale

View File

@ -593,9 +593,9 @@ class DenseMoE(nn.Module):
class DbrxLayer(nn.Module):
def __init__(self, layer_id, config, weights):
def __init__(self, prefix: str, layer_id, config, weights):
super().__init__()
prefix = f"transformer.blocks.{layer_id}"
prefix = f"{prefix}.blocks.{layer_id}"
self.attn = DbrxNormAttentionNorm(
prefix=f"{prefix}.norm_attn_norm", config=config, weights=weights
@ -637,16 +637,17 @@ class DbrxLayer(nn.Module):
class DbrxModel(torch.nn.Module):
def __init__(self, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__()
self.embed_tokens = TensorParallelEmbedding(
prefix="transformer.wte", weights=weights
prefix=f"{prefix}.wte", weights=weights
)
self.layers = nn.ModuleList(
[
DbrxLayer(
prefix,
layer_id,
config,
weights,
@ -655,7 +656,7 @@ class DbrxModel(torch.nn.Module):
]
)
self.norm = FastLayerNorm.load_no_bias(
prefix="transformer.norm_f", weights=weights, eps=1e-5
prefix=f"{prefix}.norm_f", weights=weights, eps=1e-5
)
self.head_size = self.layers[0].attn.self_attn.head_size
@ -702,10 +703,15 @@ class DbrxModel(torch.nn.Module):
class FlashDbrxForCausalLM(torch.nn.Module):
def __init__(self, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__()
self.model = DbrxModel(config, weights)
if not prefix:
prefix = "transformer"
else:
prefix = f"{prefix}.transformer"
self.model = DbrxModel(prefix, config, weights)
self.lm_head = SpeculativeHead.load(
config,
prefix="lm_head",

View File

@ -102,7 +102,7 @@ class Gemma2Config(PretrainedConfig):
class Gemma2FastRMSNorm(FastRMSNorm):
@classmethod
def load(cls, prefix, weights, eps=1e-6):
def load(cls, prefix: str, weights, eps=1e-6):
dtype = weights.dtype
weights.dtype = torch.float32
weight = weights.get_tensor(f"{prefix}.weight") + 1
@ -123,7 +123,7 @@ class Gemma2FastRMSNorm(FastRMSNorm):
return hidden_states.to(self.dtype), residual
def load_attention(config, prefix, weights):
def load_attention(config, prefix: str, weights):
if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights)
else:
@ -305,7 +305,7 @@ class Gemma2MLP(nn.Module):
class FlashGemma2Layer(nn.Module):
def __init__(self, prefix, config, weights, causal: bool, is_sliding: bool):
def __init__(self, prefix: str, config, weights, causal: bool, is_sliding: bool):
super().__init__()
self.self_attn = FlashGemma2Attention(
prefix=f"{prefix}.self_attn",
@ -376,7 +376,7 @@ class FlashGemma2Layer(nn.Module):
class FlashGemma2Model(torch.nn.Module):
def __init__(self, prefix, config, weights, causal: bool):
def __init__(self, prefix: str, config, weights, causal: bool):
super().__init__()
process_group = weights.process_group
@ -442,7 +442,7 @@ class FlashGemma2Model(torch.nn.Module):
class FlashGemma2ForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights, causal: bool):
def __init__(self, prefix: str, config, weights, *, causal: bool = True):
super().__init__()
embed_norm = config.hidden_size**0.5

View File

@ -102,7 +102,7 @@ class GemmaConfig(PretrainedConfig):
class GemmaFastRMSNorm(FastRMSNorm):
@classmethod
def load(cls, prefix, weights, eps=1e-6):
def load(cls, prefix: str, weights, eps=1e-6):
dtype = weights.dtype
weights.dtype = torch.float32
weight = weights.get_tensor(f"{prefix}.weight") + 1
@ -123,7 +123,7 @@ class GemmaFastRMSNorm(FastRMSNorm):
return hidden_states.to(self.dtype), residual
def load_attention(config, prefix, weights):
def load_attention(config, prefix: str, weights):
if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights)
else:
@ -261,7 +261,7 @@ class FlashGemmaAttention(torch.nn.Module):
class GemmaMLP(nn.Module):
def __init__(self, prefix, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__()
act = config.hidden_act
self.act = (
@ -299,7 +299,7 @@ class GemmaMLP(nn.Module):
class FlashGemmaLayer(nn.Module):
def __init__(self, prefix, config, weights, causal: bool):
def __init__(self, prefix: str, config, weights, causal: bool):
super().__init__()
self.self_attn = FlashGemmaAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights, causal=causal
@ -354,7 +354,7 @@ class FlashGemmaLayer(nn.Module):
class FlashGemmaModel(torch.nn.Module):
def __init__(self, prefix, config, weights, causal: bool):
def __init__(self, prefix: str, config, weights, causal: bool):
super().__init__()
process_group = weights.process_group
@ -419,7 +419,7 @@ class FlashGemmaModel(torch.nn.Module):
class FlashGemmaForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights, causal: bool):
def __init__(self, prefix: str, config, weights, *, causal: bool = True):
super().__init__()
embed_norm = config.hidden_size**0.5

View File

@ -261,7 +261,7 @@ class FlashGPT2Attention(torch.nn.Module):
class GPT2MLP(nn.Module):
def __init__(self, prefix, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__()
act = config.activation_function
self.act = (
@ -298,7 +298,7 @@ class GPT2MLP(nn.Module):
class FlashGPT2Layer(nn.Module):
def __init__(self, prefix, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__()
self.self_attn = FlashGPT2Attention(
prefix=f"{prefix}.attn", config=config, weights=weights
@ -350,7 +350,7 @@ class FlashGPT2Layer(nn.Module):
class FlashGPT2Model(torch.nn.Module):
def __init__(self, prefix, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__()
process_group = weights.process_group
@ -414,7 +414,7 @@ class FlashGPT2Model(torch.nn.Module):
class FlashGPT2ForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__()
self.embed_tokens = TensorParallelEmbedding(

View File

@ -54,7 +54,7 @@ if SYSTEM == "rocm":
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
def load_attention(config, prefix, weights, layer_id):
def load_attention(config, prefix: str, weights, layer_id):
# Only defined in granite.
bias = getattr(config, "attention_bias", False)
head_size = config.hidden_size // config.num_attention_heads
@ -467,7 +467,7 @@ class FlashLlamaModel(torch.nn.Module):
class FlashLlamaForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__()
self.embed_tokens = TensorParallelEmbedding(

View File

@ -248,7 +248,7 @@ class MistralAttention(torch.nn.Module):
class MistralMLP(nn.Module):
def __init__(self, prefix, config, weights, layer_id):
def __init__(self, prefix: str, config, weights, layer_id):
super().__init__()
self.hidden_act = config.hidden_act
self.act = (
@ -328,7 +328,7 @@ class MistralMLP(nn.Module):
class MistralLayer(nn.Module):
def __init__(self, prefix, config, weights, layer_id):
def __init__(self, prefix: str, config, weights, layer_id):
super().__init__()
self.self_attn = MistralAttention(
prefix=f"{prefix}.self_attn",
@ -392,7 +392,7 @@ class MistralLayer(nn.Module):
class MistralModel(torch.nn.Module):
def __init__(self, prefix, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__()
process_group = weights.process_group
@ -462,7 +462,7 @@ class MistralModel(torch.nn.Module):
class FlashMistralForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights, name=None):
def __init__(self, prefix: str, config, weights, name=None):
if name is None:
name = "model"
super().__init__()

View File

@ -116,7 +116,7 @@ def promote_scalar(x: torch.Tensor) -> torch.Tensor:
return x.view(1) if len(x.size()) == 0 else x
def load_attention(config, prefix, weights):
def load_attention(config, prefix: str, weights):
if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights)
else:
@ -155,7 +155,7 @@ def _load_gqa(config, prefix: str, weights):
)
def _load_experts(config, prefix, mat, weights):
def _load_experts(config, prefix: str, mat, weights):
if config.quantize is not None:
raise NotImplementedError("Mixtral does not support weight quantization yet.")
@ -475,7 +475,7 @@ class DenseMoE(nn.Module):
class MixtralLayer(nn.Module):
def __init__(self, prefix, layer_id, config, weights):
def __init__(self, prefix: str, layer_id, config, weights):
super().__init__()
prefix = f"{prefix}.layers.{layer_id}"
@ -536,7 +536,7 @@ class MixtralLayer(nn.Module):
class MixtralModel(torch.nn.Module):
def __init__(self, prefix, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__()
self.embed_tokens = TensorParallelEmbedding(
@ -610,7 +610,7 @@ class MixtralModel(torch.nn.Module):
class FlashMixtralForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__()
self.model = MixtralModel(prefix, config, weights)

View File

@ -305,12 +305,12 @@ class FlashGPTNeoXPreTrainedModel(PreTrainedModel):
class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
def __init__(self, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__(config)
self.config = config
self.embed_in = TensorParallelEmbedding(
prefix="gpt_neox.embed_in", weights=weights
prefix=f"{prefix}.embed_in", weights=weights
)
self.layers = nn.ModuleList(
@ -320,7 +320,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
]
)
self.final_layer_norm = FastLayerNorm.load(
prefix="gpt_neox.final_layer_norm",
prefix=f"{prefix}.final_layer_norm",
weights=weights,
eps=config.layer_norm_eps,
)
@ -370,9 +370,15 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
def __init__(self, config, weights):
def __init__(self, prefix, config, weights):
super().__init__(config)
self.gpt_neox = FlashGPTNeoXModel(config, weights)
if not prefix:
prefix = "gpt_neox"
else:
prefix = f"{prefix}.gpt_neox"
self.gpt_neox = FlashGPTNeoXModel(prefix, config, weights)
self.embed_out = SpeculativeHead.load(
config, prefix="embed_out", weights=weights

View File

@ -258,9 +258,9 @@ class PhiMLP(nn.Module):
class FlashPhiLayer(nn.Module):
def __init__(self, layer_id, config, weights):
def __init__(self, prefix: str, layer_id, config, weights):
super().__init__()
prefix = f"model.layers.{layer_id}"
prefix = f"{prefix}.layers.{layer_id}"
self.self_attn = FlashPhiAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
@ -307,18 +307,19 @@ class FlashPhiLayer(nn.Module):
class FlashPhiModel(torch.nn.Module):
def __init__(self, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__()
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights
prefix=f"{prefix}.embed_tokens", weights=weights
)
self.layers = nn.ModuleList(
[
FlashPhiLayer(
prefix,
layer_id,
config,
weights,
@ -378,10 +379,15 @@ class FlashPhiModel(torch.nn.Module):
class FlashPhiForCausalLM(torch.nn.Module):
def __init__(self, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__()
self.model = FlashPhiModel(config, weights)
if not prefix:
prefix = "model"
else:
prefix = f"{prefix}.model"
self.model = FlashPhiModel(prefix, config, weights)
self.lm_head = SpeculativeHead.load(
config,
prefix="lm_head",

View File

@ -203,9 +203,9 @@ class Qwen2MLP(nn.Module):
class Qwen2Layer(nn.Module):
def __init__(self, layer_id, config, weights):
def __init__(self, prefix, layer_id, config, weights):
super().__init__()
prefix = f"model.layers.{layer_id}"
prefix = f"{prefix}.layers.{layer_id}"
self.self_attn = Qwen2Attention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
@ -260,17 +260,18 @@ class Qwen2Layer(nn.Module):
class Qwen2Model(torch.nn.Module):
def __init__(self, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__()
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights
prefix=f"{prefix}.embed_tokens", weights=weights
)
self.layers = nn.ModuleList(
[
Qwen2Layer(
prefix,
layer_id,
config,
weights,
@ -279,7 +280,7 @@ class Qwen2Model(torch.nn.Module):
]
)
self.norm = FastRMSNorm.load(
prefix="model.norm", weights=weights, eps=config.rms_norm_eps
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
)
self.gradient_checkpointing = False
@ -331,10 +332,15 @@ class Qwen2Model(torch.nn.Module):
class Qwen2ForCausalLM(torch.nn.Module):
def __init__(self, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__()
self.model = Qwen2Model(config, weights)
if not prefix:
prefix = "model"
else:
prefix = f"{prefix}.model"
self.model = Qwen2Model(prefix, config, weights)
self.lm_head = SpeculativeHead.load(
config,
prefix="lm_head",

View File

@ -127,7 +127,7 @@ class FlashRWAttention(torch.nn.Module):
def __init__(
self,
config,
prefix,
prefix: str,
weights,
):
super().__init__()
@ -236,7 +236,7 @@ class FlashRWLargeAttention(torch.nn.Module):
def __init__(
self,
config,
prefix,
prefix: str,
weights,
):
super().__init__()
@ -358,7 +358,7 @@ class FlashRWLargeAttention(torch.nn.Module):
class FlashMLP(nn.Module):
def __init__(self, config, prefix, weights):
def __init__(self, config, prefix: str, weights):
super().__init__()
self.act = torch.nn.functional.gelu
@ -380,6 +380,7 @@ class FlashRWLayer(nn.Module):
def __init__(
self,
layer_id,
prefix: str,
config,
weights,
):
@ -388,7 +389,7 @@ class FlashRWLayer(nn.Module):
parallel_attn = config.parallel_attn
self.parallel_attn = parallel_attn
prefix = f"transformer.h.{layer_id}"
prefix = f"{prefix}.h.{layer_id}"
self.input_layernorm = FastLayerNorm.load(
prefix=f"{prefix}.input_layernorm",
@ -479,7 +480,7 @@ class FlashRWLayer(nn.Module):
class FlashRWLayerNorm(nn.Module):
def __init__(self, config, prefix, weights):
def __init__(self, config, prefix: str, weights):
super().__init__()
self.num_ln = config.num_ln_in_parallel_attn
@ -518,9 +519,9 @@ class FlashRWLayerNorm(nn.Module):
class FlashRWLargeLayer(nn.Module):
def __init__(self, layer_id, config, weights):
def __init__(self, layer_id, prefix: str, config, weights):
super().__init__()
prefix = f"transformer.h.{layer_id}"
prefix = f"{prefix}.h.{layer_id}"
self.ln_layer = FlashRWLayerNorm(config, prefix, weights)
@ -580,18 +581,18 @@ class FlashRWPreTrainedModel(PreTrainedModel):
class FlashRWModel(FlashRWPreTrainedModel):
def __init__(self, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__(config)
self.config = config
self.word_embeddings = TensorParallelEmbedding(
prefix="transformer.word_embeddings", weights=weights
prefix=f"{prefix}.word_embeddings", weights=weights
)
if config.new_decoder_architecture:
self.h = nn.ModuleList(
[
FlashRWLargeLayer(layer_id, config, weights)
FlashRWLargeLayer(layer_id, prefix, config, weights)
for layer_id in range(config.num_hidden_layers)
]
)
@ -599,14 +600,14 @@ class FlashRWModel(FlashRWPreTrainedModel):
else:
self.h = nn.ModuleList(
[
FlashRWLayer(layer_id, config, weights)
FlashRWLayer(layer_id, prefix, config, weights)
for layer_id in range(config.num_hidden_layers)
]
)
self.cache_size = self.h[0].self_attention.num_heads_kv
self.ln_f = FastLayerNorm.load(
prefix="transformer.ln_f",
prefix=f"{prefix}.ln_f",
weights=weights,
eps=config.layer_norm_epsilon,
)
@ -653,10 +654,15 @@ class FlashRWModel(FlashRWPreTrainedModel):
class FlashRWForCausalLM(FlashRWPreTrainedModel):
def __init__(self, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__(config)
self.transformer = FlashRWModel(config, weights)
if not prefix:
prefix = "transformer"
else:
prefix = f"{prefix}.transformer"
self.transformer = FlashRWModel(prefix, config, weights)
self.lm_head = SpeculativeHead.load(config, prefix="lm_head", weights=weights)

View File

@ -346,16 +346,16 @@ class MLP(nn.Module):
class Block(nn.Module):
def __init__(self, layer_id, config, weights):
def __init__(self, prefix: str, layer_id, config, weights):
super().__init__()
prefix = f"transformer.h.{layer_id}"
prefix = f"{prefix}.h.{layer_id}"
self.ln_1 = FastLayerNorm.load(
prefix=f"{prefix}.ln_1", weights=weights, eps=config.layer_norm_epsilon
)
self.ln_2 = FastLayerNorm.load(
prefix=f"{prefix}.ln_2", weights=weights, eps=config.layer_norm_epsilon
)
self.attn = FlashMQAttention(
self.self_attn = FlashMQAttention(
prefix=f"{prefix}.attn",
config=config,
weights=weights,
@ -378,7 +378,7 @@ class Block(nn.Module):
max_s,
):
hidden_states, residual = self.ln_1(hidden_states, residual)
hidden_states = self.attn(
hidden_states = self.self_attn(
hidden_states,
cu_seqlen_prefill,
kv_cache,
@ -396,25 +396,26 @@ class Block(nn.Module):
class FlashSantacoderModel(nn.Module):
def __init__(self, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__()
self.config = config
self.process_group = weights.process_group
self.wte = TensorParallelEmbedding(
prefix="transformer.wte",
prefix=f"{prefix}.wte",
weights=weights,
reduce=False,
)
self.wpe = TensorParallelEmbedding(
prefix="transformer.wpe",
prefix=f"{prefix}.wpe",
weights=weights,
reduce=False,
)
self.h = nn.ModuleList(
self.layers = nn.ModuleList(
[
Block(
prefix,
layer_id,
config,
weights,
@ -426,8 +427,8 @@ class FlashSantacoderModel(nn.Module):
prefix="transformer.ln_f", weights=weights, eps=config.layer_norm_epsilon
)
self.head_size = self.h[0].attn.head_size
self.num_heads = self.h[0].attn.num_heads
self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads
def forward(
self,
@ -446,7 +447,7 @@ class FlashSantacoderModel(nn.Module):
torch.distributed.all_reduce(hidden_states, group=self.process_group)
residual = None
for i, layer in enumerate(self.h):
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
residual,
@ -464,11 +465,18 @@ class FlashSantacoderModel(nn.Module):
class FlashSantacoderForCausalLM(nn.Module):
def __init__(self, config, weights):
def __init__(self, prefix, config, weights):
super().__init__()
self.transformer = FlashSantacoderModel(config, weights)
if not prefix:
prefix = "transformer"
else:
prefix = f"{prefix}.transformer"
config.transpose = config.architectures[0].startswith("GPT2")
self.model = FlashSantacoderModel(prefix, config, weights)
self.lm_head = SpeculativeHead.load(
config, prefix="transformer.wte", weights=weights
config, prefix=f"{prefix}.wte", weights=weights
)
def forward(
@ -485,7 +493,7 @@ class FlashSantacoderForCausalLM(nn.Module):
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.transformer(
hidden_states = self.model(
input_ids,
position_ids,
cu_seqlen_prefill,

View File

@ -417,14 +417,14 @@ class Starcoder2Layer(nn.Module):
class Starcoder2Model(torch.nn.Module):
def __init__(self, config, weights):
def __init__(self, prefix, config, weights):
super().__init__()
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights
prefix=f"{prefix}.embed_tokens", weights=weights
)
self.layers = nn.ModuleList(
[
@ -437,7 +437,7 @@ class Starcoder2Model(torch.nn.Module):
]
)
self.norm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load(
prefix="model.norm", weights=weights, eps=config.norm_epsilon
prefix=f"{prefix}.norm", weights=weights, eps=config.norm_epsilon
)
self.gradient_checkpointing = False
@ -489,10 +489,15 @@ class Starcoder2Model(torch.nn.Module):
class FlashStarcoder2ForCausalLM(torch.nn.Module):
def __init__(self, config, weights):
def __init__(self, prefix, config, weights):
super().__init__()
self.model = Starcoder2Model(config, weights)
if not prefix:
prefix = "model"
else:
prefix = f"{prefix}.model"
self.model = Starcoder2Model(prefix, config, weights)
try:
self.lm_head = SpeculativeHead.load(
config,
@ -502,7 +507,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
except RuntimeError:
self.lm_head = SpeculativeHead.load(
config,
prefix="model.embed_tokens",
prefix=f"{prefix}.embed_tokens",
weights=weights,
)

View File

@ -136,7 +136,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
self.config = config
config.text_config.quantize = config.quantize
config.text_config.speculator = config.speculator
self.language_model = load_text_model(
self.text_model = load_text_model(
prefix="language_model" if not prefix else f"{prefix}.language_model",
config=config.text_config,
weights=weights,
@ -180,7 +180,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
image_sizes: Optional[torch.LongTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
):
inputs_embeds = self.language_model.embed_tokens(input_ids)
inputs_embeds = self.text_model.embed_tokens(input_ids)
if pixel_values is not None and len(pixel_values) > 0:
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
@ -269,7 +269,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
input_ids, inputs_embeds, image_features
)
hidden_states = self.language_model.model(
hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
@ -283,5 +283,5 @@ class LlavaNextForConditionalGeneration(nn.Module):
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.language_model.lm_head(hidden_states)
logits, speculative_logits = self.text_model.lm_head(hidden_states)
return logits, speculative_logits

View File

@ -783,7 +783,7 @@ class MPTPreTrainedModel(PreTrainedModel):
class MPTModel(MPTPreTrainedModel):
def __init__(self, config, weights):
def __init__(self, prefix: str, config, weights):
# config._validate_config()
super().__init__(config)
self.world_size = weights.process_group.size()
@ -809,13 +809,13 @@ class MPTModel(MPTPreTrainedModel):
f"Requested norm type ({config.norm_type}) is not implemented within this repo."
)
self.wte = TensorParallelEmbedding("transformer.wte", weights)
self.wte = TensorParallelEmbedding(f"{prefix}.wte", weights)
if not self.alibi:
self.wpe = TensorParallelEmbedding("transformer.wpe", weights)
self.wpe = TensorParallelEmbedding(f"{prefix}.wpe", weights)
self.blocks = nn.ModuleList(
[
MPTBlock(config, prefix=f"transformer.blocks.{i}", weights=weights)
MPTBlock(config, prefix=f"{prefix}.blocks.{i}", weights=weights)
for i in range(config.n_layers)
]
)
@ -1085,13 +1085,19 @@ class MPTModel(MPTPreTrainedModel):
class MPTForCausalLM(MPTPreTrainedModel):
def __init__(self, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__(config)
if not prefix:
prefix = "transformer"
else:
prefix = f"{prefix}.transformer"
if not config.tie_word_embeddings:
raise ValueError("MPTForCausalLM only supports tied word embeddings")
self.transformer = MPTModel(config, weights)
self.transformer = MPTModel(prefix, config, weights)
self.lm_head = SpeculativeHead.load(
config, prefix="transformer.wte", weights=weights
config, prefix=f"{prefix}.wte", weights=weights
)
self.logit_scale = None
if config.logit_scale is not None:

View File

@ -404,24 +404,24 @@ class GPTNeoXMLP(nn.Module):
class GPTNeoXLayer(nn.Module):
def __init__(self, layer_id, config, weights):
def __init__(self, layer_id, prefix: str, config, weights):
super().__init__()
self.use_parallel_residual = config.use_parallel_residual
self.input_layernorm = nn.LayerNorm.load(
prefix=f"gpt_neox.layers.{layer_id}.input_layernorm",
prefix=f"{prefix}.layers.{layer_id}.input_layernorm",
weights=weights,
eps=config.layer_norm_eps,
)
self.post_attention_layernorm = nn.LayerNorm.load(
prefix=f"gpt_neox.layers.{layer_id}.post_attention_layernorm",
prefix=f"{prefix}.layers.{layer_id}.post_attention_layernorm",
weights=weights,
eps=config.layer_norm_eps,
)
self.attention = GPTNeoXAttention(
config, prefix=f"gpt_neox.layers.{layer_id}.attention", weights=weights
config, prefix=f"{prefix}.layers.{layer_id}.attention", weights=weights
)
self.mlp = GPTNeoXMLP(
config, prefix=f"gpt_neox.layers.{layer_id}.mlp", weights=weights
config, prefix=f"{prefix}.layers.{layer_id}.mlp", weights=weights
)
def forward(
@ -472,23 +472,23 @@ class GPTNeoXLayer(nn.Module):
class GPTNeoXModel(GPTNeoXPreTrainedModel):
def __init__(self, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__(config)
self.config = config
self.num_attention_heads = config.num_attention_heads
self.embed_in = TensorParallelEmbedding(
prefix="gpt_neox.embed_in", weights=weights
prefix=f"{prefix}.embed_in", weights=weights
)
self.layers = nn.ModuleList(
[
GPTNeoXLayer(layer_id, config, weights)
GPTNeoXLayer(layer_id, prefix, config, weights)
for layer_id in range(config.num_hidden_layers)
]
)
self.final_layer_norm = nn.LayerNorm.load(
prefix="gpt_neox.final_layer_norm",
prefix=f"{prefix}.final_layer_norm",
weights=weights,
eps=config.layer_norm_eps,
)
@ -640,9 +640,15 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
def __init__(self, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__(config)
self.gpt_neox = GPTNeoXModel(config, weights)
if not prefix:
prefix = "gpt_neox"
else:
prefix = f"{prefix}.gpt_neox"
self.gpt_neox = GPTNeoXModel(prefix, config, weights)
self.embed_out = SpeculativeHead.load(
config, prefix="embed_out", weights=weights
)

View File

@ -94,11 +94,11 @@ class OPTLearnedPositionalEmbedding(nn.Module):
This module learns positional embeddings up to a fixed maximum size.
"""
def __init__(self, weights):
def __init__(self, prefix: str, weights):
super().__init__()
self.offset = 2
self.weight = nn.Parameter(
weights.get_tensor("model.decoder.embed_positions.weight")
weights.get_tensor(f"{prefix}.decoder.embed_positions.weight")
)
def forward(
@ -311,11 +311,11 @@ class OPTAttention(nn.Module):
class OPTDecoderLayer(nn.Module):
def __init__(self, layer_id: int, config: OPTConfig, weights):
def __init__(self, layer_id: int, prefix: str, config: OPTConfig, weights):
super().__init__()
self.process_group = weights.process_group
self.hidden_size = config.hidden_size
prefix = f"model.decoder.layers.{layer_id}"
prefix = f"{prefix}.decoder.layers.{layer_id}"
self.self_attn = OPTAttention(
config,
prefix=f"{prefix}.self_attn",
@ -429,7 +429,7 @@ class OPTPreTrainedModel(PreTrainedModel):
class OPTDecoder(OPTPreTrainedModel):
def __init__(self, config: OPTConfig, weights):
def __init__(self, prefix: str, config: OPTConfig, weights):
super().__init__(config)
self.dropout = config.dropout
self.layerdrop = config.layerdrop
@ -438,20 +438,26 @@ class OPTDecoder(OPTPreTrainedModel):
self.vocab_size = config.vocab_size
self.embed_tokens = TensorParallelEmbedding(
prefix="model.decoder.embed_tokens", weights=weights
prefix=f"{prefix}.decoder.embed_tokens", weights=weights
)
self.embed_positions = OPTLearnedPositionalEmbedding(weights)
self.embed_positions = OPTLearnedPositionalEmbedding(prefix, weights)
if config.word_embed_proj_dim != config.hidden_size:
self.project_out = FastLinear.load(
config, prefix="model.decoder.project_out", weights=weights, bias=False
config,
prefix=f"{prefix}.decoder.project_out",
weights=weights,
bias=False,
)
else:
self.project_out = None
if config.word_embed_proj_dim != config.hidden_size:
self.project_in = FastLinear.load(
config, prefix="model.decoder.project_in", weights=weights, bias=False
config,
prefix=f"{prefix}.decoder.project_in",
weights=weights,
bias=False,
)
else:
self.project_in = None
@ -461,14 +467,14 @@ class OPTDecoder(OPTPreTrainedModel):
# see https://github.com/facebookresearch/metaseq/pull/164
if config.do_layer_norm_before and not config._remove_final_layer_norm:
self.final_layer_norm = nn.LayerNorm.load(
prefix="model.decoder.final_layer_norm", weights=weights, eps=EPS
prefix=f"{prefix}.decoder.final_layer_norm", weights=weights, eps=EPS
)
else:
self.final_layer_norm = None
self.layers = nn.ModuleList(
[
OPTDecoderLayer(layer_id, config, weights)
OPTDecoderLayer(layer_id, prefix, config, weights)
for layer_id in range(config.num_hidden_layers)
]
)
@ -686,9 +692,9 @@ class OPTDecoder(OPTPreTrainedModel):
class OPTModel(OPTPreTrainedModel):
def __init__(self, config: OPTConfig, weights):
def __init__(self, prefix: str, config: OPTConfig, weights):
super().__init__(config)
self.decoder = OPTDecoder(config, weights)
self.decoder = OPTDecoder(prefix, config, weights)
# Initialize weights and apply final processing
def forward(
@ -743,13 +749,18 @@ class OPTModel(OPTPreTrainedModel):
class OPTForCausalLM(OPTPreTrainedModel):
def __init__(self, config, weights):
def __init__(self, prefix, config, weights):
super().__init__(config)
self.model = OPTModel(config, weights)
if not prefix:
prefix = "model"
else:
prefix = f"{prefix}.model"
self.model = OPTModel(prefix, config, weights)
self.lm_head = SpeculativeHead.load(
config, prefix="model.decoder.embed_tokens", weights=weights
config, prefix=f"{prefix}.decoder.embed_tokens", weights=weights
)
def forward(

View File

@ -248,16 +248,16 @@ class PhiBlock(nn.Module):
# PhiModel implements the embedding layer and the transformer blocks.
class PhiModel(nn.Module):
def __init__(self, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__()
self.tp_rank = weights.process_group.rank()
self.tp_world_size = weights.process_group.size()
self.embed_tokens = TensorParallelEmbedding(
prefix="transformer.embd.wte", weights=weights
prefix=f"{prefix}.embd.wte", weights=weights
)
self.blocks = nn.ModuleList(
[
PhiBlock(f"transformer.h.{layer_id}", config, weights)
PhiBlock(f"{prefix}.h.{layer_id}", config, weights)
for layer_id in range(config.n_layer)
]
)
@ -289,9 +289,15 @@ class PhiModel(nn.Module):
# PhiForCausalLM wraps the PhiModel and PhiCausalLMHead together and returns a CausalLMOutputWithPast object.
class PhiForCausalLM(torch.nn.Module):
def __init__(self, config, weights):
def __init__(self, prefix: str, config, weights):
super().__init__()
self.model = PhiModel(config, weights)
if not prefix:
prefix = "transformer"
else:
prefix = f"{prefix}.transformer"
self.model = PhiModel(prefix, config, weights)
self.lm_head = PhiCausalLMHead(config, weights)
def forward(

View File

@ -10,7 +10,12 @@ import numpy as np
from loguru import logger
from dataclasses import dataclass
from opentelemetry import trace
from transformers import PreTrainedTokenizerBase
from transformers import (
PreTrainedTokenizerBase,
AutoConfig,
AutoTokenizer,
GenerationConfig,
)
from typing import Iterable, Optional, Tuple, List, Type, Dict
from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
@ -21,6 +26,12 @@ from text_generation_server.models import Model
from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.utils.dist import RANK
from text_generation_server.utils.speculate import get_speculate
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
hub,
)
from text_generation_server.models.types import (
Batch,
Tokens,
@ -798,29 +809,120 @@ class FlashCausalLMBatch(Batch):
return len(self.requests)
ADAPTER_LAYERS = [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
]
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}
class FlashCausalLM(Model):
def __init__(
self,
model_id: str,
model: torch.nn.Module,
tokenizer: PreTrainedTokenizerBase,
num_layers: int,
num_kv_heads: int,
head_size: int,
dtype: torch.dtype,
device: torch.device,
rank: int = 0,
world_size: int = 1,
sliding_window: Optional[int] = None,
model_class,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
lora_adapter_ids: Optional[list] = [],
tokenizer_class: PreTrainedTokenizerBase = AutoTokenizer,
config_class: PreTrainedTokenizerBase = AutoConfig,
default_dtype=torch.float16,
aliases=None,
# Used for Santacoder override of config
num_kv_heads=None,
skip_special_tokens: bool = True,
):
self.num_layers = num_layers
self.num_kv_heads = num_kv_heads
self.head_size = head_size
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = default_dtype if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = default_dtype if dtype is None else dtype
else:
device = torch.device("cpu")
# Float16 doesn't exist on target.
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError(f"{model_class} is only available on GPU")
tokenizer = tokenizer_class.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
try:
generation_config = GenerationConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
if isinstance(generation_config.eos_token_id, (list, set)):
# TODO Huge hack
tokenizer._eos_token_ids = set(generation_config.eos_token_id)
except Exception:
pass
config = config_class.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames, device, dtype, process_group=self.process_group, aliases=aliases
)
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
prefix = ""
model = model_class(prefix, config, weights)
torch.distributed.barrier(group=self.process_group)
# VLM models define the config we care about in their text_config
text_config = getattr(config, "text_config", None)
if text_config is not None:
config = text_config
if getattr(config, "sliding_window", None) is not None:
set_sliding_window(config.sliding_window)
else:
config.sliding_window = None
self.num_layers = config.num_hidden_layers
# Validation is done in the model itself
if num_kv_heads is None:
# Order is important here.
for attr in ["num_key_value_heads", "num_attention_heads", "n_head"]:
num_kv_heads = getattr(config, attr, None)
if num_kv_heads is not None:
break
if num_kv_heads is None:
raise ValueError("Cannot get the number of key/value heads")
self.num_kv_heads = (
num_kv_heads // self.process_group.size()
if num_kv_heads > 1
else num_kv_heads
)
assert self.num_kv_heads > 0
self.head_size = config.hidden_size // config.num_attention_heads
self.cuda_graphs = {}
self.kv_cache = []
super(FlashCausalLM, self).__init__(
super().__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
@ -829,7 +931,7 @@ class FlashCausalLM(Model):
device=device,
rank=rank,
world_size=world_size,
sliding_window=sliding_window,
sliding_window=config.sliding_window,
)
@property
@ -1577,3 +1679,72 @@ class FlashCausalLM(Model):
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, batch, (forward_ns, decode_ns)
@property
def supports_adapter_loading(self) -> bool:
return True
def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
layer_weights = {}
prefix = "model.layers"
# This accounts for VLMs (e.g. LlavaNext, Idefics2)
# that have a language_model inside of the larger model.
if hasattr(self.model, "language_model"):
_model = self.model.language_model
elif hasattr(self.model, "text_model"):
_model = self.model.text_model
else:
_model = self.model
for i, layer in enumerate(_model.model.layers):
layer_weights[(i, "q_proj")] = (
f"{prefix}.{i}.self_attn.q_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "k_proj")] = (
f"{prefix}.{i}.self_attn.k_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "v_proj")] = (
f"{prefix}.{i}.self_attn.v_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "o_proj")] = (
f"{prefix}.{i}.self_attn.o_proj",
layer.self_attn.o_proj,
)
# TODO: this is a hack to avoid the gate_proj for
# FlashStarcoder2 that doesnt have these layers
if hasattr(layer, "mlp") and hasattr(layer.mlp, "gate_up_proj"):
layer_weights[(i, "gate_proj")] = (
f"{prefix}.{i}.mlp.gate_proj",
layer.mlp.gate_up_proj,
)
layer_weights[(i, "up_proj")] = (
f"{prefix}.{i}.mlp.up_proj",
layer.mlp.gate_up_proj,
)
layer_weights[(i, "down_proj")] = (
f"{prefix}.{i}.mlp.down_proj",
layer.mlp.down_proj,
)
layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head)
return layer_weights
@property
def adapter_layers(self) -> List[str]:
return ADAPTER_LAYERS
@property
def default_traced_adapter_layers(self) -> List[str]:
return ["q_proj", "v_proj"]
def get_num_layers_for_type(self, layer_type: str) -> int:
return 1 if layer_type == "lm_head" else len(self.model.model.layers)
def is_row_parallel(self, layer_type: str) -> bool:
return layer_type in ROW_PARALLEL

View File

@ -1,75 +0,0 @@
import torch
import torch.distributed
from opentelemetry import trace
from typing import Optional
from transformers import AutoTokenizer, AutoConfig
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
FlashCohereForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
tracer = trace.get_tracer(__name__)
class FlashCohere(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
raise NotImplementedError("FlashCohere is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
use_fast=True,
from_slow=False,
)
config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = FlashCohereForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashCohere, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)

View File

@ -1,100 +0,0 @@
import torch
import torch.distributed
from opentelemetry import trace
from typing import Optional
from transformers import AutoTokenizer
from transformers.models.gpt2 import GPT2TokenizerFast
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (
FlashDbrxForCausalLM,
DbrxConfig,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
tracer = trace.get_tracer(__name__)
class FlashDbrx(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashDBRX is only available on GPU")
try:
tokenizer = GPT2TokenizerFast.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
use_fast=True,
from_slow=False,
)
except:
try:
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
use_fast=True,
from_slow=False,
)
except:
# FIXME: change back to model id once the tokenizer.json is merged
tokenizer = GPT2TokenizerFast.from_pretrained(
"Xenova/dbrx-instruct-tokenizer",
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
use_fast=True,
from_slow=False,
)
config = DbrxConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = FlashDbrxForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashDbrx, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)

View File

@ -1,83 +0,0 @@
import torch
import torch.distributed
from opentelemetry import trace
from typing import Optional
from transformers import AutoConfig, AutoTokenizer
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
FlashGemmaForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__)
class FlashGemma(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashGemma is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
# TODO hardcoded
prefix = ""
model = FlashGemmaForCausalLM(prefix, config, weights, causal=True)
torch.distributed.barrier(group=self.process_group)
super(FlashGemma, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)

View File

@ -1,83 +0,0 @@
import torch
import torch.distributed
from opentelemetry import trace
from typing import Optional
from transformers import PretrainedConfig, AutoTokenizer
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
FlashGemma2ForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__)
class FlashGemma2(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashGemma2 is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = PretrainedConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
# TODO hardcoded
prefix = ""
model = FlashGemma2ForCausalLM(prefix, config, weights, causal=True)
torch.distributed.barrier(group=self.process_group)
super(FlashGemma2, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)

View File

@ -1,82 +0,0 @@
import torch
import torch.distributed
from opentelemetry import trace
from transformers import AutoConfig, AutoTokenizer, GenerationConfig
from transformers.models.gpt2 import GPT2Tokenizer
from typing import Optional
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (
FlashGPT2ForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__)
class FlashGPT2(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashGPT2 is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
prefix = ""
model = FlashGPT2ForCausalLM(prefix, config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashGPT2, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)

View File

@ -1,171 +0,0 @@
import os
import torch
import torch.distributed
from opentelemetry import trace
from transformers import AutoConfig, AutoTokenizer, GenerationConfig
from typing import Optional, Tuple, Dict, List
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
hub,
)
tracer = trace.get_tracer(__name__)
from text_generation_server.utils.import_utils import SYSTEM
ADAPTER_LAYERS = [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
]
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}
class FlashLlama(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
lora_adapter_ids: Optional[list] = [],
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashLlama is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
try:
generation_config = GenerationConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
if isinstance(generation_config.eos_token_id, (list, set)):
# TODO Huge hack
tokenizer._eos_token_ids = set(generation_config.eos_token_id)
except Exception:
pass
config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
prefix = ""
model = FlashLlamaForCausalLM(prefix, config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashLlama, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
@property
def supports_adapter_loading(self) -> bool:
return True
def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
layer_weights = {}
prefix = "model.layers"
# This accounts for VLMs (e.g. LlavaNext, Idefics2)
# that have a language_model inside of the larger model.
if hasattr(self.model, "language_model"):
_model = self.model.language_model
elif hasattr(self.model, "text_model"):
_model = self.model.text_model
else:
_model = self.model
for i, layer in enumerate(_model.model.layers):
layer_weights[(i, "q_proj")] = (
f"{prefix}.{i}.self_attn.q_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "k_proj")] = (
f"{prefix}.{i}.self_attn.k_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "v_proj")] = (
f"{prefix}.{i}.self_attn.v_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, "o_proj")] = (
f"{prefix}.{i}.self_attn.o_proj",
layer.self_attn.o_proj,
)
layer_weights[(i, "gate_proj")] = (
f"{prefix}.{i}.mlp.gate_proj",
layer.mlp.gate_up_proj,
)
layer_weights[(i, "up_proj")] = (
f"{prefix}.{i}.mlp.up_proj",
layer.mlp.gate_up_proj,
)
layer_weights[(i, "down_proj")] = (
f"{prefix}.{i}.mlp.down_proj",
layer.mlp.down_proj,
)
layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head)
return layer_weights
@property
def adapter_layers(self) -> List[str]:
return ADAPTER_LAYERS
@property
def default_traced_adapter_layers(self) -> List[str]:
return ["q_proj", "v_proj"]
def get_num_layers_for_type(self, layer_type: str) -> int:
return 1 if layer_type == "lm_head" else len(self.model.model.layers)
def is_row_parallel(self, layer_type: str) -> bool:
return layer_type in ROW_PARALLEL

View File

@ -1,24 +1,7 @@
import torch
import torch.distributed
from opentelemetry import trace
from transformers import AutoTokenizer, AutoConfig
from typing import Optional, Tuple, Dict, List
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.flash_causal_lm import set_sliding_window
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
FlashMistralForCausalLM,
MistralConfig,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__)
ADAPTER_LAYERS = [
@ -33,88 +16,7 @@ ADAPTER_LAYERS = [
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}
class BaseFlashMistral(FlashCausalLM):
def __init__(
self,
model_cls,
model_id: str,
config_cls=AutoConfig,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
tokenizer_class=AutoTokenizer,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashMistral is only available on GPU")
tokenizer = tokenizer_class.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = config_cls.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
# Set context windows
if getattr(config, "sliding_window", None) is not None:
set_sliding_window(config.sliding_window)
else:
config.sliding_window = None
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
prefix = ""
model = model_cls(prefix, config, weights)
self.cuda_graphs = {}
torch.distributed.barrier(group=self.process_group)
num_layers, num_kv_heads, head_size = self.get_layer_config(model)
super().__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=num_layers,
num_kv_heads=num_kv_heads,
head_size=head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
sliding_window=config.sliding_window,
)
def get_layer_config(self, model) -> Tuple[int, int, int]:
return (
len(model.model.layers),
model.model.num_key_value_heads,
model.model.head_size,
)
class FlashMistral(FlashCausalLM):
@property
def supports_adapter_loading(self) -> bool:
return True
@ -126,9 +28,7 @@ class BaseFlashMistral(FlashCausalLM):
# This accounts for VLMs (e.g. LlavaNext, Idefics2)
# that have a language_model inside of the larger model.
if hasattr(self.model, "language_model"):
_model = self.model.language_model
elif hasattr(self.model, "text_model"):
if hasattr(self.model, "text_model"):
_model = self.model.text_model
else:
_model = self.model
@ -183,25 +83,3 @@ class BaseFlashMistral(FlashCausalLM):
def is_row_parallel(self, layer_type: str) -> bool:
return layer_type in ROW_PARALLEL
class FlashMistral(BaseFlashMistral):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
super(FlashMistral, self).__init__(
config_cls=MistralConfig,
model_cls=FlashMistralForCausalLM,
model_id=model_id,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)

View File

@ -1,31 +0,0 @@
import torch
from typing import Optional
from text_generation_server.models.flash_mistral import BaseFlashMistral
from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
MixtralConfig,
FlashMixtralForCausalLM,
)
class FlashMixtral(BaseFlashMistral):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
super(FlashMixtral, self).__init__(
config_cls=MixtralConfig,
model_cls=FlashMixtralForCausalLM,
model_id=model_id,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)

View File

@ -1,82 +0,0 @@
import torch
import torch.distributed
from opentelemetry import trace
from transformers import AutoTokenizer, AutoConfig
from typing import Optional
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
FlashGPTNeoXForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__)
class FlashNeoXSharded(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashNeoX is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
)
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = FlashGPTNeoXForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashNeoXSharded, self).__init__(
model_id=model_id,
model=model.to(device),
tokenizer=tokenizer,
num_layers=len(model.gpt_neox.layers),
num_kv_heads=model.gpt_neox.num_heads,
head_size=model.gpt_neox.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)

View File

@ -1,111 +0,0 @@
import torch
import torch.distributed
from opentelemetry import trace
from transformers import AutoConfig, AutoTokenizer
from typing import Optional
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
FlashPhiForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__)
class FlashPhi(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashPhi is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = FlashPhiForCausalLM(config, weights)
if speculator:
from text_generation_server.utils.medusa import MedusaModel
from huggingface_hub import hf_hub_download
import json
import os
from pathlib import Path
is_local_model = (
Path(speculator).exists() and Path(speculator).is_dir()
) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
if not is_local_model:
medusa_config = hf_hub_download(
speculator, revision=revision, filename="config.json"
)
medusa_head = hf_hub_download(
speculator, revision=revision, filename="medusa_lm_head.pt"
)
else:
medusa_config = str(Path(speculator) / "config.json")
medusa_head = str(Path(speculator) / "medusa_lm_head.pt")
with open(medusa_config, "r") as f:
config = json.load(f)
medusa_sf = medusa_head[: -len(".pt")] + ".safetensors"
weights = Weights(
[medusa_sf], device, dtype, process_group=self.process_group
)
lm_head = model.lm_head
model.lm_head = MedusaModel(config, weights, lm_head)
torch.distributed.barrier(group=self.process_group)
super(FlashPhi, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)

View File

@ -1,93 +0,0 @@
import math
import torch
import torch.distributed
from opentelemetry import trace
from transformers import AutoTokenizer, AutoConfig
from typing import Optional
from text_generation_server.models.flash_mistral import (
BaseFlashMistral,
set_sliding_window,
)
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
Qwen2ForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__)
class FlashQwen2(BaseFlashMistral):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashQwen2 is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
# Set context windows
if config.sliding_window is not None:
set_sliding_window(config.sliding_window)
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = Qwen2ForCausalLM(config, weights)
self.cuda_graphs = {}
torch.distributed.barrier(group=self.process_group)
super(BaseFlashMistral, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
sliding_window=config.sliding_window,
)

View File

@ -1,91 +0,0 @@
import torch
import torch.distributed
from opentelemetry import trace
from transformers import AutoTokenizer
from typing import Optional
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_rw_modeling import (
RWConfig,
FlashRWForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__)
class FlashRWSharded(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashRW is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = RWConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames,
device,
dtype,
process_group=self.process_group,
aliases={
"lm_head.weight": ["transformer.word_embeddings.weight"],
"transformer.word_embeddings.weight": ["lm_head.weight"],
},
)
config.quantize = quantize
config.speculator = speculator
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = FlashRWForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashRWSharded, self).__init__(
model_id=model_id,
model=model.to(device),
tokenizer=tokenizer,
num_layers=len(model.transformer.h),
num_kv_heads=model.transformer.cache_size,
head_size=model.transformer.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)

View File

@ -1,99 +0,0 @@
import torch
import torch.distributed
from opentelemetry import trace
from transformers import AutoTokenizer, AutoConfig
from typing import Optional, List
import json
import os
from huggingface_hub import hf_hub_download
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
FlashSantacoderForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__)
class FlashSantacoderSharded(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id,
revision=revision,
trust_remote_code=True,
)
config.quantize = quantize
config.speculator = speculator
config.transpose = config.architectures[0].startswith("GPT2")
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames,
device=device,
dtype=dtype,
process_group=self.process_group,
aliases={"transformer.wte.weight": ["lm_head.weight"]},
)
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = FlashSantacoderForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashSantacoderSharded, self).__init__(
model_id=model_id,
model=model.to(device),
tokenizer=tokenizer,
num_layers=len(model.transformer.h),
num_kv_heads=1,
head_size=model.transformer.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
def decode(self, generated_ids: List[int]) -> str:
# Do not skip special tokens as they are used for custom parsing rules of the generated text
return self.tokenizer.decode(
generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
)

View File

@ -1,84 +0,0 @@
import math
import torch
from typing import Optional
from transformers.models.gpt2 import GPT2TokenizerFast
from text_generation_server.models.flash_mistral import (
BaseFlashMistral,
set_sliding_window,
)
from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
Starcoder2Config,
FlashStarcoder2ForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
# Starcoder2 has the same base as Mistral
class FlashStarcoder2(BaseFlashMistral):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
raise NotImplementedError("FlashStarcoder2 is only available on GPU")
tokenizer = GPT2TokenizerFast.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = Starcoder2Config.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
# Set context windows
if config.sliding_window is not None:
set_sliding_window(config.sliding_window)
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = FlashStarcoder2ForCausalLM(config, weights)
self.cuda_graphs = {}
torch.distributed.barrier(group=self.process_group)
super(BaseFlashMistral, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
sliding_window=config.sliding_window,
)

View File

@ -162,83 +162,3 @@ class GalacticaCausalLMBatch(CausalLMBatch):
padding_right_offset=padding_right_offset,
max_tokens=max_tokens,
)
class GalacticaSharded(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id,
revision=revision,
tp_parallel=True,
trust_remote_code=trust_remote_code,
)
config.quantize = quantize
tokenizer.pad_token_id = config.pad_token_id
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
)
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = OPTForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
@property
def batch_type(self) -> Type[CausalLMBatch]:
return GalacticaCausalLMBatch
def decode(self, generated_ids: List[int]) -> str:
# Do not skip special tokens as they are used for custom parsing rules of the generated text
return self.tokenizer.decode(
generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
)
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
):
outputs, speculative_logits = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
return outputs.logits, speculative_logits, outputs.past_key_values

View File

@ -1,89 +0,0 @@
import torch
import torch.distributed
from typing import Optional
from transformers import (
AutoTokenizer,
AutoConfig,
)
from text_generation_server.models import CausalLM
from text_generation_server.models.custom_modeling.neox_modeling import (
GPTNeoxForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
class GPTNeoxSharded(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
tokenizer.pad_token = tokenizer.eos_token
config = AutoConfig.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
)
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
)
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = GPTNeoxForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
):
outputs, speculative_logits = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=True,
)
return outputs.logits, speculative_logits, outputs.past_key_values

View File

@ -1,51 +0,0 @@
import torch
from typing import Optional, Tuple
from transformers import (
AutoProcessor,
)
from text_generation_server.models.custom_modeling.idefics2 import (
Idefics2ForConditionalGeneration,
)
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
class Idefics2(VlmCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.processor = AutoProcessor.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
# XXX: Extremely important to cap resolution in order to limit
# VRAM usage.
size={"longest_edge": 448, "shortest_edge": 378},
)
super().__init__(
model_cls=Idefics2ForConditionalGeneration,
model_id=model_id,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
def get_layer_config(self, model) -> Tuple[int, int, int]:
return (
len(model.text_model.model.layers),
model.text_model.model.num_key_value_heads,
model.text_model.model.head_size,
)
def max_past(self) -> Optional[int]:
return getattr(self.model.text_model, "max_past", None)

View File

@ -1,46 +0,0 @@
import torch
from typing import Optional, Tuple
from transformers import (
AutoProcessor,
)
from text_generation_server.models.custom_modeling.llava_next import (
LlavaNextForConditionalGeneration,
)
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
class LlavaNext(VlmCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.processor = AutoProcessor.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
super().__init__(
model_cls=LlavaNextForConditionalGeneration,
model_id=model_id,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
def get_layer_config(self, model) -> Tuple[int, int, int]:
return (
len(model.language_model.model.layers),
model.language_model.model.num_key_value_heads,
model.language_model.model.head_size,
)
def max_past(self) -> Optional[int]:
return getattr(self.model.language_model, "max_past", None)

View File

@ -63,7 +63,7 @@ class Model(ABC):
self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict(
LayerAdapterWeights
)
self.target_to_layer = self.adapter_target_to_layer()
self.target_to_layer = None
self.loaded_adapters = set()
self.static_adapter_id = adapter_id
@ -206,6 +206,8 @@ class Model(ABC):
into model. Otherwise, the adapter weights are applied during the forward
pass and stored separately from the base model parameters.
"""
if self.target_to_layer is None:
self.target_to_layer = self.adapter_target_to_layer()
if adapter_index in self.loaded_adapters:
# Adapter already loaded
return

View File

@ -1,105 +0,0 @@
import torch
import torch.distributed
from pathlib import Path
from typing import Optional, Type
from opentelemetry import trace
from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase
from huggingface_hub import hf_hub_download
import json
from text_generation_server.models import CausalLM
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.pb import generate_pb2
from text_generation_server.models.custom_modeling.mpt_modeling import (
MPTForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
tracer = trace.get_tracer(__name__)
class MPTCausalLMBatch(CausalLMBatch):
@classmethod
def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "CausalLMBatch":
batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device)
batch.keys_head_dim_last = False
return batch
class MPTSharded(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
tokenizer.pad_token = tokenizer.eos_token
# If model_id is a local path, load the file directly
local_path = Path(model_id, "config.json")
if local_path.exists():
filename = str(local_path.resolve())
else:
filename = hf_hub_download(
model_id, revision=revision, filename="config.json"
)
with open(filename, "r") as f:
config = json.load(f)
config = PretrainedConfig(**config)
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
config.quantize = quantize
model = MPTForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
@property
def batch_type(self) -> Type[CausalLMBatch]:
return MPTCausalLMBatch

View File

@ -1,86 +0,0 @@
import torch
import torch.distributed
from typing import Optional
from transformers import (
AutoTokenizer,
AutoConfig,
)
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
from text_generation_server.models import CausalLM
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
class OPTSharded(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
)
config.quantize = quantize
config.speculator = speculator
tokenizer.pad_token_id = config.pad_token_id
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
)
if config.quantize in ["gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = OPTForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
):
outputs, speculative_logits = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
return outputs.logits, speculative_logits, outputs.past_key_values

View File

@ -74,45 +74,3 @@ class PaliGemmaBatch(VlmCausalLMBatch):
else:
image_inputs = None
return batch_tokenized_inputs, image_inputs
class PaliGemma(VlmCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.processor = AutoProcessor.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
)
super().__init__(
config_cls=AutoConfig,
model_cls=PaliGemmaForConditionalGeneration,
model_id=model_id,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@property
def batch_type(self):
return PaliGemmaBatch
def get_layer_config(self, model) -> Tuple[int, int, int]:
return (
len(model.text_model.model.layers),
model.text_model.model.num_key_value_heads,
model.text_model.model.head_size,
)
def max_past(self) -> Optional[int]:
return getattr(self.model.text_model, "max_past", None)

View File

@ -1,69 +0,0 @@
import torch
import torch.distributed
from transformers import AutoConfig, AutoTokenizer
from typing import Optional, List, Tuple
from text_generation_server.models import CausalLM
from text_generation_server.models.custom_modeling.phi_modeling import (
PhiConfig,
PhiForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
class Phi(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, _rank, _world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.float16 if dtype is None else dtype
else:
if quantize:
raise ValueError("quantization is not available on CPU")
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = PhiConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
tokenizer.bos_token_id = config.bos_token_id
tokenizer.eos_token_id = config.eos_token_id
tokenizer.pad_token = tokenizer.eos_token
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
model = PhiForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
)

View File

@ -1,84 +0,0 @@
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import List, Optional, Tuple
from text_generation_server.models import CausalLM
class RW(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
if speculator:
raise RuntimeError("Medusa decoding is not enabled for AutoModel")
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.float16 if dtype is None else dtype
else:
if quantize:
raise ValueError("quantization is not available on CPU")
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
device_map=(
"auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None
),
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
)
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
model = model.cuda()
if tokenizer.pad_token_id is None:
if model.config.pad_token_id is not None:
tokenizer.pad_token_id = model.config.pad_token_id
elif model.config.eos_token_id is not None:
tokenizer.pad_token_id = model.config.eos_token_id
elif tokenizer.eos_token_id is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
else:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
super(CausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
)
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
):
# Model Forward
outputs, speculative_logits = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
return outputs.logits, speculative_logits, outputs.past_key_values

View File

@ -1,77 +0,0 @@
import torch
import torch.distributed
from typing import Optional, List
from transformers import AutoTokenizer, AutoModelForCausalLM
from text_generation_server.models import CausalLM
FIM_PREFIX = "<fim-prefix>"
FIM_MIDDLE = "<fim-middle>"
FIM_SUFFIX = "<fim-suffix>"
FIM_PAD = "<fim-pad>"
EOD = "<|endoftext|>"
class SantaCoder(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.float16 if dtype is None else dtype
else:
if quantize:
raise ValueError("quantization is not available on CPU")
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
tokenizer.add_special_tokens(
{
"additional_special_tokens": [
EOD,
FIM_PREFIX,
FIM_MIDDLE,
FIM_SUFFIX,
FIM_PAD,
],
"pad_token": EOD,
}
)
with device:
model = AutoModelForCausalLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
)
super(CausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
)
def decode(self, generated_ids: List[int]) -> str:
# Do not skip special tokens as they are used for custom parsing rules of the generated text
return self.tokenizer.decode(
generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
)

View File

@ -1,11 +1,22 @@
import torch
import torch.distributed
import time
from dataclasses import dataclass
from opentelemetry import trace
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
from transformers import (
AutoTokenizer,
AutoModelForSeq2SeqLM,
PreTrainedTokenizerBase,
AutoConfig,
)
from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.models import Model
@ -531,6 +542,80 @@ class Seq2SeqLM(Model):
def __init__(
self,
model_id: str,
model_class,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
default_dtype=torch.float16,
trust_remote_code: bool = False,
config_class=AutoConfig,
tokenizer_class=AutoTokenizer,
aliases=None,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = default_dtype if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = default_dtype if dtype is None else dtype
else:
device = torch.device("cpu")
# Float16 doesn't exist on target.
dtype = torch.bfloat16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
config = config_class.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
)
config.quantize = quantize
config.speculator = speculator
tokenizer = tokenizer_class.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
tokenizer.bos_token_id = config.decoder_start_token_id
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames,
device=device,
dtype=dtype,
process_group=self.process_group,
aliases=aliases,
)
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
weights._set_gptq_params(model_id, revision)
model = model_class(config, weights)
torch.distributed.barrier(group=self.process_group)
super().__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
@classmethod
def fallback(
cls,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
@ -574,7 +659,11 @@ class Seq2SeqLM(Model):
)
tokenizer.bos_token_id = model.config.decoder_start_token_id
super(Seq2SeqLM, self).__init__(
self = cls.__new__(
cls,
)
super().__init__(
self,
model_id=model_id,
model=model,
tokenizer=tokenizer,
@ -582,16 +671,12 @@ class Seq2SeqLM(Model):
dtype=dtype,
device=device,
)
return self
@property
def batch_type(self) -> Type[Seq2SeqLMBatch]:
return Seq2SeqLMBatch
def decode(self, decoder_ids: List[int]) -> str:
return self.tokenizer.decode(
decoder_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
def forward(
self,
input_ids,

View File

@ -1,115 +0,0 @@
import torch
import torch.distributed
from typing import List, Optional, Tuple
from transformers import (
AutoTokenizer,
AutoConfig,
)
from text_generation_server.models import Seq2SeqLM
from text_generation_server.models.custom_modeling.t5_modeling import (
T5ForConditionalGeneration,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
class T5Sharded(Seq2SeqLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
config = AutoConfig.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
)
config.quantize = quantize
config.speculator = speculator
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
tokenizer.bos_token_id = config.decoder_start_token_id
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames,
device=device,
dtype=dtype,
process_group=self.process_group,
aliases={
"shared.weight": [
"encoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
]
},
)
model = T5ForConditionalGeneration(config, weights)
torch.distributed.barrier(group=self.process_group)
super(Seq2SeqLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
def forward(
self,
input_ids,
attention_mask,
decoder_input_ids,
decoder_attention_mask: Optional,
encoder_last_hidden_state: Optional,
past_key_values: Optional = None,
) -> Tuple[
torch.Tensor,
torch.Tensor,
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
]:
# Model Forward
outputs, speculative_logits = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_last_hidden_state,
past_key_values=past_key_values,
use_cache=True,
)
return (
outputs.logits,
speculative_logits,
outputs.encoder_last_hidden_state,
outputs.past_key_values,
)

View File

@ -9,10 +9,11 @@ from typing import Iterable, Optional, Tuple, List, Type, Dict
from transformers import PreTrainedTokenizerBase
from transformers.image_processing_utils import select_best_resolution
from text_generation_server.pb import generate_pb2
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch
from text_generation_server.models.flash_mistral import (
BaseFlashMistral,
from text_generation_server.models.flash_causal_lm import (
FlashCausalLMBatch,
FlashCausalLM,
)
from transformers import AutoProcessor
tracer = trace.get_tracer(__name__)
@ -239,10 +240,35 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
return batch
class VlmCausalLM(BaseFlashMistral):
class VlmCausalLM(FlashCausalLM):
def __init__(
self,
model_id: str,
*,
processor_class=AutoProcessor,
processor_kwargs=None,
batch_class=VlmCausalLMBatch,
revision,
trust_remote_code: bool,
**kwargs,
):
if processor_kwargs is None:
processor_kwargs = {}
self.processor = processor_class.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
**processor_kwargs,
)
self.batch_class = batch_class
super().__init__(model_id=model_id, **kwargs)
@property
def batch_type(self) -> Type[VlmCausalLMBatch]:
return VlmCausalLMBatch
return self.batch_class
def max_past(self) -> Optional[int]:
return getattr(self.model.text_model, "max_past", None)
def forward(
self,

View File

@ -1,6 +1,8 @@
import subprocess
import argparse
import ast
import json
import os
TEMPLATE = """
# Supported Models and Hardware
@ -122,6 +124,53 @@ def check_supported_models(check: bool):
f.write(final_doc)
def get_openapi_schema():
try:
output = subprocess.check_output(["text-generation-router", "print-schema"])
return json.loads(output)
except subprocess.CalledProcessError as e:
print(f"Error running text-generation-router print-schema: {e}")
raise SystemExit(1)
except json.JSONDecodeError:
print("Error: Invalid JSON received from text-generation-router print-schema")
raise SystemExit(1)
def check_openapi(check: bool):
new_openapi_data = get_openapi_schema()
filename = "docs/openapi.json"
tmp_filename = "openapi_tmp.json"
with open(tmp_filename, "w") as f:
json.dump(new_openapi_data, f, indent=2)
if check:
diff = subprocess.run(
[
"diff",
# allow for trailing whitespace since it's not significant
# and the precommit hook will remove it
"--ignore-trailing-space",
tmp_filename,
filename,
],
capture_output=True,
).stdout.decode()
os.remove(tmp_filename)
if diff:
print(diff)
raise Exception(
"OpenAPI documentation is not up-to-date, run `python update_doc.py` in order to update it"
)
return True
else:
os.rename(tmp_filename, filename)
print("OpenAPI documentation updated.")
return True
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--check", action="store_true")
@ -130,6 +179,7 @@ def main():
check_cli(args.check)
check_supported_models(args.check)
check_openapi(args.check)
if __name__ == "__main__":