Merge branch 'main' into ci_amd3
This commit is contained in:
commit
8c590be463
|
@ -11,10 +11,30 @@ jobs:
|
|||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Rust
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
profile: minimal
|
||||
toolchain: stable
|
||||
|
||||
- name: Install Protocol Buffers compiler
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y protobuf-compiler libprotobuf-dev
|
||||
|
||||
- name: Install Launcher
|
||||
id: install-launcher
|
||||
run: cargo install --path launcher/
|
||||
- name: Check launcher Docs are up-to-date
|
||||
|
||||
- name: Install router
|
||||
id: install-router
|
||||
run: cargo install --path router/
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: '3.x'
|
||||
|
||||
- name: Check that documentation is up-to-date
|
||||
run: |
|
||||
echo text-generation-launcher --help
|
||||
python update_doc.py --check
|
||||
|
|
|
@ -11,6 +11,11 @@ on:
|
|||
# - rocm
|
||||
# - xpu
|
||||
required: true
|
||||
release-tests:
|
||||
description: "Run release integration tests"
|
||||
required: true
|
||||
default: false
|
||||
type: boolean
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
|
@ -195,7 +200,7 @@ jobs:
|
|||
runs-on: ["self-hosted", "${{ needs.build-and-push.outputs.runs_on }}", "multi-gpu"]
|
||||
if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest'
|
||||
env:
|
||||
PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main') && '--release' || '' }}
|
||||
PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || inputs.release-tests == true) && '--release' || '' }}
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
|
|
@ -20,7 +20,14 @@ on:
|
|||
- "Dockerfile_amd"
|
||||
- "Dockerfile_intel"
|
||||
branches:
|
||||
- 'main'
|
||||
- "main"
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
release-tests:
|
||||
description: "Run release integration tests"
|
||||
required: true
|
||||
default: false
|
||||
type: boolean
|
||||
|
||||
jobs:
|
||||
build:
|
||||
|
@ -33,4 +40,6 @@ jobs:
|
|||
uses: ./.github/workflows/build.yaml # calls the one above ^
|
||||
with:
|
||||
hardware: ${{ matrix.hardware }}
|
||||
# https://github.com/actions/runner/issues/2206
|
||||
release-tests: ${{ inputs.release-tests == true }}
|
||||
secrets: inherit
|
||||
|
|
|
@ -9,7 +9,7 @@ members = [
|
|||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
version = "2.1.1-dev0"
|
||||
version = "2.1.2-dev0"
|
||||
edition = "2021"
|
||||
authors = ["Olivier Dehaene"]
|
||||
homepage = "https://github.com/huggingface/text-generation-inference"
|
||||
|
|
32
Dockerfile
32
Dockerfile
|
@ -4,7 +4,7 @@ WORKDIR /usr/src
|
|||
|
||||
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
||||
|
||||
FROM chef as planner
|
||||
FROM chef AS planner
|
||||
COPY Cargo.lock Cargo.lock
|
||||
COPY Cargo.toml Cargo.toml
|
||||
COPY rust-toolchain.toml rust-toolchain.toml
|
||||
|
@ -38,7 +38,7 @@ RUN cargo build --profile release-opt
|
|||
|
||||
# Python builder
|
||||
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
|
||||
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 as pytorch-install
|
||||
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS pytorch-install
|
||||
|
||||
ARG PYTORCH_VERSION=2.3.0
|
||||
ARG PYTHON_VERSION=3.10
|
||||
|
@ -81,7 +81,7 @@ RUN case ${TARGETPLATFORM} in \
|
|||
/opt/conda/bin/conda clean -ya
|
||||
|
||||
# CUDA kernels builder image
|
||||
FROM pytorch-install as kernel-builder
|
||||
FROM pytorch-install AS kernel-builder
|
||||
|
||||
ARG MAX_JOBS=8
|
||||
|
||||
|
@ -90,7 +90,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
|
|||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Build Flash Attention CUDA kernels
|
||||
FROM kernel-builder as flash-att-builder
|
||||
FROM kernel-builder AS flash-att-builder
|
||||
|
||||
WORKDIR /usr/src
|
||||
|
||||
|
@ -100,7 +100,7 @@ COPY server/Makefile-flash-att Makefile
|
|||
RUN make build-flash-attention
|
||||
|
||||
# Build Flash Attention v2 CUDA kernels
|
||||
FROM kernel-builder as flash-att-v2-builder
|
||||
FROM kernel-builder AS flash-att-v2-builder
|
||||
|
||||
WORKDIR /usr/src
|
||||
|
||||
|
@ -110,14 +110,14 @@ COPY server/Makefile-flash-att-v2 Makefile
|
|||
RUN make build-flash-attention-v2-cuda
|
||||
|
||||
# Build Transformers exllama kernels
|
||||
FROM kernel-builder as exllama-kernels-builder
|
||||
FROM kernel-builder AS exllama-kernels-builder
|
||||
WORKDIR /usr/src
|
||||
COPY server/exllama_kernels/ .
|
||||
|
||||
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build
|
||||
|
||||
# Build Transformers exllama kernels
|
||||
FROM kernel-builder as exllamav2-kernels-builder
|
||||
FROM kernel-builder AS exllamav2-kernels-builder
|
||||
WORKDIR /usr/src
|
||||
COPY server/exllamav2_kernels/ .
|
||||
|
||||
|
@ -125,42 +125,42 @@ COPY server/exllamav2_kernels/ .
|
|||
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build
|
||||
|
||||
# Build Transformers awq kernels
|
||||
FROM kernel-builder as awq-kernels-builder
|
||||
FROM kernel-builder AS awq-kernels-builder
|
||||
WORKDIR /usr/src
|
||||
COPY server/Makefile-awq Makefile
|
||||
# Build specific version of transformers
|
||||
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-awq
|
||||
|
||||
# Build eetq kernels
|
||||
FROM kernel-builder as eetq-kernels-builder
|
||||
FROM kernel-builder AS eetq-kernels-builder
|
||||
WORKDIR /usr/src
|
||||
COPY server/Makefile-eetq Makefile
|
||||
# Build specific version of transformers
|
||||
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-eetq
|
||||
|
||||
# Build marlin kernels
|
||||
FROM kernel-builder as marlin-kernels-builder
|
||||
FROM kernel-builder AS marlin-kernels-builder
|
||||
WORKDIR /usr/src
|
||||
COPY server/marlin/ .
|
||||
# Build specific version of transformers
|
||||
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build
|
||||
|
||||
# Build Lorax Punica kernels
|
||||
FROM kernel-builder as lorax-punica-builder
|
||||
FROM kernel-builder AS lorax-punica-builder
|
||||
WORKDIR /usr/src
|
||||
COPY server/Makefile-lorax-punica Makefile
|
||||
# Build specific version of transformers
|
||||
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-lorax-punica
|
||||
|
||||
# Build Transformers CUDA kernels
|
||||
FROM kernel-builder as custom-kernels-builder
|
||||
FROM kernel-builder AS custom-kernels-builder
|
||||
WORKDIR /usr/src
|
||||
COPY server/custom_kernels/ .
|
||||
# Build specific version of transformers
|
||||
RUN python setup.py build
|
||||
|
||||
# Build vllm CUDA kernels
|
||||
FROM kernel-builder as vllm-builder
|
||||
FROM kernel-builder AS vllm-builder
|
||||
|
||||
WORKDIR /usr/src
|
||||
|
||||
|
@ -172,13 +172,13 @@ COPY server/Makefile-vllm Makefile
|
|||
RUN make build-vllm-cuda
|
||||
|
||||
# Build mamba kernels
|
||||
FROM kernel-builder as mamba-builder
|
||||
FROM kernel-builder AS mamba-builder
|
||||
WORKDIR /usr/src
|
||||
COPY server/Makefile-selective-scan Makefile
|
||||
RUN make build-all
|
||||
|
||||
# Text Generation Inference base image
|
||||
FROM nvidia/cuda:12.1.0-base-ubuntu22.04 as base
|
||||
FROM nvidia/cuda:12.1.0-base-ubuntu22.04 AS base
|
||||
|
||||
# Conda env
|
||||
ENV PATH=/opt/conda/bin:$PATH \
|
||||
|
@ -260,7 +260,7 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/lo
|
|||
|
||||
|
||||
# AWS Sagemaker compatible image
|
||||
FROM base as sagemaker
|
||||
FROM base AS sagemaker
|
||||
|
||||
COPY sagemaker-entrypoint.sh entrypoint.sh
|
||||
RUN chmod +x entrypoint.sh
|
||||
|
|
|
@ -4,7 +4,7 @@ WORKDIR /usr/src
|
|||
|
||||
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
||||
|
||||
FROM chef as planner
|
||||
FROM chef AS planner
|
||||
COPY Cargo.lock Cargo.lock
|
||||
COPY Cargo.toml Cargo.toml
|
||||
COPY rust-toolchain.toml rust-toolchain.toml
|
||||
|
@ -37,7 +37,7 @@ COPY launcher launcher
|
|||
RUN cargo build --profile release-opt
|
||||
|
||||
# Text Generation Inference base image for RoCm
|
||||
FROM rocm/dev-ubuntu-22.04:6.1.1_hip_update as base
|
||||
FROM rocm/dev-ubuntu-22.04:6.1.1_hip_update AS base
|
||||
|
||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
|
@ -118,7 +118,7 @@ ARG BUILD_CAFFE2="0" \
|
|||
|
||||
RUN cd pytorch && python tools/amd_build/build_amd.py && python setup.py install
|
||||
|
||||
# Set as recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm
|
||||
# Set AS recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm
|
||||
ENV HIP_FORCE_DEV_KERNARG=1
|
||||
|
||||
# On MI250 and MI300, performances for flash with Triton FA are slightly better than CK.
|
||||
|
@ -150,26 +150,26 @@ COPY server/Makefile-flash-att-v2 Makefile
|
|||
RUN make build-flash-attention-v2-rocm
|
||||
|
||||
# Build Transformers CUDA kernels (gpt-neox and bloom)
|
||||
FROM kernel-builder as custom-kernels-builder
|
||||
FROM kernel-builder AS custom-kernels-builder
|
||||
WORKDIR /usr/src
|
||||
COPY server/custom_kernels/ .
|
||||
RUN python setup.py build
|
||||
|
||||
# Build exllama kernels
|
||||
FROM kernel-builder as exllama-kernels-builder
|
||||
FROM kernel-builder AS exllama-kernels-builder
|
||||
WORKDIR /usr/src
|
||||
COPY server/exllama_kernels/ .
|
||||
|
||||
RUN python setup.py build
|
||||
|
||||
# Build exllama v2 kernels
|
||||
FROM kernel-builder as exllamav2-kernels-builder
|
||||
FROM kernel-builder AS exllamav2-kernels-builder
|
||||
WORKDIR /usr/src
|
||||
COPY server/exllamav2_kernels/ .
|
||||
|
||||
RUN python setup.py build
|
||||
|
||||
FROM base as base-copy
|
||||
FROM base AS base-copy
|
||||
|
||||
# Text Generation Inference base env
|
||||
ENV HUGGINGFACE_HUB_CACHE=/data \
|
||||
|
@ -208,7 +208,7 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca
|
|||
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||
|
||||
# AWS Sagemaker compatible image
|
||||
FROM base as sagemaker
|
||||
FROM base AS sagemaker
|
||||
|
||||
COPY sagemaker-entrypoint.sh entrypoint.sh
|
||||
RUN chmod +x entrypoint.sh
|
||||
|
|
|
@ -5,7 +5,7 @@ WORKDIR /usr/src
|
|||
|
||||
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
||||
|
||||
FROM chef as planner
|
||||
FROM chef AS planner
|
||||
COPY Cargo.lock Cargo.lock
|
||||
COPY Cargo.toml Cargo.toml
|
||||
COPY rust-toolchain.toml rust-toolchain.toml
|
||||
|
@ -40,7 +40,7 @@ RUN cargo build --profile release-opt
|
|||
|
||||
# Text Generation Inference base image for Intel
|
||||
|
||||
FROM intel/intel-extension-for-pytorch:2.1.30-xpu as xpu
|
||||
FROM intel/intel-extension-for-pytorch:2.1.30-xpu AS xpu
|
||||
|
||||
USER root
|
||||
# libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it
|
||||
|
@ -95,7 +95,7 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/lo
|
|||
|
||||
|
||||
# Text Generation Inference base image for Intel-cpu
|
||||
FROM ubuntu:22.04 as cpu
|
||||
FROM ubuntu:22.04 AS cpu
|
||||
|
||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||
curl \
|
||||
|
@ -172,6 +172,6 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca
|
|||
# Install launcher
|
||||
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||
|
||||
FROM ${PLATFORM} as final
|
||||
FROM ${PLATFORM} AS final
|
||||
ENTRYPOINT ["text-generation-launcher"]
|
||||
CMD ["--json-output"]
|
||||
|
|
|
@ -79,7 +79,7 @@ model=HuggingFaceH4/zephyr-7b-beta
|
|||
volume=$PWD/data
|
||||
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:2.1.0 --model-id $model
|
||||
ghcr.io/huggingface/text-generation-inference:2.1.1 --model-id $model
|
||||
```
|
||||
|
||||
And then you can make requests like
|
||||
|
@ -93,7 +93,7 @@ curl 127.0.0.1:8080/generate_stream \
|
|||
|
||||
**Note:** To use NVIDIA GPUs, you need to install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). We also recommend using NVIDIA drivers with CUDA version 12.2 or higher. For running the Docker container on a machine with no GPUs or CUDA support, it is enough to remove the `--gpus all` flag and add `--disable-custom-kernels`, please note CPU is not the intended platform for this project, so performance might be subpar.
|
||||
|
||||
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.1.0-rocm --model-id $model` instead of the command above.
|
||||
**Note:** TGI supports AMD Instinct MI210 and MI250 GPUs. Details can be found in the [Supported Hardware documentation](https://huggingface.co/docs/text-generation-inference/supported_models#supported-hardware). To use AMD GPUs, please use `docker run --device /dev/kfd --device /dev/dri --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.1.1-rocm --model-id $model` instead of the command above.
|
||||
|
||||
To see all options to serve your models (in the [code](https://github.com/huggingface/text-generation-inference/blob/main/launcher/src/main.rs) or in the cli):
|
||||
```
|
||||
|
|
|
@ -10,7 +10,7 @@
|
|||
"name": "Apache 2.0",
|
||||
"url": "https://www.apache.org/licenses/LICENSE-2.0"
|
||||
},
|
||||
"version": "2.0.1"
|
||||
"version": "2.1.2-dev0"
|
||||
},
|
||||
"paths": {
|
||||
"/": {
|
||||
|
@ -19,7 +19,6 @@
|
|||
"Text Generation Inference"
|
||||
],
|
||||
"summary": "Generate tokens if `stream == false` or a stream of token if `stream == true`",
|
||||
"description": "Generate tokens if `stream == false` or a stream of token if `stream == true`",
|
||||
"operationId": "compat_generate",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
|
@ -108,7 +107,6 @@
|
|||
"Text Generation Inference"
|
||||
],
|
||||
"summary": "Generate tokens",
|
||||
"description": "Generate tokens",
|
||||
"operationId": "generate",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
|
@ -192,7 +190,6 @@
|
|||
"Text Generation Inference"
|
||||
],
|
||||
"summary": "Generate a stream of token using Server-Sent Events",
|
||||
"description": "Generate a stream of token using Server-Sent Events",
|
||||
"operationId": "generate_stream",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
|
@ -276,7 +273,6 @@
|
|||
"Text Generation Inference"
|
||||
],
|
||||
"summary": "Health check method",
|
||||
"description": "Health check method",
|
||||
"operationId": "health",
|
||||
"responses": {
|
||||
"200": {
|
||||
|
@ -305,7 +301,6 @@
|
|||
"Text Generation Inference"
|
||||
],
|
||||
"summary": "Text Generation Inference endpoint info",
|
||||
"description": "Text Generation Inference endpoint info",
|
||||
"operationId": "get_model_info",
|
||||
"responses": {
|
||||
"200": {
|
||||
|
@ -327,7 +322,6 @@
|
|||
"Text Generation Inference"
|
||||
],
|
||||
"summary": "Prometheus metrics scrape endpoint",
|
||||
"description": "Prometheus metrics scrape endpoint",
|
||||
"operationId": "metrics",
|
||||
"responses": {
|
||||
"200": {
|
||||
|
@ -349,7 +343,6 @@
|
|||
"Text Generation Inference"
|
||||
],
|
||||
"summary": "Tokenize inputs",
|
||||
"description": "Tokenize inputs",
|
||||
"operationId": "tokenize",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
|
@ -394,7 +387,6 @@
|
|||
"Text Generation Inference"
|
||||
],
|
||||
"summary": "Generate tokens",
|
||||
"description": "Generate tokens",
|
||||
"operationId": "chat_completions",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
|
@ -483,7 +475,6 @@
|
|||
"Text Generation Inference"
|
||||
],
|
||||
"summary": "Generate tokens",
|
||||
"description": "Generate tokens",
|
||||
"operationId": "completions",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
|
@ -626,7 +617,6 @@
|
|||
"type": "object",
|
||||
"required": [
|
||||
"id",
|
||||
"object",
|
||||
"created",
|
||||
"model",
|
||||
"system_fingerprint",
|
||||
|
@ -653,9 +643,6 @@
|
|||
"type": "string",
|
||||
"example": "mistralai/Mistral-7B-Instruct-v0.2"
|
||||
},
|
||||
"object": {
|
||||
"type": "string"
|
||||
},
|
||||
"system_fingerprint": {
|
||||
"type": "string"
|
||||
},
|
||||
|
@ -697,7 +684,6 @@
|
|||
"type": "object",
|
||||
"required": [
|
||||
"id",
|
||||
"object",
|
||||
"created",
|
||||
"model",
|
||||
"system_fingerprint",
|
||||
|
@ -723,9 +709,6 @@
|
|||
"type": "string",
|
||||
"example": "mistralai/Mistral-7B-Instruct-v0.2"
|
||||
},
|
||||
"object": {
|
||||
"type": "string"
|
||||
},
|
||||
"system_fingerprint": {
|
||||
"type": "string"
|
||||
}
|
||||
|
@ -756,34 +739,19 @@
|
|||
"nullable": true
|
||||
},
|
||||
"message": {
|
||||
"$ref": "#/components/schemas/Message"
|
||||
"$ref": "#/components/schemas/OutputMessage"
|
||||
}
|
||||
}
|
||||
},
|
||||
"ChatCompletionDelta": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"role"
|
||||
],
|
||||
"properties": {
|
||||
"content": {
|
||||
"type": "string",
|
||||
"example": "What is Deep Learning?",
|
||||
"nullable": true
|
||||
"oneOf": [
|
||||
{
|
||||
"$ref": "#/components/schemas/TextMessage"
|
||||
},
|
||||
"role": {
|
||||
"type": "string",
|
||||
"example": "user"
|
||||
},
|
||||
"tool_calls": {
|
||||
"allOf": [
|
||||
{
|
||||
"$ref": "#/components/schemas/DeltaToolCall"
|
||||
}
|
||||
],
|
||||
"nullable": true
|
||||
{
|
||||
"$ref": "#/components/schemas/ToolCallDelta"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"ChatCompletionLogprob": {
|
||||
"type": "object",
|
||||
|
@ -903,6 +871,15 @@
|
|||
"example": 0.1,
|
||||
"nullable": true
|
||||
},
|
||||
"response_format": {
|
||||
"allOf": [
|
||||
{
|
||||
"$ref": "#/components/schemas/GrammarType"
|
||||
}
|
||||
],
|
||||
"default": "null",
|
||||
"nullable": true
|
||||
},
|
||||
"seed": {
|
||||
"type": "integer",
|
||||
"format": "int64",
|
||||
|
@ -969,6 +946,38 @@
|
|||
}
|
||||
}
|
||||
},
|
||||
"Chunk": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"id",
|
||||
"created",
|
||||
"choices",
|
||||
"model",
|
||||
"system_fingerprint"
|
||||
],
|
||||
"properties": {
|
||||
"choices": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/CompletionComplete"
|
||||
}
|
||||
},
|
||||
"created": {
|
||||
"type": "integer",
|
||||
"format": "int64",
|
||||
"minimum": 0
|
||||
},
|
||||
"id": {
|
||||
"type": "string"
|
||||
},
|
||||
"model": {
|
||||
"type": "string"
|
||||
},
|
||||
"system_fingerprint": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"CompatGenerateRequest": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
|
@ -988,6 +997,55 @@
|
|||
}
|
||||
}
|
||||
},
|
||||
"Completion": {
|
||||
"oneOf": [
|
||||
{
|
||||
"allOf": [
|
||||
{
|
||||
"$ref": "#/components/schemas/Chunk"
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"required": [
|
||||
"object"
|
||||
],
|
||||
"properties": {
|
||||
"object": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"text_completion"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"allOf": [
|
||||
{
|
||||
"$ref": "#/components/schemas/CompletionFinal"
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"required": [
|
||||
"object"
|
||||
],
|
||||
"properties": {
|
||||
"object": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"text_completion"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"discriminator": {
|
||||
"propertyName": "object"
|
||||
}
|
||||
},
|
||||
"CompletionComplete": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
|
@ -1017,15 +1075,15 @@
|
|||
}
|
||||
}
|
||||
},
|
||||
"CompletionCompleteChunk": {
|
||||
"CompletionFinal": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"id",
|
||||
"object",
|
||||
"created",
|
||||
"choices",
|
||||
"model",
|
||||
"system_fingerprint"
|
||||
"system_fingerprint",
|
||||
"choices",
|
||||
"usage"
|
||||
],
|
||||
"properties": {
|
||||
"choices": {
|
||||
|
@ -1037,19 +1095,21 @@
|
|||
"created": {
|
||||
"type": "integer",
|
||||
"format": "int64",
|
||||
"example": "1706270835",
|
||||
"minimum": 0
|
||||
},
|
||||
"id": {
|
||||
"type": "string"
|
||||
},
|
||||
"model": {
|
||||
"type": "string"
|
||||
},
|
||||
"object": {
|
||||
"type": "string"
|
||||
"type": "string",
|
||||
"example": "mistralai/Mistral-7B-Instruct-v0.2"
|
||||
},
|
||||
"system_fingerprint": {
|
||||
"type": "string"
|
||||
},
|
||||
"usage": {
|
||||
"$ref": "#/components/schemas/Usage"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
@ -1081,12 +1141,7 @@
|
|||
"example": "mistralai/Mistral-7B-Instruct-v0.2"
|
||||
},
|
||||
"prompt": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"description": "The prompt to generate completions for.",
|
||||
"example": "What is Deep Learning?"
|
||||
"$ref": "#/components/schemas/Prompt"
|
||||
},
|
||||
"repetition_penalty": {
|
||||
"type": "number",
|
||||
|
@ -1100,6 +1155,15 @@
|
|||
"nullable": true,
|
||||
"minimum": 0
|
||||
},
|
||||
"stop": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"description": "Up to 4 sequences where the API will stop generating further tokens.",
|
||||
"example": "null",
|
||||
"nullable": true
|
||||
},
|
||||
"stream": {
|
||||
"type": "boolean"
|
||||
},
|
||||
|
@ -1121,15 +1185,6 @@
|
|||
"description": "An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the\ntokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.",
|
||||
"example": 0.95,
|
||||
"nullable": true
|
||||
},
|
||||
"stop": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"description": "Up to 4 sequences where the API will stop generating further tokens.",
|
||||
"example": "null",
|
||||
"nullable": true
|
||||
}
|
||||
}
|
||||
},
|
||||
|
@ -1272,8 +1327,16 @@
|
|||
"GenerateParameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"adapter_id": {
|
||||
"type": "string",
|
||||
"description": "Lora adapter id",
|
||||
"default": "null",
|
||||
"example": "null",
|
||||
"nullable": true
|
||||
},
|
||||
"best_of": {
|
||||
"type": "integer",
|
||||
"description": "Generate best_of sequences and return the one if the highest token logprobs.",
|
||||
"default": "null",
|
||||
"example": 1,
|
||||
"nullable": true,
|
||||
|
@ -1282,20 +1345,24 @@
|
|||
},
|
||||
"decoder_input_details": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to return decoder input token logprobs and ids.",
|
||||
"default": "false"
|
||||
},
|
||||
"details": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to return generation details.",
|
||||
"default": "true"
|
||||
},
|
||||
"do_sample": {
|
||||
"type": "boolean",
|
||||
"description": "Activate logits sampling.",
|
||||
"default": "false",
|
||||
"example": true
|
||||
},
|
||||
"frequency_penalty": {
|
||||
"type": "number",
|
||||
"format": "float",
|
||||
"description": "The parameter for frequency penalty. 1.0 means no penalty\nPenalize new tokens based on their existing frequency in the text so far,\ndecreasing the model's likelihood to repeat the same line verbatim.",
|
||||
"default": "null",
|
||||
"example": 0.1,
|
||||
"nullable": true,
|
||||
|
@ -1313,6 +1380,7 @@
|
|||
"max_new_tokens": {
|
||||
"type": "integer",
|
||||
"format": "int32",
|
||||
"description": "Maximum number of tokens to generate.",
|
||||
"default": "100",
|
||||
"example": "20",
|
||||
"nullable": true,
|
||||
|
@ -1321,6 +1389,7 @@
|
|||
"repetition_penalty": {
|
||||
"type": "number",
|
||||
"format": "float",
|
||||
"description": "The parameter for repetition penalty. 1.0 means no penalty.\nSee [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.",
|
||||
"default": "null",
|
||||
"example": 1.03,
|
||||
"nullable": true,
|
||||
|
@ -1328,6 +1397,7 @@
|
|||
},
|
||||
"return_full_text": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to prepend the prompt to the generated text",
|
||||
"default": "null",
|
||||
"example": false,
|
||||
"nullable": true
|
||||
|
@ -1335,6 +1405,7 @@
|
|||
"seed": {
|
||||
"type": "integer",
|
||||
"format": "int64",
|
||||
"description": "Random sampling seed.",
|
||||
"default": "null",
|
||||
"example": "null",
|
||||
"nullable": true,
|
||||
|
@ -1346,6 +1417,7 @@
|
|||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"description": "Stop generating tokens if a member of `stop` is generated.",
|
||||
"example": [
|
||||
"photographer"
|
||||
],
|
||||
|
@ -1354,6 +1426,7 @@
|
|||
"temperature": {
|
||||
"type": "number",
|
||||
"format": "float",
|
||||
"description": "The value used to module the logits distribution.",
|
||||
"default": "null",
|
||||
"example": 0.5,
|
||||
"nullable": true,
|
||||
|
@ -1362,6 +1435,7 @@
|
|||
"top_k": {
|
||||
"type": "integer",
|
||||
"format": "int32",
|
||||
"description": "The number of highest probability vocabulary tokens to keep for top-k-filtering.",
|
||||
"default": "null",
|
||||
"example": 10,
|
||||
"nullable": true,
|
||||
|
@ -1370,6 +1444,7 @@
|
|||
"top_n_tokens": {
|
||||
"type": "integer",
|
||||
"format": "int32",
|
||||
"description": "The number of highest probability vocabulary tokens to keep for top-n-filtering.",
|
||||
"default": "null",
|
||||
"example": 5,
|
||||
"nullable": true,
|
||||
|
@ -1379,6 +1454,7 @@
|
|||
"top_p": {
|
||||
"type": "number",
|
||||
"format": "float",
|
||||
"description": "Top-p value for nucleus sampling.",
|
||||
"default": "null",
|
||||
"example": 0.95,
|
||||
"nullable": true,
|
||||
|
@ -1387,6 +1463,7 @@
|
|||
},
|
||||
"truncate": {
|
||||
"type": "integer",
|
||||
"description": "Truncate inputs tokens to the given size.",
|
||||
"default": "null",
|
||||
"example": "null",
|
||||
"nullable": true,
|
||||
|
@ -1395,6 +1472,7 @@
|
|||
"typical_p": {
|
||||
"type": "number",
|
||||
"format": "float",
|
||||
"description": "Typical Decoding mass\nSee [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information.",
|
||||
"default": "null",
|
||||
"example": 0.95,
|
||||
"nullable": true,
|
||||
|
@ -1403,6 +1481,7 @@
|
|||
},
|
||||
"watermark": {
|
||||
"type": "boolean",
|
||||
"description": "Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226).",
|
||||
"default": "false",
|
||||
"example": true
|
||||
}
|
||||
|
@ -1495,13 +1574,14 @@
|
|||
"max_concurrent_requests",
|
||||
"max_best_of",
|
||||
"max_stop_sequences",
|
||||
"max_input_length",
|
||||
"max_input_tokens",
|
||||
"max_total_tokens",
|
||||
"waiting_served_ratio",
|
||||
"max_batch_total_tokens",
|
||||
"max_waiting_tokens",
|
||||
"validation_workers",
|
||||
"max_client_batch_size",
|
||||
"router",
|
||||
"version"
|
||||
],
|
||||
"properties": {
|
||||
|
@ -1538,7 +1618,7 @@
|
|||
"example": "128",
|
||||
"minimum": 0
|
||||
},
|
||||
"max_input_length": {
|
||||
"max_input_tokens": {
|
||||
"type": "integer",
|
||||
"example": "1024",
|
||||
"minimum": 0
|
||||
|
@ -1581,6 +1661,11 @@
|
|||
"example": "e985a63cdc139290c5f700ff1929f0b5942cced2",
|
||||
"nullable": true
|
||||
},
|
||||
"router": {
|
||||
"type": "string",
|
||||
"description": "Router Info",
|
||||
"example": "text-generation-router"
|
||||
},
|
||||
"sha": {
|
||||
"type": "string",
|
||||
"example": "null",
|
||||
|
@ -1593,7 +1678,6 @@
|
|||
},
|
||||
"version": {
|
||||
"type": "string",
|
||||
"description": "Router Info",
|
||||
"example": "0.5.0"
|
||||
},
|
||||
"waiting_served_ratio": {
|
||||
|
@ -1606,13 +1690,12 @@
|
|||
"Message": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"role"
|
||||
"role",
|
||||
"content"
|
||||
],
|
||||
"properties": {
|
||||
"content": {
|
||||
"type": "string",
|
||||
"example": "My name is David and I",
|
||||
"nullable": true
|
||||
"$ref": "#/components/schemas/MessageContent"
|
||||
},
|
||||
"name": {
|
||||
"type": "string",
|
||||
|
@ -1622,13 +1705,6 @@
|
|||
"role": {
|
||||
"type": "string",
|
||||
"example": "user"
|
||||
},
|
||||
"tool_calls": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/ToolCall"
|
||||
},
|
||||
"nullable": true
|
||||
}
|
||||
}
|
||||
},
|
||||
|
@ -1658,6 +1734,12 @@
|
|||
}
|
||||
}
|
||||
},
|
||||
"Prompt": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"SimpleToken": {
|
||||
"type": "object",
|
||||
"required": [
|
||||
|
@ -1817,9 +1899,7 @@
|
|||
"$ref": "#/components/schemas/FunctionDefinition"
|
||||
},
|
||||
"id": {
|
||||
"type": "integer",
|
||||
"format": "int32",
|
||||
"minimum": 0
|
||||
"type": "string"
|
||||
},
|
||||
"type": {
|
||||
"type": "string"
|
||||
|
@ -1830,20 +1910,22 @@
|
|||
"oneOf": [
|
||||
{
|
||||
"type": "object",
|
||||
"required": [
|
||||
"FunctionName"
|
||||
],
|
||||
"properties": {
|
||||
"FunctionName": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
"default": null,
|
||||
"nullable": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"OneOf"
|
||||
]
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"required": [
|
||||
"function"
|
||||
],
|
||||
"properties": {
|
||||
"function": {
|
||||
"$ref": "#/components/schemas/FunctionName"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
|
|
|
@ -11,7 +11,7 @@ volume=$PWD/data # share a volume with the Docker container to avoid downloading
|
|||
docker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
|
||||
--device=/dev/kfd --device=/dev/dri --group-add video \
|
||||
--ipc=host --shm-size 256g --net host -v $volume:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:2.1.0-rocm \
|
||||
ghcr.io/huggingface/text-generation-inference:2.1.1-rocm \
|
||||
--model-id $model
|
||||
```
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B
|
|||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||
|
||||
docker run --gpus all --shm-size 64g -p 8080:80 -v $volume:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:2.1.0 \
|
||||
ghcr.io/huggingface/text-generation-inference:2.1.1 \
|
||||
--model-id $model
|
||||
```
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ model=teknium/OpenHermes-2.5-Mistral-7B
|
|||
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
|
||||
|
||||
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:2.1.0 \
|
||||
ghcr.io/huggingface/text-generation-inference:2.1.1 \
|
||||
--model-id $model
|
||||
```
|
||||
|
||||
|
@ -88,7 +88,7 @@ curl 127.0.0.1:8080/generate \
|
|||
To see all possible deploy flags and options, you can use the `--help` flag. It's possible to configure the number of shards, quantization, generation parameters, and more.
|
||||
|
||||
```bash
|
||||
docker run ghcr.io/huggingface/text-generation-inference:2.1.0 --help
|
||||
docker run ghcr.io/huggingface/text-generation-inference:2.1.1 --help
|
||||
```
|
||||
|
||||
</Tip>
|
||||
|
|
|
@ -10,6 +10,7 @@ Text Generation Inference enables serving optimized models on specific hardware
|
|||
- [Llama](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)
|
||||
- [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)
|
||||
- [Gemma](https://huggingface.co/google/gemma-7b)
|
||||
- [PaliGemma](https://huggingface.co/google/paligemma-3b-pt-224)
|
||||
- [Gemma2](https://huggingface.co/google/gemma2-9b)
|
||||
- [Cohere](https://huggingface.co/CohereForAI/c4ai-command-r-plus)
|
||||
- [Dbrx](https://huggingface.co/databricks/dbrx-instruct)
|
||||
|
|
|
@ -5,85 +5,80 @@
|
|||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"id": 2323,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 4321,
|
||||
"logprob": -9.7890625,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -9.625,
|
||||
"text": "request"
|
||||
"id": 1715,
|
||||
"logprob": -11.34375,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -2.3359375,
|
||||
"id": 198,
|
||||
"logprob": -2.5742188,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -1.8779297,
|
||||
"id": 262,
|
||||
"logprob": -1.6230469,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -1.2744141,
|
||||
"id": 3270,
|
||||
"logprob": -2.046875,
|
||||
"special": false,
|
||||
"text": " \"\"\"\n"
|
||||
},
|
||||
{
|
||||
"id": 262,
|
||||
"logprob": -0.015281677,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 422,
|
||||
"logprob": -2.1425781,
|
||||
"special": false,
|
||||
"text": " if"
|
||||
},
|
||||
{
|
||||
"id": 1715,
|
||||
"logprob": -0.9238281,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -1.6933594,
|
||||
"id": 13204,
|
||||
"logprob": -0.076660156,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
"text": ".method"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -1.4648438,
|
||||
"id": 624,
|
||||
"logprob": -0.021987915,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
"text": " =="
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -0.15600586,
|
||||
"id": 364,
|
||||
"logprob": -0.39208984,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
"text": " '"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.8027344,
|
||||
"id": 3019,
|
||||
"logprob": -0.10821533,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -0.23022461,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -0.0069885254,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.02218628,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
"text": "POST"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\nTest request\nTest request\nTest request\n"
|
||||
"generated_text": "\n \"\"\"\n if request.method == 'POST"
|
||||
}
|
||||
|
|
|
@ -5,85 +5,80 @@
|
|||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"id": 2323,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 4321,
|
||||
"logprob": -9.84375,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -9.6015625,
|
||||
"text": "request"
|
||||
"id": 1715,
|
||||
"logprob": -11.34375,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"seed": 0,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 29899,
|
||||
"logprob": -1.5625,
|
||||
"id": 13,
|
||||
"logprob": -2.2539062,
|
||||
"special": false,
|
||||
"text": "-"
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 1454,
|
||||
"logprob": -0.20410156,
|
||||
"id": 578,
|
||||
"logprob": -0.15563965,
|
||||
"special": false,
|
||||
"text": "for"
|
||||
"text": " The"
|
||||
},
|
||||
{
|
||||
"id": 29899,
|
||||
"id": 3622,
|
||||
"logprob": -0.8203125,
|
||||
"special": false,
|
||||
"text": " server"
|
||||
},
|
||||
{
|
||||
"id": 706,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "-"
|
||||
"text": " has"
|
||||
},
|
||||
{
|
||||
"id": 9342,
|
||||
"id": 539,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "comment"
|
||||
"text": " not"
|
||||
},
|
||||
{
|
||||
"id": 29901,
|
||||
"id": 3686,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": ":"
|
||||
"text": " yet"
|
||||
},
|
||||
{
|
||||
"id": 396,
|
||||
"logprob": -0.27685547,
|
||||
"special": false,
|
||||
"text": " #"
|
||||
},
|
||||
{
|
||||
"id": 29906,
|
||||
"logprob": -0.4970703,
|
||||
"special": false,
|
||||
"text": "2"
|
||||
},
|
||||
{
|
||||
"id": 29900,
|
||||
"logprob": -0.80615234,
|
||||
"special": false,
|
||||
"text": "0"
|
||||
},
|
||||
{
|
||||
"id": 29896,
|
||||
"id": 3288,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "1"
|
||||
"text": " sent"
|
||||
},
|
||||
{
|
||||
"id": 29955,
|
||||
"logprob": -1.0751953,
|
||||
"id": 904,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "7"
|
||||
"text": " any"
|
||||
},
|
||||
{
|
||||
"id": 828,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " data"
|
||||
},
|
||||
{
|
||||
"id": 382,
|
||||
"logprob": -1.5517578,
|
||||
"special": false,
|
||||
"text": ".\n\n"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "Test request-for-comment: #2017"
|
||||
"generated_text": "Test request. The server has not yet sent any data.\n\n"
|
||||
}
|
||||
|
|
|
@ -6,87 +6,82 @@
|
|||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"id": 2323,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 4321,
|
||||
"logprob": -9.828125,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -9.609375,
|
||||
"text": "request"
|
||||
"id": 1715,
|
||||
"logprob": -11.34375,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -2.3300781,
|
||||
"id": 198,
|
||||
"logprob": -2.5742188,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -1.8740234,
|
||||
"id": 262,
|
||||
"logprob": -1.6220703,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -1.2646484,
|
||||
"id": 3270,
|
||||
"logprob": -2.0410156,
|
||||
"special": false,
|
||||
"text": " \"\"\"\n"
|
||||
},
|
||||
{
|
||||
"id": 262,
|
||||
"logprob": -0.015281677,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 422,
|
||||
"logprob": -2.1445312,
|
||||
"special": false,
|
||||
"text": " if"
|
||||
},
|
||||
{
|
||||
"id": 1715,
|
||||
"logprob": -0.92333984,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -1.7158203,
|
||||
"id": 13204,
|
||||
"logprob": -0.07672119,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
"text": ".method"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -1.4667969,
|
||||
"id": 624,
|
||||
"logprob": -0.021987915,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
"text": " =="
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -0.15344238,
|
||||
"id": 364,
|
||||
"logprob": -0.39208984,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
"text": " '"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.81591797,
|
||||
"id": 3019,
|
||||
"logprob": -0.10638428,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -0.22973633,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -0.007045746,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.021957397,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
"text": "POST"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\nTest request\nTest request\nTest request\n"
|
||||
"generated_text": "\n \"\"\"\n if request.method == 'POST"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
|
@ -95,87 +90,82 @@
|
|||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"id": 2323,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 4321,
|
||||
"logprob": -9.84375,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -9.59375,
|
||||
"text": "request"
|
||||
"id": 1715,
|
||||
"logprob": -11.34375,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -2.3378906,
|
||||
"id": 198,
|
||||
"logprob": -2.5742188,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -1.8779297,
|
||||
"id": 262,
|
||||
"logprob": -1.6220703,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -1.2636719,
|
||||
"id": 3270,
|
||||
"logprob": -2.0410156,
|
||||
"special": false,
|
||||
"text": " \"\"\"\n"
|
||||
},
|
||||
{
|
||||
"id": 262,
|
||||
"logprob": -0.015281677,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 422,
|
||||
"logprob": -2.1445312,
|
||||
"special": false,
|
||||
"text": " if"
|
||||
},
|
||||
{
|
||||
"id": 1715,
|
||||
"logprob": -0.92333984,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -1.6992188,
|
||||
"id": 13204,
|
||||
"logprob": -0.07672119,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
"text": ".method"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -1.4589844,
|
||||
"id": 624,
|
||||
"logprob": -0.021987915,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
"text": " =="
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -0.15344238,
|
||||
"id": 364,
|
||||
"logprob": -0.39208984,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
"text": " '"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.79052734,
|
||||
"id": 3019,
|
||||
"logprob": -0.10638428,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -0.22937012,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -0.007041931,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.022140503,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
"text": "POST"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\nTest request\nTest request\nTest request\n"
|
||||
"generated_text": "\n \"\"\"\n if request.method == 'POST"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
|
@ -184,87 +174,82 @@
|
|||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"id": 2323,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 4321,
|
||||
"logprob": -9.84375,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -9.609375,
|
||||
"text": "request"
|
||||
"id": 1715,
|
||||
"logprob": -11.34375,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -2.3261719,
|
||||
"id": 198,
|
||||
"logprob": -2.5742188,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -1.8730469,
|
||||
"id": 262,
|
||||
"logprob": -1.6220703,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -1.2587891,
|
||||
"id": 3270,
|
||||
"logprob": -2.0410156,
|
||||
"special": false,
|
||||
"text": " \"\"\"\n"
|
||||
},
|
||||
{
|
||||
"id": 262,
|
||||
"logprob": -0.015281677,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 422,
|
||||
"logprob": -2.1445312,
|
||||
"special": false,
|
||||
"text": " if"
|
||||
},
|
||||
{
|
||||
"id": 1715,
|
||||
"logprob": -0.92333984,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -1.6894531,
|
||||
"id": 13204,
|
||||
"logprob": -0.07672119,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
"text": ".method"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -1.46875,
|
||||
"id": 624,
|
||||
"logprob": -0.021987915,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
"text": " =="
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -0.1541748,
|
||||
"id": 364,
|
||||
"logprob": -0.39208984,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
"text": " '"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.80322266,
|
||||
"id": 3019,
|
||||
"logprob": -0.10638428,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -0.22912598,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -0.0070495605,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.021606445,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
"text": "POST"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\nTest request\nTest request\nTest request\n"
|
||||
"generated_text": "\n \"\"\"\n if request.method == 'POST"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
|
@ -273,86 +258,81 @@
|
|||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"id": 2323,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 4321,
|
||||
"logprob": -9.84375,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -9.6015625,
|
||||
"text": "request"
|
||||
"id": 1715,
|
||||
"logprob": -11.34375,
|
||||
"text": " request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -2.3320312,
|
||||
"id": 198,
|
||||
"logprob": -2.5742188,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -1.875,
|
||||
"id": 262,
|
||||
"logprob": -1.6220703,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -1.2646484,
|
||||
"id": 3270,
|
||||
"logprob": -2.0410156,
|
||||
"special": false,
|
||||
"text": " \"\"\"\n"
|
||||
},
|
||||
{
|
||||
"id": 262,
|
||||
"logprob": -0.015281677,
|
||||
"special": false,
|
||||
"text": " "
|
||||
},
|
||||
{
|
||||
"id": 422,
|
||||
"logprob": -2.1445312,
|
||||
"special": false,
|
||||
"text": " if"
|
||||
},
|
||||
{
|
||||
"id": 1715,
|
||||
"logprob": -0.92333984,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -1.6884766,
|
||||
"id": 13204,
|
||||
"logprob": -0.07672119,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
"text": ".method"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -1.4589844,
|
||||
"id": 624,
|
||||
"logprob": -0.021987915,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
"text": " =="
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -0.15185547,
|
||||
"id": 364,
|
||||
"logprob": -0.39208984,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
"text": " '"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.79833984,
|
||||
"id": 3019,
|
||||
"logprob": -0.10638428,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -0.22827148,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -0.006996155,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.021560669,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
"text": "POST"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\nTest request\nTest request\nTest request\n"
|
||||
"generated_text": "\n \"\"\"\n if request.method == 'POST"
|
||||
}
|
||||
]
|
||||
|
|
|
@ -1,130 +1,124 @@
|
|||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 20,
|
||||
"finish_reason": "eos_token",
|
||||
"generated_tokens": 19,
|
||||
"prefill": [],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 415,
|
||||
"logprob": -0.039886475,
|
||||
"logprob": -0.03665161,
|
||||
"special": false,
|
||||
"text": " The"
|
||||
},
|
||||
{
|
||||
"id": 12072,
|
||||
"logprob": -0.1430664,
|
||||
"logprob": -0.13549805,
|
||||
"special": false,
|
||||
"text": " cow"
|
||||
},
|
||||
{
|
||||
"id": 349,
|
||||
"logprob": -0.056488037,
|
||||
"logprob": -0.05819702,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 6328,
|
||||
"logprob": -0.6855469,
|
||||
"logprob": -0.6826172,
|
||||
"special": false,
|
||||
"text": " standing"
|
||||
},
|
||||
{
|
||||
"id": 356,
|
||||
"logprob": -0.1685791,
|
||||
"logprob": -0.1607666,
|
||||
"special": false,
|
||||
"text": " on"
|
||||
},
|
||||
{
|
||||
"id": 272,
|
||||
"logprob": -0.50097656,
|
||||
"logprob": -0.5073242,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 10305,
|
||||
"logprob": -0.017303467,
|
||||
"logprob": -0.016418457,
|
||||
"special": false,
|
||||
"text": " beach"
|
||||
},
|
||||
{
|
||||
"id": 304,
|
||||
"logprob": -1.3564453,
|
||||
"logprob": -1.3916016,
|
||||
"special": false,
|
||||
"text": " and"
|
||||
},
|
||||
{
|
||||
"id": 272,
|
||||
"logprob": -0.017868042,
|
||||
"logprob": -0.020217896,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 13088,
|
||||
"logprob": -0.0027103424,
|
||||
"logprob": -0.0028133392,
|
||||
"special": false,
|
||||
"text": " chicken"
|
||||
},
|
||||
{
|
||||
"id": 349,
|
||||
"logprob": -0.003156662,
|
||||
"logprob": -0.003145218,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 6398,
|
||||
"logprob": -0.37304688,
|
||||
"logprob": -0.37060547,
|
||||
"special": false,
|
||||
"text": " sitting"
|
||||
},
|
||||
{
|
||||
"id": 356,
|
||||
"logprob": -0.034576416,
|
||||
"logprob": -0.034851074,
|
||||
"special": false,
|
||||
"text": " on"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": -0.29418945,
|
||||
"logprob": -0.2878418,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 17972,
|
||||
"logprob": -0.042877197,
|
||||
"logprob": -0.046051025,
|
||||
"special": false,
|
||||
"text": " pile"
|
||||
},
|
||||
{
|
||||
"id": 302,
|
||||
"logprob": -0.00028443336,
|
||||
"logprob": -0.00028848648,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 2445,
|
||||
"logprob": -0.023223877,
|
||||
"logprob": -0.025772095,
|
||||
"special": false,
|
||||
"text": " money"
|
||||
},
|
||||
{
|
||||
"id": 28723,
|
||||
"logprob": -0.018157959,
|
||||
"logprob": -0.018127441,
|
||||
"special": false,
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 32002,
|
||||
"logprob": -0.00018393993,
|
||||
"logprob": -0.00019824505,
|
||||
"special": true,
|
||||
"text": "<end_of_utterance>"
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"logprob": -1.1920929e-07,
|
||||
"special": true,
|
||||
"text": "</s>"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
|
|
|
@ -0,0 +1,89 @@
|
|||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 4321,
|
||||
"logprob": -9.8359375,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -9.6171875,
|
||||
"text": "request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -2.3417969,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -1.8730469,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -1.2626953,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -1.7060547,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -1.4482422,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -0.15246582,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.796875,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -0.22766113,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -0.007045746,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.021759033,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\nTest request\nTest request\nTest request\n"
|
||||
}
|
|
@ -0,0 +1,89 @@
|
|||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 4321,
|
||||
"logprob": -9.7890625,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -9.625,
|
||||
"text": "request"
|
||||
}
|
||||
],
|
||||
"seed": 0,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 29899,
|
||||
"logprob": -1.4980469,
|
||||
"special": false,
|
||||
"text": "-"
|
||||
},
|
||||
{
|
||||
"id": 1454,
|
||||
"logprob": -0.19433594,
|
||||
"special": false,
|
||||
"text": "for"
|
||||
},
|
||||
{
|
||||
"id": 29899,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "-"
|
||||
},
|
||||
{
|
||||
"id": 9342,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "comment"
|
||||
},
|
||||
{
|
||||
"id": 29901,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 396,
|
||||
"logprob": -0.27392578,
|
||||
"special": false,
|
||||
"text": " #"
|
||||
},
|
||||
{
|
||||
"id": 29906,
|
||||
"logprob": -0.49389648,
|
||||
"special": false,
|
||||
"text": "2"
|
||||
},
|
||||
{
|
||||
"id": 29900,
|
||||
"logprob": -0.81103516,
|
||||
"special": false,
|
||||
"text": "0"
|
||||
},
|
||||
{
|
||||
"id": 29896,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "1"
|
||||
},
|
||||
{
|
||||
"id": 29955,
|
||||
"logprob": -1.0800781,
|
||||
"special": false,
|
||||
"text": "7"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "Test request-for-comment: #2017"
|
||||
}
|
|
@ -0,0 +1,358 @@
|
|||
[
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 4321,
|
||||
"logprob": -9.8828125,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -9.5859375,
|
||||
"text": "request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -2.3359375,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -1.8623047,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -1.2451172,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -1.6923828,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -1.4492188,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -0.15197754,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.8022461,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -0.22583008,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -0.007095337,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.021652222,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\nTest request\nTest request\nTest request\n"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 4321,
|
||||
"logprob": -9.796875,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -9.625,
|
||||
"text": "request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -2.3476562,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -1.8789062,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -1.2734375,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -1.703125,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -1.4677734,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -0.15454102,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.7973633,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -0.23278809,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -0.006980896,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.022033691,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\nTest request\nTest request\nTest request\n"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 4321,
|
||||
"logprob": -9.9296875,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -9.5703125,
|
||||
"text": "request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -2.3203125,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -1.8486328,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -1.2480469,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -1.7060547,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -1.4511719,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -0.1529541,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.81396484,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -0.22180176,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -0.007133484,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.021835327,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\nTest request\nTest request\nTest request\n"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 4321,
|
||||
"logprob": -9.84375,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -9.6171875,
|
||||
"text": "request"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -2.3261719,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -1.8691406,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -1.2597656,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -1.7070312,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -1.4550781,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -0.1538086,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.79345703,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 3057,
|
||||
"logprob": -0.22924805,
|
||||
"special": false,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2009,
|
||||
"logprob": -0.0070266724,
|
||||
"special": false,
|
||||
"text": " request"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.021942139,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\nTest request\nTest request\nTest request\n"
|
||||
}
|
||||
]
|
|
@ -5,7 +5,9 @@ from testing_utils import is_flaky_async, SYSTEM, require_backend_async
|
|||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_llama_gptq_handle(launcher):
|
||||
with launcher("huggingface/llama-7b-gptq", num_shard=2, quantize="gptq") as handle:
|
||||
with launcher(
|
||||
"astronomer/Llama-3-8B-Instruct-GPTQ-4-Bit", num_shard=2, quantize="gptq"
|
||||
) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
|
|
|
@ -62,7 +62,7 @@ async def test_flash_idefics2_two_images(flash_idefics2_next, response_snapshot)
|
|||
response.generated_text
|
||||
== " The cow is standing on the beach and the chicken is sitting on a pile of money."
|
||||
), f"{repr(response.generated_text)}"
|
||||
assert response.details.generated_tokens == 20
|
||||
assert response.details.generated_tokens == 19
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
|
|
|
@ -433,8 +433,17 @@ pub struct CompletionRequest {
|
|||
pub stop: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, ToSchema)]
|
||||
#[serde(tag = "object")]
|
||||
enum Completion {
|
||||
#[serde(rename = "text_completion")]
|
||||
Chunk(Chunk),
|
||||
#[serde(rename = "text_completion")]
|
||||
Final(CompletionFinal),
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema, Default)]
|
||||
pub(crate) struct Completion {
|
||||
pub(crate) struct CompletionFinal {
|
||||
pub id: String,
|
||||
#[schema(example = "1706270835")]
|
||||
pub created: u64,
|
||||
|
@ -453,6 +462,15 @@ pub(crate) struct CompletionComplete {
|
|||
pub finish_reason: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
||||
pub(crate) struct Chunk {
|
||||
pub id: String,
|
||||
pub created: u64,
|
||||
pub choices: Vec<CompletionComplete>,
|
||||
pub model: String,
|
||||
pub system_fingerprint: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
||||
pub(crate) struct ChatCompletion {
|
||||
pub id: String,
|
||||
|
@ -614,15 +632,6 @@ impl ChatCompletion {
|
|||
}
|
||||
}
|
||||
}
|
||||
#[derive(Clone, Deserialize, Serialize, ToSchema)]
|
||||
pub(crate) struct CompletionCompleteChunk {
|
||||
pub id: String,
|
||||
pub created: u64,
|
||||
pub choices: Vec<CompletionComplete>,
|
||||
pub model: String,
|
||||
pub system_fingerprint: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, ToSchema)]
|
||||
pub(crate) struct ChatCompletionChunk {
|
||||
pub id: String,
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
use axum::http::HeaderValue;
|
||||
use clap::Parser;
|
||||
use clap::Subcommand;
|
||||
use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
|
||||
use hf_hub::{Cache, Repo, RepoType};
|
||||
use opentelemetry::sdk::propagation::TraceContextPropagator;
|
||||
|
@ -27,6 +28,9 @@ use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer};
|
|||
#[derive(Parser, Debug)]
|
||||
#[clap(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
#[command(subcommand)]
|
||||
command: Option<Commands>,
|
||||
|
||||
#[clap(default_value = "128", long, env)]
|
||||
max_concurrent_requests: usize,
|
||||
#[clap(default_value = "2", long, env)]
|
||||
|
@ -85,10 +89,15 @@ struct Args {
|
|||
max_client_batch_size: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Subcommand)]
|
||||
enum Commands {
|
||||
PrintSchema,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), RouterError> {
|
||||
// Get args
|
||||
let args = Args::parse();
|
||||
|
||||
// Pattern match configuration
|
||||
let Args {
|
||||
max_concurrent_requests,
|
||||
|
@ -119,10 +128,17 @@ async fn main() -> Result<(), RouterError> {
|
|||
messages_api_enabled,
|
||||
disable_grammar_support,
|
||||
max_client_batch_size,
|
||||
command,
|
||||
} = args;
|
||||
|
||||
// Launch Tokio runtime
|
||||
init_logging(otlp_endpoint, otlp_service_name, json_output);
|
||||
let print_schema_command = match command {
|
||||
Some(Commands::PrintSchema) => true,
|
||||
None => {
|
||||
// only init logging if we are not running the print schema command
|
||||
init_logging(otlp_endpoint, otlp_service_name, json_output);
|
||||
false
|
||||
}
|
||||
};
|
||||
|
||||
// Validate args
|
||||
if max_input_tokens >= max_total_tokens {
|
||||
|
@ -388,6 +404,7 @@ async fn main() -> Result<(), RouterError> {
|
|||
messages_api_enabled,
|
||||
disable_grammar_support,
|
||||
max_client_batch_size,
|
||||
print_schema_command,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
|
|
|
@ -19,8 +19,8 @@ use crate::{
|
|||
use crate::{
|
||||
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
|
||||
ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob,
|
||||
ChatRequest, CompatGenerateRequest, Completion, CompletionComplete, CompletionCompleteChunk,
|
||||
CompletionRequest, CompletionType, DeltaToolCall, Function, Tool, VertexRequest,
|
||||
ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal,
|
||||
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, VertexRequest,
|
||||
VertexResponse,
|
||||
};
|
||||
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolType};
|
||||
|
@ -705,7 +705,7 @@ async fn completions(
|
|||
.as_secs();
|
||||
|
||||
event
|
||||
.json_data(CompletionCompleteChunk {
|
||||
.json_data(Completion::Chunk(Chunk {
|
||||
id: "".to_string(),
|
||||
created: current_time,
|
||||
|
||||
|
@ -718,7 +718,7 @@ async fn completions(
|
|||
|
||||
model: model_id.clone(),
|
||||
system_fingerprint: system_fingerprint.clone(),
|
||||
})
|
||||
}))
|
||||
.unwrap_or_else(|_e| Event::default())
|
||||
};
|
||||
|
||||
|
@ -931,7 +931,7 @@ async fn completions(
|
|||
.collect::<Result<Vec<_>, _>>()
|
||||
.map_err(|(status, Json(err))| (status, Json(err)))?;
|
||||
|
||||
let response = Completion {
|
||||
let response = Completion::Final(CompletionFinal {
|
||||
id: "".to_string(),
|
||||
created: current_time,
|
||||
model: info.model_id.clone(),
|
||||
|
@ -946,7 +946,7 @@ async fn completions(
|
|||
completion_tokens,
|
||||
total_tokens,
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
// headers similar to `generate` but aggregated
|
||||
let mut headers = HeaderMap::new();
|
||||
|
@ -1387,10 +1387,10 @@ async fn tokenize(
|
|||
|
||||
/// Prometheus metrics scrape endpoint
|
||||
#[utoipa::path(
|
||||
get,
|
||||
tag = "Text Generation Inference",
|
||||
path = "/metrics",
|
||||
responses((status = 200, description = "Prometheus Metrics", body = String))
|
||||
get,
|
||||
tag = "Text Generation Inference",
|
||||
path = "/metrics",
|
||||
responses((status = 200, description = "Prometheus Metrics", body = String))
|
||||
)]
|
||||
async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
|
||||
prom_handle.render()
|
||||
|
@ -1430,6 +1430,7 @@ pub async fn run(
|
|||
messages_api_enabled: bool,
|
||||
grammar_support: bool,
|
||||
max_client_batch_size: usize,
|
||||
print_schema_command: bool,
|
||||
) -> Result<(), WebServerError> {
|
||||
// OpenAPI documentation
|
||||
#[derive(OpenApi)]
|
||||
|
@ -1463,7 +1464,10 @@ pub async fn run(
|
|||
ChatCompletion,
|
||||
CompletionRequest,
|
||||
CompletionComplete,
|
||||
CompletionCompleteChunk,
|
||||
Chunk,
|
||||
Completion,
|
||||
CompletionFinal,
|
||||
Prompt,
|
||||
GenerateParameters,
|
||||
PrefillToken,
|
||||
Token,
|
||||
|
@ -1500,6 +1504,12 @@ pub async fn run(
|
|||
struct ApiDoc;
|
||||
|
||||
// Create state
|
||||
if print_schema_command {
|
||||
let api_doc = ApiDoc::openapi();
|
||||
let api_doc = serde_json::to_string_pretty(&api_doc).unwrap();
|
||||
println!("{}", api_doc);
|
||||
std::process::exit(0);
|
||||
}
|
||||
|
||||
// Open connection, get model info and warmup
|
||||
let (scheduler, health_ext, shard_info, max_batch_total_tokens): (
|
||||
|
|
|
@ -8,6 +8,9 @@ from text_generation_server.pb import generate_pb2
|
|||
from text_generation_server.models.causal_lm import CausalLMBatch
|
||||
from text_generation_server.utils import weight_hub_files, download_weights
|
||||
from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOMSharded
|
||||
from text_generation_server.models.custom_modeling.bloom_modeling import (
|
||||
BloomForCausalLM,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
@ -16,7 +19,10 @@ def default_bloom():
|
|||
revision = "main"
|
||||
filenames = weight_hub_files(model_id, revision, ".safetensors")
|
||||
download_weights(filenames, model_id, revision)
|
||||
return BLOOMSharded(model_id)
|
||||
return BLOOMSharded(
|
||||
model_id,
|
||||
model_class=BloomForCausalLM,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
|
|
@ -10,7 +10,7 @@ from text_generation_server.models.causal_lm import CausalLM, CausalLMBatch
|
|||
|
||||
@pytest.fixture(scope="session")
|
||||
def default_causal_lm():
|
||||
return CausalLM("gpt2")
|
||||
return CausalLM.fallback("gpt2")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
|
|
@ -1,13 +1,12 @@
|
|||
import pytest
|
||||
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.models.causal_lm import CausalLMBatch
|
||||
from text_generation_server.models.santacoder import SantaCoder
|
||||
from text_generation_server.models.causal_lm import CausalLMBatch, CausalLM
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def default_santacoder():
|
||||
return SantaCoder("bigcode/santacoder")
|
||||
return CausalLM.fallback(model_id="bigcode/santacoder")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
@ -20,7 +20,7 @@ def mt0_small_tokenizer():
|
|||
|
||||
@pytest.fixture(scope="session")
|
||||
def default_seq2seq_lm():
|
||||
return Seq2SeqLM("bigscience/mt0-small")
|
||||
return Seq2SeqLM.fallback("bigscience/mt0-small")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
@ -110,7 +110,7 @@ class PositionRotaryEmbedding(nn.Module):
|
|||
beta_fast=32,
|
||||
beta_slow=1,
|
||||
)
|
||||
elif rope_scaling["type"] == "su":
|
||||
elif rope_scaling["type"] in ["su", "longrope"]:
|
||||
short_factor = torch.tensor(
|
||||
rope_scaling["short_factor"], dtype=torch.float32, device=device
|
||||
)
|
||||
|
|
|
@ -11,17 +11,27 @@ from pathlib import Path
|
|||
|
||||
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
||||
from text_generation_server.models.model import Model
|
||||
from text_generation_server.models.causal_lm import CausalLM
|
||||
from text_generation_server.models.bloom import BLOOMSharded
|
||||
from text_generation_server.models.mpt import MPTSharded
|
||||
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast
|
||||
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
|
||||
from text_generation_server.models.custom_modeling.mpt_modeling import (
|
||||
MPTForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.bloom import BloomCausalLMBatch
|
||||
from text_generation_server.models.custom_modeling.bloom_modeling import (
|
||||
BloomForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
||||
from text_generation_server.models.rw import RW
|
||||
from text_generation_server.models.opt import OPTSharded
|
||||
from text_generation_server.models.galactica import GalacticaSharded
|
||||
from text_generation_server.models.santacoder import SantaCoder
|
||||
from text_generation_server.models.t5 import T5Sharded
|
||||
from text_generation_server.models.gpt_neox import GPTNeoxSharded
|
||||
from text_generation_server.models.phi import Phi
|
||||
from text_generation_server.models.galactica import GalacticaCausalLMBatch
|
||||
from text_generation_server.models.custom_modeling.neox_modeling import (
|
||||
GPTNeoxForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.phi_modeling import (
|
||||
PhiConfig,
|
||||
PhiForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.t5_modeling import (
|
||||
T5ForConditionalGeneration,
|
||||
)
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
|
@ -41,9 +51,6 @@ __all__ = [
|
|||
"CausalLM",
|
||||
"GalacticaSharded",
|
||||
"Seq2SeqLM",
|
||||
"SantaCoder",
|
||||
"OPTSharded",
|
||||
"T5Sharded",
|
||||
"get_model",
|
||||
]
|
||||
|
||||
|
@ -53,38 +60,65 @@ FLASH_ATTENTION = True
|
|||
|
||||
try:
|
||||
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||
from text_generation_server.models.flash_rw import FlashRWSharded
|
||||
from text_generation_server.models.flash_gpt2 import FlashGPT2
|
||||
from text_generation_server.models.flash_neox import FlashNeoXSharded
|
||||
from text_generation_server.models.flash_llama import (
|
||||
FlashLlama,
|
||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
||||
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
||||
FlashLlamaForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.flash_qwen2 import (
|
||||
FlashQwen2,
|
||||
from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
|
||||
FlashCohereForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.flash_cohere import (
|
||||
FlashCohere,
|
||||
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
|
||||
FlashGemmaForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.flash_gemma import (
|
||||
FlashGemma,
|
||||
from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
|
||||
FlashGemma2ForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.flash_gemma2 import (
|
||||
FlashGemma2,
|
||||
from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (
|
||||
FlashDbrxForCausalLM,
|
||||
DbrxConfig,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_rw_modeling import (
|
||||
RWConfig,
|
||||
FlashRWForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
|
||||
FlashGPTNeoXForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.pali_gemma import (
|
||||
PaliGemma,
|
||||
PaliGemmaBatch,
|
||||
)
|
||||
from text_generation_server.models.flash_santacoder import (
|
||||
FlashSantacoderSharded,
|
||||
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
|
||||
PaliGemmaForConditionalGeneration,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
|
||||
FlashPhiForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.idefics import IDEFICSSharded
|
||||
from text_generation_server.models.llava_next import LlavaNext
|
||||
from text_generation_server.models.idefics2 import Idefics2
|
||||
from text_generation_server.models.flash_mistral import FlashMistral
|
||||
from text_generation_server.models.flash_mixtral import FlashMixtral
|
||||
from text_generation_server.models.flash_phi import FlashPhi
|
||||
from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
|
||||
from text_generation_server.models.flash_dbrx import FlashDbrx
|
||||
from text_generation_server.models.custom_modeling.llava_next import (
|
||||
LlavaNextForConditionalGeneration,
|
||||
)
|
||||
|
||||
from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
|
||||
FlashSantacoderForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
|
||||
FlashStarcoder2ForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
|
||||
Qwen2ForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
|
||||
FlashMistralForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
|
||||
FlashMixtralForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (
|
||||
FlashGPT2ForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.idefics2 import (
|
||||
Idefics2ForConditionalGeneration,
|
||||
)
|
||||
from text_generation_server.layers.attention import SUPPORTS_WINDOWING
|
||||
except ImportError as e:
|
||||
logger.warning(f"Could not import Flash Attention enabled models: {e}")
|
||||
|
@ -93,21 +127,7 @@ except ImportError as e:
|
|||
|
||||
if FLASH_ATTENTION:
|
||||
__all__.append(FlashCausalLM)
|
||||
__all__.append(FlashGPT2)
|
||||
__all__.append(FlashNeoXSharded)
|
||||
__all__.append(FlashRWSharded)
|
||||
__all__.append(FlashSantacoderSharded)
|
||||
__all__.append(FlashLlama)
|
||||
__all__.append(IDEFICSSharded)
|
||||
__all__.append(FlashMistral)
|
||||
__all__.append(FlashMixtral)
|
||||
__all__.append(FlashDbrx)
|
||||
__all__.append(FlashPhi)
|
||||
__all__.append(FlashQwen2)
|
||||
__all__.append(FlashStarcoder2)
|
||||
__all__.append(FlashGemma)
|
||||
__all__.append(FlashGemma2)
|
||||
__all__.append(FlashCohere)
|
||||
|
||||
MAMBA_AVAILABLE = True
|
||||
MAMBA_IMPORT_ERROR = None
|
||||
|
@ -150,6 +170,11 @@ class ModelType(enum.Enum):
|
|||
"name": "Gemma",
|
||||
"url": "https://huggingface.co/google/gemma-7b",
|
||||
}
|
||||
PALIGEMMA = {
|
||||
"type": "paligemma",
|
||||
"name": "PaliGemma",
|
||||
"url": "https://huggingface.co/google/paligemma-3b-pt-224",
|
||||
}
|
||||
GEMMA2 = {
|
||||
"type": "gemma2",
|
||||
"name": "Gemma2",
|
||||
|
@ -452,13 +477,16 @@ def get_model(
|
|||
)
|
||||
|
||||
if model_id.startswith("facebook/galactica"):
|
||||
return GalacticaSharded(
|
||||
model_id,
|
||||
revision,
|
||||
return CausalLM(
|
||||
model_id=model_id,
|
||||
# Yes galactica is just an OPT model.
|
||||
model_class=OPTForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
batch_class=GalacticaCausalLMBatch,
|
||||
)
|
||||
|
||||
if (
|
||||
|
@ -467,22 +495,26 @@ def get_model(
|
|||
and model_id.startswith("bigcode/")
|
||||
):
|
||||
if FLASH_ATTENTION:
|
||||
return FlashSantacoderSharded(
|
||||
model_id,
|
||||
revision,
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashSantacoderForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
aliases={"transformer.wte.weight": ["lm_head.weight"]},
|
||||
num_kv_heads=1,
|
||||
)
|
||||
elif sharded:
|
||||
raise NotImplementedError(
|
||||
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
|
||||
)
|
||||
else:
|
||||
return SantaCoder(
|
||||
model_id,
|
||||
revision,
|
||||
return CausalLM.fallback(
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
|
@ -490,38 +522,44 @@ def get_model(
|
|||
)
|
||||
|
||||
if model_type == BLOOM:
|
||||
return BLOOMSharded(
|
||||
model_id,
|
||||
revision,
|
||||
return CausalLM(
|
||||
model_id=model_id,
|
||||
model_class=BloomForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
batch_class=BloomCausalLMBatch,
|
||||
)
|
||||
elif model_type == MPT:
|
||||
return MPTSharded(
|
||||
model_id,
|
||||
revision,
|
||||
return CausalLM(
|
||||
model_id=model_id,
|
||||
model_class=MPTForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
batch_class=CausalLMBatchKeysLast,
|
||||
)
|
||||
elif model_type == GPT2:
|
||||
if FLASH_ATTENTION:
|
||||
try:
|
||||
return FlashGPT2(
|
||||
model_id,
|
||||
revision,
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashGPT2ForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
except RuntimeError as e:
|
||||
# Lots of legacy models with various weight names.
|
||||
logger.warning(f"Couldn't load flash gpt2 variant: {e}")
|
||||
return CausalLM(
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
|
@ -532,7 +570,7 @@ def get_model(
|
|||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
|
||||
else:
|
||||
return CausalLM(
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
|
@ -542,25 +580,28 @@ def get_model(
|
|||
)
|
||||
elif model_type == GPT_NEOX:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashNeoXSharded(
|
||||
model_id,
|
||||
revision,
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashGPTNeoXForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
elif sharded:
|
||||
return GPTNeoxSharded(
|
||||
model_id,
|
||||
revision,
|
||||
return CausalLM(
|
||||
model_id=model_id,
|
||||
model_class=GPTNeoxForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
else:
|
||||
return CausalLM(
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
|
@ -571,16 +612,18 @@ def get_model(
|
|||
|
||||
elif model_type == PHI:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashPhi(
|
||||
model_id,
|
||||
revision,
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashPhiForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
else:
|
||||
return CausalLM(
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
|
@ -595,9 +638,11 @@ def get_model(
|
|||
"Legacy phi-msft is not supported with Flash Attention"
|
||||
)
|
||||
else:
|
||||
return Phi(
|
||||
model_id,
|
||||
revision,
|
||||
return CausalLM(
|
||||
model_id=model_id,
|
||||
model_class=PhiForCausalLM,
|
||||
config_class=PhiConfig,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
|
@ -606,9 +651,10 @@ def get_model(
|
|||
|
||||
elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashLlama(
|
||||
model_id,
|
||||
revision,
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashLlamaForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
|
@ -618,7 +664,7 @@ def get_model(
|
|||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
|
||||
else:
|
||||
return CausalLM(
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
|
@ -628,18 +674,22 @@ def get_model(
|
|||
)
|
||||
if model_type == GEMMA:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashGemma(
|
||||
model_id,
|
||||
revision,
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashGemmaForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
# Works better for these models
|
||||
default_dtype=torch.bfloat16,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
|
||||
else:
|
||||
return CausalLM(
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
|
@ -649,18 +699,22 @@ def get_model(
|
|||
)
|
||||
elif model_type == GEMMA2:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashGemma2(
|
||||
model_id,
|
||||
revision,
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashGemma2ForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
# Works better for these models
|
||||
default_dtype=torch.bfloat16,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2"))
|
||||
else:
|
||||
return CausalLM(
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
|
@ -671,18 +725,20 @@ def get_model(
|
|||
|
||||
if model_type == COHERE:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashCohere(
|
||||
model_id,
|
||||
revision,
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashCohereForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
|
||||
else:
|
||||
return CausalLM(
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
|
@ -693,18 +749,23 @@ def get_model(
|
|||
|
||||
if model_type == DBRX:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashDbrx(
|
||||
model_id,
|
||||
revision,
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashDbrxForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
# Dbrx works better in bfloat16.
|
||||
default_dtype=torch.bfloat16,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
config_class=DbrxConfig,
|
||||
)
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
|
||||
else:
|
||||
return CausalLM(
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
|
@ -718,27 +779,37 @@ def get_model(
|
|||
if FLASH_ATTENTION:
|
||||
if config_dict.get("alibi", False):
|
||||
raise NotImplementedError("sharded is not supported for this model")
|
||||
return FlashRWSharded(
|
||||
model_id,
|
||||
revision,
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashRWForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
aliases={
|
||||
"lm_head.weight": ["transformer.word_embeddings.weight"],
|
||||
"transformer.word_embeddings.weight": ["lm_head.weight"],
|
||||
},
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
config_class=RWConfig,
|
||||
)
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
|
||||
else:
|
||||
if FLASH_ATTENTION and not config_dict.get("alibi", False):
|
||||
return FlashRWSharded(
|
||||
model_id,
|
||||
revision,
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashRWForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
config_class=RWConfig,
|
||||
)
|
||||
else:
|
||||
return RW(
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
|
@ -749,18 +820,20 @@ def get_model(
|
|||
|
||||
if model_type == MISTRAL:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashMistral(
|
||||
model_id,
|
||||
revision,
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashMistralForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
|
||||
else:
|
||||
return CausalLM(
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
|
@ -771,18 +844,20 @@ def get_model(
|
|||
|
||||
if model_type == MIXTRAL:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashMixtral(
|
||||
model_id,
|
||||
revision,
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashMixtralForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
|
||||
else:
|
||||
return CausalLM(
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
|
@ -793,19 +868,22 @@ def get_model(
|
|||
|
||||
if model_type == STARCODER2:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashStarcoder2(
|
||||
model_id,
|
||||
revision,
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashStarcoder2ForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
elif sharded:
|
||||
raise NotImplementedError(
|
||||
FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
|
||||
)
|
||||
else:
|
||||
return CausalLM(
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
|
@ -816,17 +894,20 @@ def get_model(
|
|||
|
||||
if model_type == QWEN2:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashQwen2(
|
||||
model_id,
|
||||
revision,
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=Qwen2ForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
|
||||
else:
|
||||
return CausalLM(
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
|
@ -836,9 +917,10 @@ def get_model(
|
|||
)
|
||||
|
||||
if model_type == OPT:
|
||||
return OPTSharded(
|
||||
model_id,
|
||||
revision,
|
||||
return CausalLM(
|
||||
model_id=model_id,
|
||||
model_class=OPTForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
|
@ -846,13 +928,20 @@ def get_model(
|
|||
)
|
||||
|
||||
if model_type == T5:
|
||||
return T5Sharded(
|
||||
model_id,
|
||||
revision,
|
||||
return Seq2SeqLM(
|
||||
model_id=model_id,
|
||||
model_class=T5ForConditionalGeneration,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
aliases={
|
||||
"shared.weight": [
|
||||
"encoder.embed_tokens.weight",
|
||||
"decoder.embed_tokens.weight",
|
||||
]
|
||||
},
|
||||
)
|
||||
if model_type == IDEFICS:
|
||||
if FLASH_ATTENTION:
|
||||
|
@ -868,34 +957,45 @@ def get_model(
|
|||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||
if model_type == IDEFICS2:
|
||||
if FLASH_ATTENTION:
|
||||
return Idefics2(
|
||||
model_id,
|
||||
revision,
|
||||
return VlmCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=Idefics2ForConditionalGeneration,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
# XXX: Extremely important to cap resolution in order to limit
|
||||
# VRAM usage.
|
||||
processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||
if model_type == "paligemma":
|
||||
if model_type == PALIGEMMA:
|
||||
if FLASH_ATTENTION:
|
||||
return PaliGemma(
|
||||
model_id,
|
||||
revision,
|
||||
return VlmCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=PaliGemmaForConditionalGeneration,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
# Works better for these models
|
||||
default_dtype=torch.bfloat16,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
batch_class=PaliGemmaBatch,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||
|
||||
if model_type == LLAVA_NEXT:
|
||||
if FLASH_ATTENTION:
|
||||
return LlavaNext(
|
||||
model_id,
|
||||
revision,
|
||||
return VlmCausalLM(
|
||||
model_class=LlavaNextForConditionalGeneration,
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
|
@ -919,7 +1019,7 @@ def get_model(
|
|||
elif quantize == "exl2":
|
||||
raise NotImplementedError("exl2 quantization is not supported for AutoModel")
|
||||
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||
return CausalLM(
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
|
@ -928,7 +1028,7 @@ def get_model(
|
|||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
|
||||
return Seq2SeqLM(
|
||||
return Seq2SeqLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
|
@ -940,7 +1040,7 @@ def get_model(
|
|||
auto_map = config_dict.get("auto_map", None)
|
||||
if trust_remote_code and auto_map is not None:
|
||||
if "AutoModelForCausalLM" in auto_map.keys():
|
||||
return CausalLM(
|
||||
return CausalLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
|
@ -949,7 +1049,7 @@ def get_model(
|
|||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
if "AutoModelForSeq2SeqLM" in auto_map.keys():
|
||||
return Seq2SeqLM(
|
||||
return Seq2SeqLM.fallback(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
|
|
|
@ -4,22 +4,12 @@ import torch.distributed
|
|||
from typing import Optional, Type
|
||||
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoConfig,
|
||||
PreTrainedTokenizerBase,
|
||||
)
|
||||
|
||||
from text_generation_server.models.custom_modeling.bloom_modeling import (
|
||||
BloomForCausalLM,
|
||||
)
|
||||
from text_generation_server.models import CausalLM
|
||||
from text_generation_server.models.causal_lm import CausalLMBatch
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
|
||||
|
||||
class BloomCausalLMBatch(CausalLMBatch):
|
||||
|
@ -37,69 +27,6 @@ class BloomCausalLMBatch(CausalLMBatch):
|
|||
|
||||
|
||||
class BLOOMSharded(CausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
slow_but_exact=False,
|
||||
tp_parallel=True,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
config.pad_token_id = 3
|
||||
config.quantize = quantize
|
||||
config.speculator = speculator
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(
|
||||
filenames,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
process_group=self.process_group,
|
||||
prefix="transformer",
|
||||
)
|
||||
if config.quantize in ["gptq", "marlin"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
model = BloomForCausalLM(config, weights)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(CausalLM, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
@property
|
||||
def batch_type(self) -> Type[CausalLMBatch]:
|
||||
return BloomCausalLMBatch
|
||||
|
|
|
@ -1,13 +1,25 @@
|
|||
import torch
|
||||
import time
|
||||
import torch.distributed
|
||||
|
||||
from dataclasses import dataclass
|
||||
from opentelemetry import trace
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
AutoModelForCausalLM,
|
||||
PreTrainedTokenizerBase,
|
||||
)
|
||||
from typing import Optional, Tuple, List, Type, Dict
|
||||
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.models import Model
|
||||
from text_generation_server.utils.chunks import concat_text_chunks
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.tokens import batch_top_tokens
|
||||
from text_generation_server.models.types import (
|
||||
Batch,
|
||||
|
@ -478,10 +490,88 @@ class CausalLMBatch(Batch):
|
|||
return len(self.requests)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CausalLMBatchKeysLast(Batch):
|
||||
keys_head_dim_last: bool = False
|
||||
|
||||
|
||||
class CausalLM(Model):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
model_class,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
default_dtype=torch.float16,
|
||||
trust_remote_code: bool = False,
|
||||
tokenizer_class=AutoTokenizer,
|
||||
config_class=AutoConfig,
|
||||
batch_class=CausalLMBatch,
|
||||
):
|
||||
self.batch_class = batch_class
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = default_dtype if dtype is None else dtype
|
||||
elif SYSTEM == "ipex":
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
device = torch.device(f"xpu:{rank}")
|
||||
dtype = default_dtype if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
# Float16 doesn't exist on target.
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
||||
tokenizer = tokenizer_class.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = config_class.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.speculator = speculator
|
||||
if tokenizer.pad_token_id is None:
|
||||
tokenizer.pad_token_id = config.pad_token_id
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(
|
||||
filenames, device=device, dtype=dtype, process_group=self.process_group
|
||||
)
|
||||
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
prefix = ""
|
||||
model = model_class(prefix, config, weights)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super().__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def fallback(
|
||||
cls,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
|
@ -537,7 +627,12 @@ class CausalLM(Model):
|
|||
else:
|
||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
|
||||
super(CausalLM, self).__init__(
|
||||
self = cls.__new__(
|
||||
cls,
|
||||
)
|
||||
self.batch_class = CausalLMBatch
|
||||
super().__init__(
|
||||
self,
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
|
@ -545,15 +640,11 @@ class CausalLM(Model):
|
|||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
return self
|
||||
|
||||
@property
|
||||
def batch_type(self) -> Type[CausalLMBatch]:
|
||||
return CausalLMBatch
|
||||
|
||||
def decode(self, generated_ids: List[int]) -> str:
|
||||
return self.tokenizer.decode(
|
||||
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)
|
||||
return self.batch_class
|
||||
|
||||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
|
|
|
@ -815,7 +815,7 @@ class BloomModel(BloomPreTrainedModel):
|
|||
|
||||
|
||||
class BloomForCausalLM(BloomPreTrainedModel):
|
||||
def __init__(self, config, weights):
|
||||
def __init__(self, prefix: str, config, weights):
|
||||
super().__init__(config)
|
||||
self.transformer = BloomModel(config, weights)
|
||||
|
||||
|
|
|
@ -446,7 +446,7 @@ class CLIPEncoder(nn.Module):
|
|||
|
||||
|
||||
class CLIPTextTransformer(nn.Module):
|
||||
def __init__(self, config: CLIPTextConfig):
|
||||
def __init__(self, prefix: str, config: CLIPTextConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
embed_dim = config.hidden_size
|
||||
|
@ -536,9 +536,9 @@ class CLIPTextModel(CLIPPreTrainedModel):
|
|||
|
||||
_no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
|
||||
|
||||
def __init__(self, config: CLIPTextConfig):
|
||||
def __init__(self, prefix, config: CLIPTextConfig):
|
||||
super().__init__(config)
|
||||
self.text_model = CLIPTextTransformer(config)
|
||||
self.text_model = CLIPTextTransformer(prefix, config)
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
|
|
|
@ -363,9 +363,9 @@ class CohereMLP(nn.Module):
|
|||
|
||||
|
||||
class FlashCohereLayer(nn.Module):
|
||||
def __init__(self, layer_id, config, weights):
|
||||
def __init__(self, prefix: str, layer_id, config, weights):
|
||||
super().__init__()
|
||||
prefix = f"model.layers.{layer_id}"
|
||||
prefix = f"{prefix}.layers.{layer_id}"
|
||||
self.self_attn = FlashCohereAttention(
|
||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||
)
|
||||
|
@ -416,18 +416,19 @@ class FlashCohereLayer(nn.Module):
|
|||
|
||||
|
||||
class FlashCohereModel(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
def __init__(self, prefix: str, config, weights):
|
||||
super().__init__()
|
||||
|
||||
process_group = weights.process_group
|
||||
self.tp_rank = process_group.rank()
|
||||
self.tp_world_size = process_group.size()
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix="model.embed_tokens", weights=weights
|
||||
prefix=f"{prefix}.embed_tokens", weights=weights
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
FlashCohereLayer(
|
||||
prefix,
|
||||
layer_id,
|
||||
config,
|
||||
weights,
|
||||
|
@ -436,7 +437,7 @@ class FlashCohereModel(torch.nn.Module):
|
|||
]
|
||||
)
|
||||
self.norm = FastLayerNorm.load_no_bias(
|
||||
prefix="model.norm", weights=weights, eps=config.layer_norm_eps
|
||||
prefix=f"{prefix}.norm", weights=weights, eps=config.layer_norm_eps
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
@ -486,10 +487,15 @@ class FlashCohereModel(torch.nn.Module):
|
|||
|
||||
|
||||
class FlashCohereForCausalLM(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
def __init__(self, prefix: str, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.model = FlashCohereModel(config, weights)
|
||||
if not prefix:
|
||||
prefix = "model"
|
||||
else:
|
||||
prefix = f"{prefix}.model"
|
||||
|
||||
self.model = FlashCohereModel(prefix, config, weights)
|
||||
try:
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config,
|
||||
|
@ -499,7 +505,7 @@ class FlashCohereForCausalLM(torch.nn.Module):
|
|||
except RuntimeError:
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config,
|
||||
prefix="model.embed_tokens",
|
||||
prefix=f"{prefix}.embed_tokens",
|
||||
weights=weights,
|
||||
)
|
||||
self.logit_scale = config.logit_scale
|
||||
|
|
|
@ -593,9 +593,9 @@ class DenseMoE(nn.Module):
|
|||
|
||||
|
||||
class DbrxLayer(nn.Module):
|
||||
def __init__(self, layer_id, config, weights):
|
||||
def __init__(self, prefix: str, layer_id, config, weights):
|
||||
super().__init__()
|
||||
prefix = f"transformer.blocks.{layer_id}"
|
||||
prefix = f"{prefix}.blocks.{layer_id}"
|
||||
|
||||
self.attn = DbrxNormAttentionNorm(
|
||||
prefix=f"{prefix}.norm_attn_norm", config=config, weights=weights
|
||||
|
@ -637,16 +637,17 @@ class DbrxLayer(nn.Module):
|
|||
|
||||
|
||||
class DbrxModel(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
def __init__(self, prefix: str, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix="transformer.wte", weights=weights
|
||||
prefix=f"{prefix}.wte", weights=weights
|
||||
)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
DbrxLayer(
|
||||
prefix,
|
||||
layer_id,
|
||||
config,
|
||||
weights,
|
||||
|
@ -655,7 +656,7 @@ class DbrxModel(torch.nn.Module):
|
|||
]
|
||||
)
|
||||
self.norm = FastLayerNorm.load_no_bias(
|
||||
prefix="transformer.norm_f", weights=weights, eps=1e-5
|
||||
prefix=f"{prefix}.norm_f", weights=weights, eps=1e-5
|
||||
)
|
||||
|
||||
self.head_size = self.layers[0].attn.self_attn.head_size
|
||||
|
@ -702,10 +703,15 @@ class DbrxModel(torch.nn.Module):
|
|||
|
||||
|
||||
class FlashDbrxForCausalLM(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
def __init__(self, prefix: str, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.model = DbrxModel(config, weights)
|
||||
if not prefix:
|
||||
prefix = "transformer"
|
||||
else:
|
||||
prefix = f"{prefix}.transformer"
|
||||
|
||||
self.model = DbrxModel(prefix, config, weights)
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config,
|
||||
prefix="lm_head",
|
||||
|
|
|
@ -102,7 +102,7 @@ class Gemma2Config(PretrainedConfig):
|
|||
|
||||
class Gemma2FastRMSNorm(FastRMSNorm):
|
||||
@classmethod
|
||||
def load(cls, prefix, weights, eps=1e-6):
|
||||
def load(cls, prefix: str, weights, eps=1e-6):
|
||||
dtype = weights.dtype
|
||||
weights.dtype = torch.float32
|
||||
weight = weights.get_tensor(f"{prefix}.weight") + 1
|
||||
|
@ -123,7 +123,7 @@ class Gemma2FastRMSNorm(FastRMSNorm):
|
|||
return hidden_states.to(self.dtype), residual
|
||||
|
||||
|
||||
def load_attention(config, prefix, weights):
|
||||
def load_attention(config, prefix: str, weights):
|
||||
if config.num_attention_heads != config.num_key_value_heads:
|
||||
return _load_gqa(config, prefix, weights)
|
||||
else:
|
||||
|
@ -305,7 +305,7 @@ class Gemma2MLP(nn.Module):
|
|||
|
||||
|
||||
class FlashGemma2Layer(nn.Module):
|
||||
def __init__(self, prefix, config, weights, causal: bool, is_sliding: bool):
|
||||
def __init__(self, prefix: str, config, weights, causal: bool, is_sliding: bool):
|
||||
super().__init__()
|
||||
self.self_attn = FlashGemma2Attention(
|
||||
prefix=f"{prefix}.self_attn",
|
||||
|
@ -376,7 +376,7 @@ class FlashGemma2Layer(nn.Module):
|
|||
|
||||
|
||||
class FlashGemma2Model(torch.nn.Module):
|
||||
def __init__(self, prefix, config, weights, causal: bool):
|
||||
def __init__(self, prefix: str, config, weights, causal: bool):
|
||||
super().__init__()
|
||||
|
||||
process_group = weights.process_group
|
||||
|
@ -442,7 +442,7 @@ class FlashGemma2Model(torch.nn.Module):
|
|||
|
||||
|
||||
class FlashGemma2ForCausalLM(torch.nn.Module):
|
||||
def __init__(self, prefix, config, weights, causal: bool):
|
||||
def __init__(self, prefix: str, config, weights, *, causal: bool = True):
|
||||
super().__init__()
|
||||
|
||||
embed_norm = config.hidden_size**0.5
|
||||
|
|
|
@ -102,7 +102,7 @@ class GemmaConfig(PretrainedConfig):
|
|||
|
||||
class GemmaFastRMSNorm(FastRMSNorm):
|
||||
@classmethod
|
||||
def load(cls, prefix, weights, eps=1e-6):
|
||||
def load(cls, prefix: str, weights, eps=1e-6):
|
||||
dtype = weights.dtype
|
||||
weights.dtype = torch.float32
|
||||
weight = weights.get_tensor(f"{prefix}.weight") + 1
|
||||
|
@ -123,7 +123,7 @@ class GemmaFastRMSNorm(FastRMSNorm):
|
|||
return hidden_states.to(self.dtype), residual
|
||||
|
||||
|
||||
def load_attention(config, prefix, weights):
|
||||
def load_attention(config, prefix: str, weights):
|
||||
if config.num_attention_heads != config.num_key_value_heads:
|
||||
return _load_gqa(config, prefix, weights)
|
||||
else:
|
||||
|
@ -261,7 +261,7 @@ class FlashGemmaAttention(torch.nn.Module):
|
|||
|
||||
|
||||
class GemmaMLP(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
def __init__(self, prefix: str, config, weights):
|
||||
super().__init__()
|
||||
act = config.hidden_act
|
||||
self.act = (
|
||||
|
@ -299,7 +299,7 @@ class GemmaMLP(nn.Module):
|
|||
|
||||
|
||||
class FlashGemmaLayer(nn.Module):
|
||||
def __init__(self, prefix, config, weights, causal: bool):
|
||||
def __init__(self, prefix: str, config, weights, causal: bool):
|
||||
super().__init__()
|
||||
self.self_attn = FlashGemmaAttention(
|
||||
prefix=f"{prefix}.self_attn", config=config, weights=weights, causal=causal
|
||||
|
@ -354,7 +354,7 @@ class FlashGemmaLayer(nn.Module):
|
|||
|
||||
|
||||
class FlashGemmaModel(torch.nn.Module):
|
||||
def __init__(self, prefix, config, weights, causal: bool):
|
||||
def __init__(self, prefix: str, config, weights, causal: bool):
|
||||
super().__init__()
|
||||
|
||||
process_group = weights.process_group
|
||||
|
@ -419,7 +419,7 @@ class FlashGemmaModel(torch.nn.Module):
|
|||
|
||||
|
||||
class FlashGemmaForCausalLM(torch.nn.Module):
|
||||
def __init__(self, prefix, config, weights, causal: bool):
|
||||
def __init__(self, prefix: str, config, weights, *, causal: bool = True):
|
||||
super().__init__()
|
||||
|
||||
embed_norm = config.hidden_size**0.5
|
||||
|
|
|
@ -261,7 +261,7 @@ class FlashGPT2Attention(torch.nn.Module):
|
|||
|
||||
|
||||
class GPT2MLP(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
def __init__(self, prefix: str, config, weights):
|
||||
super().__init__()
|
||||
act = config.activation_function
|
||||
self.act = (
|
||||
|
@ -298,7 +298,7 @@ class GPT2MLP(nn.Module):
|
|||
|
||||
|
||||
class FlashGPT2Layer(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
def __init__(self, prefix: str, config, weights):
|
||||
super().__init__()
|
||||
self.self_attn = FlashGPT2Attention(
|
||||
prefix=f"{prefix}.attn", config=config, weights=weights
|
||||
|
@ -350,7 +350,7 @@ class FlashGPT2Layer(nn.Module):
|
|||
|
||||
|
||||
class FlashGPT2Model(torch.nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
def __init__(self, prefix: str, config, weights):
|
||||
super().__init__()
|
||||
|
||||
process_group = weights.process_group
|
||||
|
@ -414,7 +414,7 @@ class FlashGPT2Model(torch.nn.Module):
|
|||
|
||||
|
||||
class FlashGPT2ForCausalLM(torch.nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
def __init__(self, prefix: str, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
|
|
|
@ -54,7 +54,7 @@ if SYSTEM == "rocm":
|
|||
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
|
||||
|
||||
|
||||
def load_attention(config, prefix, weights, layer_id):
|
||||
def load_attention(config, prefix: str, weights, layer_id):
|
||||
# Only defined in granite.
|
||||
bias = getattr(config, "attention_bias", False)
|
||||
head_size = config.hidden_size // config.num_attention_heads
|
||||
|
@ -467,7 +467,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||
|
||||
|
||||
class FlashLlamaForCausalLM(torch.nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
def __init__(self, prefix: str, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
|
|
|
@ -248,7 +248,7 @@ class MistralAttention(torch.nn.Module):
|
|||
|
||||
|
||||
class MistralMLP(nn.Module):
|
||||
def __init__(self, prefix, config, weights, layer_id):
|
||||
def __init__(self, prefix: str, config, weights, layer_id):
|
||||
super().__init__()
|
||||
self.hidden_act = config.hidden_act
|
||||
self.act = (
|
||||
|
@ -328,7 +328,7 @@ class MistralMLP(nn.Module):
|
|||
|
||||
|
||||
class MistralLayer(nn.Module):
|
||||
def __init__(self, prefix, config, weights, layer_id):
|
||||
def __init__(self, prefix: str, config, weights, layer_id):
|
||||
super().__init__()
|
||||
self.self_attn = MistralAttention(
|
||||
prefix=f"{prefix}.self_attn",
|
||||
|
@ -392,7 +392,7 @@ class MistralLayer(nn.Module):
|
|||
|
||||
|
||||
class MistralModel(torch.nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
def __init__(self, prefix: str, config, weights):
|
||||
super().__init__()
|
||||
|
||||
process_group = weights.process_group
|
||||
|
@ -462,7 +462,7 @@ class MistralModel(torch.nn.Module):
|
|||
|
||||
|
||||
class FlashMistralForCausalLM(torch.nn.Module):
|
||||
def __init__(self, prefix, config, weights, name=None):
|
||||
def __init__(self, prefix: str, config, weights, name=None):
|
||||
if name is None:
|
||||
name = "model"
|
||||
super().__init__()
|
||||
|
|
|
@ -116,7 +116,7 @@ def promote_scalar(x: torch.Tensor) -> torch.Tensor:
|
|||
return x.view(1) if len(x.size()) == 0 else x
|
||||
|
||||
|
||||
def load_attention(config, prefix, weights):
|
||||
def load_attention(config, prefix: str, weights):
|
||||
if config.num_attention_heads != config.num_key_value_heads:
|
||||
return _load_gqa(config, prefix, weights)
|
||||
else:
|
||||
|
@ -155,7 +155,7 @@ def _load_gqa(config, prefix: str, weights):
|
|||
)
|
||||
|
||||
|
||||
def _load_experts(config, prefix, mat, weights):
|
||||
def _load_experts(config, prefix: str, mat, weights):
|
||||
if config.quantize is not None:
|
||||
raise NotImplementedError("Mixtral does not support weight quantization yet.")
|
||||
|
||||
|
@ -475,7 +475,7 @@ class DenseMoE(nn.Module):
|
|||
|
||||
|
||||
class MixtralLayer(nn.Module):
|
||||
def __init__(self, prefix, layer_id, config, weights):
|
||||
def __init__(self, prefix: str, layer_id, config, weights):
|
||||
super().__init__()
|
||||
prefix = f"{prefix}.layers.{layer_id}"
|
||||
|
||||
|
@ -536,7 +536,7 @@ class MixtralLayer(nn.Module):
|
|||
|
||||
|
||||
class MixtralModel(torch.nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
def __init__(self, prefix: str, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
|
@ -610,7 +610,7 @@ class MixtralModel(torch.nn.Module):
|
|||
|
||||
|
||||
class FlashMixtralForCausalLM(torch.nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
def __init__(self, prefix: str, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.model = MixtralModel(prefix, config, weights)
|
||||
|
|
|
@ -305,12 +305,12 @@ class FlashGPTNeoXPreTrainedModel(PreTrainedModel):
|
|||
|
||||
|
||||
class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
||||
def __init__(self, config, weights):
|
||||
def __init__(self, prefix: str, config, weights):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
self.embed_in = TensorParallelEmbedding(
|
||||
prefix="gpt_neox.embed_in", weights=weights
|
||||
prefix=f"{prefix}.embed_in", weights=weights
|
||||
)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
|
@ -320,7 +320,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
|||
]
|
||||
)
|
||||
self.final_layer_norm = FastLayerNorm.load(
|
||||
prefix="gpt_neox.final_layer_norm",
|
||||
prefix=f"{prefix}.final_layer_norm",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_eps,
|
||||
)
|
||||
|
@ -370,9 +370,15 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
|||
|
||||
|
||||
class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
||||
def __init__(self, config, weights):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__(config)
|
||||
self.gpt_neox = FlashGPTNeoXModel(config, weights)
|
||||
|
||||
if not prefix:
|
||||
prefix = "gpt_neox"
|
||||
else:
|
||||
prefix = f"{prefix}.gpt_neox"
|
||||
|
||||
self.gpt_neox = FlashGPTNeoXModel(prefix, config, weights)
|
||||
|
||||
self.embed_out = SpeculativeHead.load(
|
||||
config, prefix="embed_out", weights=weights
|
||||
|
|
|
@ -258,9 +258,9 @@ class PhiMLP(nn.Module):
|
|||
|
||||
|
||||
class FlashPhiLayer(nn.Module):
|
||||
def __init__(self, layer_id, config, weights):
|
||||
def __init__(self, prefix: str, layer_id, config, weights):
|
||||
super().__init__()
|
||||
prefix = f"model.layers.{layer_id}"
|
||||
prefix = f"{prefix}.layers.{layer_id}"
|
||||
self.self_attn = FlashPhiAttention(
|
||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||
)
|
||||
|
@ -307,18 +307,19 @@ class FlashPhiLayer(nn.Module):
|
|||
|
||||
|
||||
class FlashPhiModel(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
def __init__(self, prefix: str, config, weights):
|
||||
super().__init__()
|
||||
|
||||
process_group = weights.process_group
|
||||
self.tp_rank = process_group.rank()
|
||||
self.tp_world_size = process_group.size()
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix="model.embed_tokens", weights=weights
|
||||
prefix=f"{prefix}.embed_tokens", weights=weights
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
FlashPhiLayer(
|
||||
prefix,
|
||||
layer_id,
|
||||
config,
|
||||
weights,
|
||||
|
@ -378,10 +379,15 @@ class FlashPhiModel(torch.nn.Module):
|
|||
|
||||
|
||||
class FlashPhiForCausalLM(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
def __init__(self, prefix: str, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.model = FlashPhiModel(config, weights)
|
||||
if not prefix:
|
||||
prefix = "model"
|
||||
else:
|
||||
prefix = f"{prefix}.model"
|
||||
|
||||
self.model = FlashPhiModel(prefix, config, weights)
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config,
|
||||
prefix="lm_head",
|
||||
|
|
|
@ -203,9 +203,9 @@ class Qwen2MLP(nn.Module):
|
|||
|
||||
|
||||
class Qwen2Layer(nn.Module):
|
||||
def __init__(self, layer_id, config, weights):
|
||||
def __init__(self, prefix, layer_id, config, weights):
|
||||
super().__init__()
|
||||
prefix = f"model.layers.{layer_id}"
|
||||
prefix = f"{prefix}.layers.{layer_id}"
|
||||
self.self_attn = Qwen2Attention(
|
||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||
)
|
||||
|
@ -260,17 +260,18 @@ class Qwen2Layer(nn.Module):
|
|||
|
||||
|
||||
class Qwen2Model(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
def __init__(self, prefix: str, config, weights):
|
||||
super().__init__()
|
||||
process_group = weights.process_group
|
||||
self.tp_rank = process_group.rank()
|
||||
self.tp_world_size = process_group.size()
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix="model.embed_tokens", weights=weights
|
||||
prefix=f"{prefix}.embed_tokens", weights=weights
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
Qwen2Layer(
|
||||
prefix,
|
||||
layer_id,
|
||||
config,
|
||||
weights,
|
||||
|
@ -279,7 +280,7 @@ class Qwen2Model(torch.nn.Module):
|
|||
]
|
||||
)
|
||||
self.norm = FastRMSNorm.load(
|
||||
prefix="model.norm", weights=weights, eps=config.rms_norm_eps
|
||||
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
@ -331,10 +332,15 @@ class Qwen2Model(torch.nn.Module):
|
|||
|
||||
|
||||
class Qwen2ForCausalLM(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
def __init__(self, prefix: str, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.model = Qwen2Model(config, weights)
|
||||
if not prefix:
|
||||
prefix = "model"
|
||||
else:
|
||||
prefix = f"{prefix}.model"
|
||||
|
||||
self.model = Qwen2Model(prefix, config, weights)
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config,
|
||||
prefix="lm_head",
|
||||
|
|
|
@ -127,7 +127,7 @@ class FlashRWAttention(torch.nn.Module):
|
|||
def __init__(
|
||||
self,
|
||||
config,
|
||||
prefix,
|
||||
prefix: str,
|
||||
weights,
|
||||
):
|
||||
super().__init__()
|
||||
|
@ -236,7 +236,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||
def __init__(
|
||||
self,
|
||||
config,
|
||||
prefix,
|
||||
prefix: str,
|
||||
weights,
|
||||
):
|
||||
super().__init__()
|
||||
|
@ -358,7 +358,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||
|
||||
|
||||
class FlashMLP(nn.Module):
|
||||
def __init__(self, config, prefix, weights):
|
||||
def __init__(self, config, prefix: str, weights):
|
||||
super().__init__()
|
||||
self.act = torch.nn.functional.gelu
|
||||
|
||||
|
@ -380,6 +380,7 @@ class FlashRWLayer(nn.Module):
|
|||
def __init__(
|
||||
self,
|
||||
layer_id,
|
||||
prefix: str,
|
||||
config,
|
||||
weights,
|
||||
):
|
||||
|
@ -388,7 +389,7 @@ class FlashRWLayer(nn.Module):
|
|||
parallel_attn = config.parallel_attn
|
||||
self.parallel_attn = parallel_attn
|
||||
|
||||
prefix = f"transformer.h.{layer_id}"
|
||||
prefix = f"{prefix}.h.{layer_id}"
|
||||
|
||||
self.input_layernorm = FastLayerNorm.load(
|
||||
prefix=f"{prefix}.input_layernorm",
|
||||
|
@ -479,7 +480,7 @@ class FlashRWLayer(nn.Module):
|
|||
|
||||
|
||||
class FlashRWLayerNorm(nn.Module):
|
||||
def __init__(self, config, prefix, weights):
|
||||
def __init__(self, config, prefix: str, weights):
|
||||
super().__init__()
|
||||
self.num_ln = config.num_ln_in_parallel_attn
|
||||
|
||||
|
@ -518,9 +519,9 @@ class FlashRWLayerNorm(nn.Module):
|
|||
|
||||
|
||||
class FlashRWLargeLayer(nn.Module):
|
||||
def __init__(self, layer_id, config, weights):
|
||||
def __init__(self, layer_id, prefix: str, config, weights):
|
||||
super().__init__()
|
||||
prefix = f"transformer.h.{layer_id}"
|
||||
prefix = f"{prefix}.h.{layer_id}"
|
||||
|
||||
self.ln_layer = FlashRWLayerNorm(config, prefix, weights)
|
||||
|
||||
|
@ -580,18 +581,18 @@ class FlashRWPreTrainedModel(PreTrainedModel):
|
|||
|
||||
|
||||
class FlashRWModel(FlashRWPreTrainedModel):
|
||||
def __init__(self, config, weights):
|
||||
def __init__(self, prefix: str, config, weights):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
self.word_embeddings = TensorParallelEmbedding(
|
||||
prefix="transformer.word_embeddings", weights=weights
|
||||
prefix=f"{prefix}.word_embeddings", weights=weights
|
||||
)
|
||||
|
||||
if config.new_decoder_architecture:
|
||||
self.h = nn.ModuleList(
|
||||
[
|
||||
FlashRWLargeLayer(layer_id, config, weights)
|
||||
FlashRWLargeLayer(layer_id, prefix, config, weights)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
|
@ -599,14 +600,14 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||
else:
|
||||
self.h = nn.ModuleList(
|
||||
[
|
||||
FlashRWLayer(layer_id, config, weights)
|
||||
FlashRWLayer(layer_id, prefix, config, weights)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.cache_size = self.h[0].self_attention.num_heads_kv
|
||||
|
||||
self.ln_f = FastLayerNorm.load(
|
||||
prefix="transformer.ln_f",
|
||||
prefix=f"{prefix}.ln_f",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_epsilon,
|
||||
)
|
||||
|
@ -653,10 +654,15 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||
|
||||
|
||||
class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
||||
def __init__(self, config, weights):
|
||||
def __init__(self, prefix: str, config, weights):
|
||||
super().__init__(config)
|
||||
|
||||
self.transformer = FlashRWModel(config, weights)
|
||||
if not prefix:
|
||||
prefix = "transformer"
|
||||
else:
|
||||
prefix = f"{prefix}.transformer"
|
||||
|
||||
self.transformer = FlashRWModel(prefix, config, weights)
|
||||
|
||||
self.lm_head = SpeculativeHead.load(config, prefix="lm_head", weights=weights)
|
||||
|
||||
|
|
|
@ -346,16 +346,16 @@ class MLP(nn.Module):
|
|||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, layer_id, config, weights):
|
||||
def __init__(self, prefix: str, layer_id, config, weights):
|
||||
super().__init__()
|
||||
prefix = f"transformer.h.{layer_id}"
|
||||
prefix = f"{prefix}.h.{layer_id}"
|
||||
self.ln_1 = FastLayerNorm.load(
|
||||
prefix=f"{prefix}.ln_1", weights=weights, eps=config.layer_norm_epsilon
|
||||
)
|
||||
self.ln_2 = FastLayerNorm.load(
|
||||
prefix=f"{prefix}.ln_2", weights=weights, eps=config.layer_norm_epsilon
|
||||
)
|
||||
self.attn = FlashMQAttention(
|
||||
self.self_attn = FlashMQAttention(
|
||||
prefix=f"{prefix}.attn",
|
||||
config=config,
|
||||
weights=weights,
|
||||
|
@ -378,7 +378,7 @@ class Block(nn.Module):
|
|||
max_s,
|
||||
):
|
||||
hidden_states, residual = self.ln_1(hidden_states, residual)
|
||||
hidden_states = self.attn(
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
|
@ -396,25 +396,26 @@ class Block(nn.Module):
|
|||
|
||||
|
||||
class FlashSantacoderModel(nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
def __init__(self, prefix: str, config, weights):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.process_group = weights.process_group
|
||||
self.wte = TensorParallelEmbedding(
|
||||
prefix="transformer.wte",
|
||||
prefix=f"{prefix}.wte",
|
||||
weights=weights,
|
||||
reduce=False,
|
||||
)
|
||||
self.wpe = TensorParallelEmbedding(
|
||||
prefix="transformer.wpe",
|
||||
prefix=f"{prefix}.wpe",
|
||||
weights=weights,
|
||||
reduce=False,
|
||||
)
|
||||
|
||||
self.h = nn.ModuleList(
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
Block(
|
||||
prefix,
|
||||
layer_id,
|
||||
config,
|
||||
weights,
|
||||
|
@ -426,8 +427,8 @@ class FlashSantacoderModel(nn.Module):
|
|||
prefix="transformer.ln_f", weights=weights, eps=config.layer_norm_epsilon
|
||||
)
|
||||
|
||||
self.head_size = self.h[0].attn.head_size
|
||||
self.num_heads = self.h[0].attn.num_heads
|
||||
self.head_size = self.layers[0].self_attn.head_size
|
||||
self.num_heads = self.layers[0].self_attn.num_heads
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -446,7 +447,7 @@ class FlashSantacoderModel(nn.Module):
|
|||
torch.distributed.all_reduce(hidden_states, group=self.process_group)
|
||||
|
||||
residual = None
|
||||
for i, layer in enumerate(self.h):
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
residual,
|
||||
|
@ -464,11 +465,18 @@ class FlashSantacoderModel(nn.Module):
|
|||
|
||||
|
||||
class FlashSantacoderForCausalLM(nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.transformer = FlashSantacoderModel(config, weights)
|
||||
|
||||
if not prefix:
|
||||
prefix = "transformer"
|
||||
else:
|
||||
prefix = f"{prefix}.transformer"
|
||||
|
||||
config.transpose = config.architectures[0].startswith("GPT2")
|
||||
self.model = FlashSantacoderModel(prefix, config, weights)
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config, prefix="transformer.wte", weights=weights
|
||||
config, prefix=f"{prefix}.wte", weights=weights
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
@ -485,7 +493,7 @@ class FlashSantacoderForCausalLM(nn.Module):
|
|||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
position_ids,
|
||||
cu_seqlen_prefill,
|
||||
|
|
|
@ -417,14 +417,14 @@ class Starcoder2Layer(nn.Module):
|
|||
|
||||
|
||||
class Starcoder2Model(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
|
||||
process_group = weights.process_group
|
||||
self.tp_rank = process_group.rank()
|
||||
self.tp_world_size = process_group.size()
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix="model.embed_tokens", weights=weights
|
||||
prefix=f"{prefix}.embed_tokens", weights=weights
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
|
@ -437,7 +437,7 @@ class Starcoder2Model(torch.nn.Module):
|
|||
]
|
||||
)
|
||||
self.norm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load(
|
||||
prefix="model.norm", weights=weights, eps=config.norm_epsilon
|
||||
prefix=f"{prefix}.norm", weights=weights, eps=config.norm_epsilon
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
@ -489,10 +489,15 @@ class Starcoder2Model(torch.nn.Module):
|
|||
|
||||
|
||||
class FlashStarcoder2ForCausalLM(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.model = Starcoder2Model(config, weights)
|
||||
if not prefix:
|
||||
prefix = "model"
|
||||
else:
|
||||
prefix = f"{prefix}.model"
|
||||
|
||||
self.model = Starcoder2Model(prefix, config, weights)
|
||||
try:
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config,
|
||||
|
@ -502,7 +507,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
|
|||
except RuntimeError:
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config,
|
||||
prefix="model.embed_tokens",
|
||||
prefix=f"{prefix}.embed_tokens",
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
|
|
|
@ -136,7 +136,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
|||
self.config = config
|
||||
config.text_config.quantize = config.quantize
|
||||
config.text_config.speculator = config.speculator
|
||||
self.language_model = load_text_model(
|
||||
self.text_model = load_text_model(
|
||||
prefix="language_model" if not prefix else f"{prefix}.language_model",
|
||||
config=config.text_config,
|
||||
weights=weights,
|
||||
|
@ -180,7 +180,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
|||
image_sizes: Optional[torch.LongTensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
):
|
||||
inputs_embeds = self.language_model.embed_tokens(input_ids)
|
||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||
if pixel_values is not None and len(pixel_values) > 0:
|
||||
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
|
||||
|
@ -269,7 +269,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
|||
input_ids, inputs_embeds, image_features
|
||||
)
|
||||
|
||||
hidden_states = self.language_model.model(
|
||||
hidden_states = self.text_model.model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
|
@ -283,5 +283,5 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
|||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits, speculative_logits = self.language_model.lm_head(hidden_states)
|
||||
logits, speculative_logits = self.text_model.lm_head(hidden_states)
|
||||
return logits, speculative_logits
|
||||
|
|
|
@ -783,7 +783,7 @@ class MPTPreTrainedModel(PreTrainedModel):
|
|||
|
||||
|
||||
class MPTModel(MPTPreTrainedModel):
|
||||
def __init__(self, config, weights):
|
||||
def __init__(self, prefix: str, config, weights):
|
||||
# config._validate_config()
|
||||
super().__init__(config)
|
||||
self.world_size = weights.process_group.size()
|
||||
|
@ -809,13 +809,13 @@ class MPTModel(MPTPreTrainedModel):
|
|||
f"Requested norm type ({config.norm_type}) is not implemented within this repo."
|
||||
)
|
||||
|
||||
self.wte = TensorParallelEmbedding("transformer.wte", weights)
|
||||
self.wte = TensorParallelEmbedding(f"{prefix}.wte", weights)
|
||||
|
||||
if not self.alibi:
|
||||
self.wpe = TensorParallelEmbedding("transformer.wpe", weights)
|
||||
self.wpe = TensorParallelEmbedding(f"{prefix}.wpe", weights)
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
MPTBlock(config, prefix=f"transformer.blocks.{i}", weights=weights)
|
||||
MPTBlock(config, prefix=f"{prefix}.blocks.{i}", weights=weights)
|
||||
for i in range(config.n_layers)
|
||||
]
|
||||
)
|
||||
|
@ -1085,13 +1085,19 @@ class MPTModel(MPTPreTrainedModel):
|
|||
|
||||
|
||||
class MPTForCausalLM(MPTPreTrainedModel):
|
||||
def __init__(self, config, weights):
|
||||
def __init__(self, prefix: str, config, weights):
|
||||
super().__init__(config)
|
||||
|
||||
if not prefix:
|
||||
prefix = "transformer"
|
||||
else:
|
||||
prefix = f"{prefix}.transformer"
|
||||
|
||||
if not config.tie_word_embeddings:
|
||||
raise ValueError("MPTForCausalLM only supports tied word embeddings")
|
||||
self.transformer = MPTModel(config, weights)
|
||||
self.transformer = MPTModel(prefix, config, weights)
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config, prefix="transformer.wte", weights=weights
|
||||
config, prefix=f"{prefix}.wte", weights=weights
|
||||
)
|
||||
self.logit_scale = None
|
||||
if config.logit_scale is not None:
|
||||
|
|
|
@ -404,24 +404,24 @@ class GPTNeoXMLP(nn.Module):
|
|||
|
||||
|
||||
class GPTNeoXLayer(nn.Module):
|
||||
def __init__(self, layer_id, config, weights):
|
||||
def __init__(self, layer_id, prefix: str, config, weights):
|
||||
super().__init__()
|
||||
self.use_parallel_residual = config.use_parallel_residual
|
||||
self.input_layernorm = nn.LayerNorm.load(
|
||||
prefix=f"gpt_neox.layers.{layer_id}.input_layernorm",
|
||||
prefix=f"{prefix}.layers.{layer_id}.input_layernorm",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_eps,
|
||||
)
|
||||
self.post_attention_layernorm = nn.LayerNorm.load(
|
||||
prefix=f"gpt_neox.layers.{layer_id}.post_attention_layernorm",
|
||||
prefix=f"{prefix}.layers.{layer_id}.post_attention_layernorm",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_eps,
|
||||
)
|
||||
self.attention = GPTNeoXAttention(
|
||||
config, prefix=f"gpt_neox.layers.{layer_id}.attention", weights=weights
|
||||
config, prefix=f"{prefix}.layers.{layer_id}.attention", weights=weights
|
||||
)
|
||||
self.mlp = GPTNeoXMLP(
|
||||
config, prefix=f"gpt_neox.layers.{layer_id}.mlp", weights=weights
|
||||
config, prefix=f"{prefix}.layers.{layer_id}.mlp", weights=weights
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
@ -472,23 +472,23 @@ class GPTNeoXLayer(nn.Module):
|
|||
|
||||
|
||||
class GPTNeoXModel(GPTNeoXPreTrainedModel):
|
||||
def __init__(self, config, weights):
|
||||
def __init__(self, prefix: str, config, weights):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
|
||||
self.embed_in = TensorParallelEmbedding(
|
||||
prefix="gpt_neox.embed_in", weights=weights
|
||||
prefix=f"{prefix}.embed_in", weights=weights
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
GPTNeoXLayer(layer_id, config, weights)
|
||||
GPTNeoXLayer(layer_id, prefix, config, weights)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.final_layer_norm = nn.LayerNorm.load(
|
||||
prefix="gpt_neox.final_layer_norm",
|
||||
prefix=f"{prefix}.final_layer_norm",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_eps,
|
||||
)
|
||||
|
@ -640,9 +640,15 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
|
|||
class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
||||
|
||||
def __init__(self, config, weights):
|
||||
def __init__(self, prefix: str, config, weights):
|
||||
super().__init__(config)
|
||||
self.gpt_neox = GPTNeoXModel(config, weights)
|
||||
|
||||
if not prefix:
|
||||
prefix = "gpt_neox"
|
||||
else:
|
||||
prefix = f"{prefix}.gpt_neox"
|
||||
|
||||
self.gpt_neox = GPTNeoXModel(prefix, config, weights)
|
||||
self.embed_out = SpeculativeHead.load(
|
||||
config, prefix="embed_out", weights=weights
|
||||
)
|
||||
|
|
|
@ -94,11 +94,11 @@ class OPTLearnedPositionalEmbedding(nn.Module):
|
|||
This module learns positional embeddings up to a fixed maximum size.
|
||||
"""
|
||||
|
||||
def __init__(self, weights):
|
||||
def __init__(self, prefix: str, weights):
|
||||
super().__init__()
|
||||
self.offset = 2
|
||||
self.weight = nn.Parameter(
|
||||
weights.get_tensor("model.decoder.embed_positions.weight")
|
||||
weights.get_tensor(f"{prefix}.decoder.embed_positions.weight")
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
@ -311,11 +311,11 @@ class OPTAttention(nn.Module):
|
|||
|
||||
|
||||
class OPTDecoderLayer(nn.Module):
|
||||
def __init__(self, layer_id: int, config: OPTConfig, weights):
|
||||
def __init__(self, layer_id: int, prefix: str, config: OPTConfig, weights):
|
||||
super().__init__()
|
||||
self.process_group = weights.process_group
|
||||
self.hidden_size = config.hidden_size
|
||||
prefix = f"model.decoder.layers.{layer_id}"
|
||||
prefix = f"{prefix}.decoder.layers.{layer_id}"
|
||||
self.self_attn = OPTAttention(
|
||||
config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
|
@ -429,7 +429,7 @@ class OPTPreTrainedModel(PreTrainedModel):
|
|||
|
||||
|
||||
class OPTDecoder(OPTPreTrainedModel):
|
||||
def __init__(self, config: OPTConfig, weights):
|
||||
def __init__(self, prefix: str, config: OPTConfig, weights):
|
||||
super().__init__(config)
|
||||
self.dropout = config.dropout
|
||||
self.layerdrop = config.layerdrop
|
||||
|
@ -438,20 +438,26 @@ class OPTDecoder(OPTPreTrainedModel):
|
|||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix="model.decoder.embed_tokens", weights=weights
|
||||
prefix=f"{prefix}.decoder.embed_tokens", weights=weights
|
||||
)
|
||||
self.embed_positions = OPTLearnedPositionalEmbedding(weights)
|
||||
self.embed_positions = OPTLearnedPositionalEmbedding(prefix, weights)
|
||||
|
||||
if config.word_embed_proj_dim != config.hidden_size:
|
||||
self.project_out = FastLinear.load(
|
||||
config, prefix="model.decoder.project_out", weights=weights, bias=False
|
||||
config,
|
||||
prefix=f"{prefix}.decoder.project_out",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
else:
|
||||
self.project_out = None
|
||||
|
||||
if config.word_embed_proj_dim != config.hidden_size:
|
||||
self.project_in = FastLinear.load(
|
||||
config, prefix="model.decoder.project_in", weights=weights, bias=False
|
||||
config,
|
||||
prefix=f"{prefix}.decoder.project_in",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
else:
|
||||
self.project_in = None
|
||||
|
@ -461,14 +467,14 @@ class OPTDecoder(OPTPreTrainedModel):
|
|||
# see https://github.com/facebookresearch/metaseq/pull/164
|
||||
if config.do_layer_norm_before and not config._remove_final_layer_norm:
|
||||
self.final_layer_norm = nn.LayerNorm.load(
|
||||
prefix="model.decoder.final_layer_norm", weights=weights, eps=EPS
|
||||
prefix=f"{prefix}.decoder.final_layer_norm", weights=weights, eps=EPS
|
||||
)
|
||||
else:
|
||||
self.final_layer_norm = None
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
OPTDecoderLayer(layer_id, config, weights)
|
||||
OPTDecoderLayer(layer_id, prefix, config, weights)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
|
@ -686,9 +692,9 @@ class OPTDecoder(OPTPreTrainedModel):
|
|||
|
||||
|
||||
class OPTModel(OPTPreTrainedModel):
|
||||
def __init__(self, config: OPTConfig, weights):
|
||||
def __init__(self, prefix: str, config: OPTConfig, weights):
|
||||
super().__init__(config)
|
||||
self.decoder = OPTDecoder(config, weights)
|
||||
self.decoder = OPTDecoder(prefix, config, weights)
|
||||
# Initialize weights and apply final processing
|
||||
|
||||
def forward(
|
||||
|
@ -743,13 +749,18 @@ class OPTModel(OPTPreTrainedModel):
|
|||
|
||||
|
||||
class OPTForCausalLM(OPTPreTrainedModel):
|
||||
def __init__(self, config, weights):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__(config)
|
||||
|
||||
self.model = OPTModel(config, weights)
|
||||
if not prefix:
|
||||
prefix = "model"
|
||||
else:
|
||||
prefix = f"{prefix}.model"
|
||||
|
||||
self.model = OPTModel(prefix, config, weights)
|
||||
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config, prefix="model.decoder.embed_tokens", weights=weights
|
||||
config, prefix=f"{prefix}.decoder.embed_tokens", weights=weights
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
|
|
@ -248,16 +248,16 @@ class PhiBlock(nn.Module):
|
|||
|
||||
# PhiModel implements the embedding layer and the transformer blocks.
|
||||
class PhiModel(nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
def __init__(self, prefix: str, config, weights):
|
||||
super().__init__()
|
||||
self.tp_rank = weights.process_group.rank()
|
||||
self.tp_world_size = weights.process_group.size()
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix="transformer.embd.wte", weights=weights
|
||||
prefix=f"{prefix}.embd.wte", weights=weights
|
||||
)
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
PhiBlock(f"transformer.h.{layer_id}", config, weights)
|
||||
PhiBlock(f"{prefix}.h.{layer_id}", config, weights)
|
||||
for layer_id in range(config.n_layer)
|
||||
]
|
||||
)
|
||||
|
@ -289,9 +289,15 @@ class PhiModel(nn.Module):
|
|||
|
||||
# PhiForCausalLM wraps the PhiModel and PhiCausalLMHead together and returns a CausalLMOutputWithPast object.
|
||||
class PhiForCausalLM(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
def __init__(self, prefix: str, config, weights):
|
||||
super().__init__()
|
||||
self.model = PhiModel(config, weights)
|
||||
|
||||
if not prefix:
|
||||
prefix = "transformer"
|
||||
else:
|
||||
prefix = f"{prefix}.transformer"
|
||||
|
||||
self.model = PhiModel(prefix, config, weights)
|
||||
self.lm_head = PhiCausalLMHead(config, weights)
|
||||
|
||||
def forward(
|
||||
|
|
|
@ -10,7 +10,12 @@ import numpy as np
|
|||
from loguru import logger
|
||||
from dataclasses import dataclass
|
||||
from opentelemetry import trace
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from transformers import (
|
||||
PreTrainedTokenizerBase,
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
GenerationConfig,
|
||||
)
|
||||
from typing import Iterable, Optional, Tuple, List, Type, Dict
|
||||
|
||||
from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata
|
||||
|
@ -21,6 +26,12 @@ from text_generation_server.models import Model
|
|||
from text_generation_server.utils.tokens import batch_top_tokens
|
||||
from text_generation_server.utils.dist import RANK
|
||||
from text_generation_server.utils.speculate import get_speculate
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
hub,
|
||||
)
|
||||
from text_generation_server.models.types import (
|
||||
Batch,
|
||||
Tokens,
|
||||
|
@ -798,29 +809,120 @@ class FlashCausalLMBatch(Batch):
|
|||
return len(self.requests)
|
||||
|
||||
|
||||
ADAPTER_LAYERS = [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
"o_proj",
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
"down_proj",
|
||||
]
|
||||
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}
|
||||
|
||||
|
||||
class FlashCausalLM(Model):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
model: torch.nn.Module,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_layers: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
rank: int = 0,
|
||||
world_size: int = 1,
|
||||
sliding_window: Optional[int] = None,
|
||||
model_class,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
lora_adapter_ids: Optional[list] = [],
|
||||
tokenizer_class: PreTrainedTokenizerBase = AutoTokenizer,
|
||||
config_class: PreTrainedTokenizerBase = AutoConfig,
|
||||
default_dtype=torch.float16,
|
||||
aliases=None,
|
||||
# Used for Santacoder override of config
|
||||
num_kv_heads=None,
|
||||
skip_special_tokens: bool = True,
|
||||
):
|
||||
self.num_layers = num_layers
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.head_size = head_size
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = default_dtype if dtype is None else dtype
|
||||
elif SYSTEM == "ipex":
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
device = torch.device(f"xpu:{rank}")
|
||||
dtype = default_dtype if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
# Float16 doesn't exist on target.
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError(f"{model_class} is only available on GPU")
|
||||
|
||||
tokenizer = tokenizer_class.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
try:
|
||||
generation_config = GenerationConfig.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
if isinstance(generation_config.eos_token_id, (list, set)):
|
||||
# TODO Huge hack
|
||||
tokenizer._eos_token_ids = set(generation_config.eos_token_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
config = config_class.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.speculator = speculator
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(
|
||||
filenames, device, dtype, process_group=self.process_group, aliases=aliases
|
||||
)
|
||||
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
prefix = ""
|
||||
model = model_class(prefix, config, weights)
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
# VLM models define the config we care about in their text_config
|
||||
text_config = getattr(config, "text_config", None)
|
||||
if text_config is not None:
|
||||
config = text_config
|
||||
|
||||
if getattr(config, "sliding_window", None) is not None:
|
||||
set_sliding_window(config.sliding_window)
|
||||
else:
|
||||
config.sliding_window = None
|
||||
|
||||
self.num_layers = config.num_hidden_layers
|
||||
# Validation is done in the model itself
|
||||
if num_kv_heads is None:
|
||||
# Order is important here.
|
||||
for attr in ["num_key_value_heads", "num_attention_heads", "n_head"]:
|
||||
num_kv_heads = getattr(config, attr, None)
|
||||
if num_kv_heads is not None:
|
||||
break
|
||||
if num_kv_heads is None:
|
||||
raise ValueError("Cannot get the number of key/value heads")
|
||||
self.num_kv_heads = (
|
||||
num_kv_heads // self.process_group.size()
|
||||
if num_kv_heads > 1
|
||||
else num_kv_heads
|
||||
)
|
||||
assert self.num_kv_heads > 0
|
||||
self.head_size = config.hidden_size // config.num_attention_heads
|
||||
|
||||
self.cuda_graphs = {}
|
||||
self.kv_cache = []
|
||||
|
||||
super(FlashCausalLM, self).__init__(
|
||||
super().__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
|
@ -829,7 +931,7 @@ class FlashCausalLM(Model):
|
|||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
sliding_window=sliding_window,
|
||||
sliding_window=config.sliding_window,
|
||||
)
|
||||
|
||||
@property
|
||||
|
@ -1577,3 +1679,72 @@ class FlashCausalLM(Model):
|
|||
forward_ns = start_decode - start
|
||||
decode_ns = time.time_ns() - start_decode
|
||||
return generations, batch, (forward_ns, decode_ns)
|
||||
|
||||
@property
|
||||
def supports_adapter_loading(self) -> bool:
|
||||
return True
|
||||
|
||||
def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
|
||||
layer_weights = {}
|
||||
|
||||
prefix = "model.layers"
|
||||
|
||||
# This accounts for VLMs (e.g. LlavaNext, Idefics2)
|
||||
# that have a language_model inside of the larger model.
|
||||
if hasattr(self.model, "language_model"):
|
||||
_model = self.model.language_model
|
||||
elif hasattr(self.model, "text_model"):
|
||||
_model = self.model.text_model
|
||||
else:
|
||||
_model = self.model
|
||||
|
||||
for i, layer in enumerate(_model.model.layers):
|
||||
layer_weights[(i, "q_proj")] = (
|
||||
f"{prefix}.{i}.self_attn.q_proj",
|
||||
layer.self_attn.query_key_value,
|
||||
)
|
||||
layer_weights[(i, "k_proj")] = (
|
||||
f"{prefix}.{i}.self_attn.k_proj",
|
||||
layer.self_attn.query_key_value,
|
||||
)
|
||||
layer_weights[(i, "v_proj")] = (
|
||||
f"{prefix}.{i}.self_attn.v_proj",
|
||||
layer.self_attn.query_key_value,
|
||||
)
|
||||
layer_weights[(i, "o_proj")] = (
|
||||
f"{prefix}.{i}.self_attn.o_proj",
|
||||
layer.self_attn.o_proj,
|
||||
)
|
||||
|
||||
# TODO: this is a hack to avoid the gate_proj for
|
||||
# FlashStarcoder2 that doesnt have these layers
|
||||
if hasattr(layer, "mlp") and hasattr(layer.mlp, "gate_up_proj"):
|
||||
layer_weights[(i, "gate_proj")] = (
|
||||
f"{prefix}.{i}.mlp.gate_proj",
|
||||
layer.mlp.gate_up_proj,
|
||||
)
|
||||
layer_weights[(i, "up_proj")] = (
|
||||
f"{prefix}.{i}.mlp.up_proj",
|
||||
layer.mlp.gate_up_proj,
|
||||
)
|
||||
layer_weights[(i, "down_proj")] = (
|
||||
f"{prefix}.{i}.mlp.down_proj",
|
||||
layer.mlp.down_proj,
|
||||
)
|
||||
|
||||
layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head)
|
||||
return layer_weights
|
||||
|
||||
@property
|
||||
def adapter_layers(self) -> List[str]:
|
||||
return ADAPTER_LAYERS
|
||||
|
||||
@property
|
||||
def default_traced_adapter_layers(self) -> List[str]:
|
||||
return ["q_proj", "v_proj"]
|
||||
|
||||
def get_num_layers_for_type(self, layer_type: str) -> int:
|
||||
return 1 if layer_type == "lm_head" else len(self.model.model.layers)
|
||||
|
||||
def is_row_parallel(self, layer_type: str) -> bool:
|
||||
return layer_type in ROW_PARALLEL
|
||||
|
|
|
@ -1,75 +0,0 @@
|
|||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from opentelemetry import trace
|
||||
from typing import Optional
|
||||
from transformers import AutoTokenizer, AutoConfig
|
||||
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
|
||||
FlashCohereForCausalLM,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
class FlashCohere(FlashCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError("FlashCohere is only available on GPU")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
use_fast=True,
|
||||
from_slow=False,
|
||||
)
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.speculator = speculator
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
if config.quantize in ["gptq", "awq", "marlin"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
model = FlashCohereForCausalLM(config, weights)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashCohere, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.model.layers),
|
||||
num_kv_heads=model.model.num_key_value_heads,
|
||||
head_size=model.model.head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
|
@ -1,100 +0,0 @@
|
|||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from opentelemetry import trace
|
||||
from typing import Optional
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.gpt2 import GPT2TokenizerFast
|
||||
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (
|
||||
FlashDbrxForCausalLM,
|
||||
DbrxConfig,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
class FlashDbrx(FlashCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError("FlashDBRX is only available on GPU")
|
||||
|
||||
try:
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
use_fast=True,
|
||||
from_slow=False,
|
||||
)
|
||||
except:
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
use_fast=True,
|
||||
from_slow=False,
|
||||
)
|
||||
except:
|
||||
# FIXME: change back to model id once the tokenizer.json is merged
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained(
|
||||
"Xenova/dbrx-instruct-tokenizer",
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
use_fast=True,
|
||||
from_slow=False,
|
||||
)
|
||||
|
||||
config = DbrxConfig.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.speculator = speculator
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
if config.quantize in ["gptq", "awq", "marlin"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
model = FlashDbrxForCausalLM(config, weights)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashDbrx, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.model.layers),
|
||||
num_kv_heads=model.model.num_key_value_heads,
|
||||
head_size=model.model.head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
|
@ -1,83 +0,0 @@
|
|||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from opentelemetry import trace
|
||||
from typing import Optional
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
|
||||
FlashGemmaForCausalLM,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
class FlashGemma(FlashCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
elif SYSTEM == "ipex":
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
device = torch.device(f"xpu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError("FlashGemma is only available on GPU")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.speculator = speculator
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
if config.quantize in ["gptq", "awq", "marlin"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
# TODO hardcoded
|
||||
prefix = ""
|
||||
model = FlashGemmaForCausalLM(prefix, config, weights, causal=True)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashGemma, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.model.layers),
|
||||
num_kv_heads=model.model.num_key_value_heads,
|
||||
head_size=model.model.head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
|
@ -1,83 +0,0 @@
|
|||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from opentelemetry import trace
|
||||
from typing import Optional
|
||||
from transformers import PretrainedConfig, AutoTokenizer
|
||||
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
|
||||
FlashGemma2ForCausalLM,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
class FlashGemma2(FlashCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
elif SYSTEM == "ipex":
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
device = torch.device(f"xpu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError("FlashGemma2 is only available on GPU")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = PretrainedConfig.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.speculator = speculator
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
if config.quantize in ["gptq", "awq", "marlin"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
# TODO hardcoded
|
||||
prefix = ""
|
||||
model = FlashGemma2ForCausalLM(prefix, config, weights, causal=True)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashGemma2, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.model.layers),
|
||||
num_kv_heads=model.model.num_key_value_heads,
|
||||
head_size=model.model.head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
|
@ -1,82 +0,0 @@
|
|||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from opentelemetry import trace
|
||||
from transformers import AutoConfig, AutoTokenizer, GenerationConfig
|
||||
from transformers.models.gpt2 import GPT2Tokenizer
|
||||
from typing import Optional
|
||||
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (
|
||||
FlashGPT2ForCausalLM,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
class FlashGPT2(FlashCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif SYSTEM == "ipex":
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
device = torch.device(f"xpu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError("FlashGPT2 is only available on GPU")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.speculator = speculator
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
if config.quantize in ["gptq", "awq", "marlin"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
prefix = ""
|
||||
model = FlashGPT2ForCausalLM(prefix, config, weights)
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashGPT2, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.model.layers),
|
||||
num_kv_heads=model.model.num_heads,
|
||||
head_size=model.model.head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
|
@ -1,171 +0,0 @@
|
|||
import os
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from opentelemetry import trace
|
||||
from transformers import AutoConfig, AutoTokenizer, GenerationConfig
|
||||
from typing import Optional, Tuple, Dict, List
|
||||
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
||||
FlashLlamaForCausalLM,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
hub,
|
||||
)
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
ADAPTER_LAYERS = [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
"o_proj",
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
"down_proj",
|
||||
]
|
||||
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}
|
||||
|
||||
|
||||
class FlashLlama(FlashCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
lora_adapter_ids: Optional[list] = [],
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif SYSTEM == "ipex":
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
device = torch.device(f"xpu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError("FlashLlama is only available on GPU")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
try:
|
||||
generation_config = GenerationConfig.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
if isinstance(generation_config.eos_token_id, (list, set)):
|
||||
# TODO Huge hack
|
||||
tokenizer._eos_token_ids = set(generation_config.eos_token_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.speculator = speculator
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
prefix = ""
|
||||
model = FlashLlamaForCausalLM(prefix, config, weights)
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashLlama, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.model.layers),
|
||||
num_kv_heads=model.model.num_key_value_heads,
|
||||
head_size=model.model.head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
@property
|
||||
def supports_adapter_loading(self) -> bool:
|
||||
return True
|
||||
|
||||
def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
|
||||
layer_weights = {}
|
||||
|
||||
prefix = "model.layers"
|
||||
|
||||
# This accounts for VLMs (e.g. LlavaNext, Idefics2)
|
||||
# that have a language_model inside of the larger model.
|
||||
if hasattr(self.model, "language_model"):
|
||||
_model = self.model.language_model
|
||||
elif hasattr(self.model, "text_model"):
|
||||
_model = self.model.text_model
|
||||
else:
|
||||
_model = self.model
|
||||
|
||||
for i, layer in enumerate(_model.model.layers):
|
||||
layer_weights[(i, "q_proj")] = (
|
||||
f"{prefix}.{i}.self_attn.q_proj",
|
||||
layer.self_attn.query_key_value,
|
||||
)
|
||||
layer_weights[(i, "k_proj")] = (
|
||||
f"{prefix}.{i}.self_attn.k_proj",
|
||||
layer.self_attn.query_key_value,
|
||||
)
|
||||
layer_weights[(i, "v_proj")] = (
|
||||
f"{prefix}.{i}.self_attn.v_proj",
|
||||
layer.self_attn.query_key_value,
|
||||
)
|
||||
layer_weights[(i, "o_proj")] = (
|
||||
f"{prefix}.{i}.self_attn.o_proj",
|
||||
layer.self_attn.o_proj,
|
||||
)
|
||||
|
||||
layer_weights[(i, "gate_proj")] = (
|
||||
f"{prefix}.{i}.mlp.gate_proj",
|
||||
layer.mlp.gate_up_proj,
|
||||
)
|
||||
layer_weights[(i, "up_proj")] = (
|
||||
f"{prefix}.{i}.mlp.up_proj",
|
||||
layer.mlp.gate_up_proj,
|
||||
)
|
||||
layer_weights[(i, "down_proj")] = (
|
||||
f"{prefix}.{i}.mlp.down_proj",
|
||||
layer.mlp.down_proj,
|
||||
)
|
||||
|
||||
layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head)
|
||||
return layer_weights
|
||||
|
||||
@property
|
||||
def adapter_layers(self) -> List[str]:
|
||||
return ADAPTER_LAYERS
|
||||
|
||||
@property
|
||||
def default_traced_adapter_layers(self) -> List[str]:
|
||||
return ["q_proj", "v_proj"]
|
||||
|
||||
def get_num_layers_for_type(self, layer_type: str) -> int:
|
||||
return 1 if layer_type == "lm_head" else len(self.model.model.layers)
|
||||
|
||||
def is_row_parallel(self, layer_type: str) -> bool:
|
||||
return layer_type in ROW_PARALLEL
|
|
@ -1,24 +1,7 @@
|
|||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from opentelemetry import trace
|
||||
from transformers import AutoTokenizer, AutoConfig
|
||||
from typing import Optional, Tuple, Dict, List
|
||||
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
from text_generation_server.models.flash_causal_lm import set_sliding_window
|
||||
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
|
||||
FlashMistralForCausalLM,
|
||||
MistralConfig,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
ADAPTER_LAYERS = [
|
||||
|
@ -33,88 +16,7 @@ ADAPTER_LAYERS = [
|
|||
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}
|
||||
|
||||
|
||||
class BaseFlashMistral(FlashCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_cls,
|
||||
model_id: str,
|
||||
config_cls=AutoConfig,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
tokenizer_class=AutoTokenizer,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif SYSTEM == "ipex":
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
device = torch.device(f"xpu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError("FlashMistral is only available on GPU")
|
||||
|
||||
tokenizer = tokenizer_class.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = config_cls.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.speculator = speculator
|
||||
|
||||
# Set context windows
|
||||
if getattr(config, "sliding_window", None) is not None:
|
||||
set_sliding_window(config.sliding_window)
|
||||
else:
|
||||
config.sliding_window = None
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
if config.quantize in ["gptq", "awq", "marlin"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
prefix = ""
|
||||
model = model_cls(prefix, config, weights)
|
||||
|
||||
self.cuda_graphs = {}
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
num_layers, num_kv_heads, head_size = self.get_layer_config(model)
|
||||
super().__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
num_layers=num_layers,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
sliding_window=config.sliding_window,
|
||||
)
|
||||
|
||||
def get_layer_config(self, model) -> Tuple[int, int, int]:
|
||||
return (
|
||||
len(model.model.layers),
|
||||
model.model.num_key_value_heads,
|
||||
model.model.head_size,
|
||||
)
|
||||
|
||||
class FlashMistral(FlashCausalLM):
|
||||
@property
|
||||
def supports_adapter_loading(self) -> bool:
|
||||
return True
|
||||
|
@ -126,9 +28,7 @@ class BaseFlashMistral(FlashCausalLM):
|
|||
|
||||
# This accounts for VLMs (e.g. LlavaNext, Idefics2)
|
||||
# that have a language_model inside of the larger model.
|
||||
if hasattr(self.model, "language_model"):
|
||||
_model = self.model.language_model
|
||||
elif hasattr(self.model, "text_model"):
|
||||
if hasattr(self.model, "text_model"):
|
||||
_model = self.model.text_model
|
||||
else:
|
||||
_model = self.model
|
||||
|
@ -183,25 +83,3 @@ class BaseFlashMistral(FlashCausalLM):
|
|||
|
||||
def is_row_parallel(self, layer_type: str) -> bool:
|
||||
return layer_type in ROW_PARALLEL
|
||||
|
||||
|
||||
class FlashMistral(BaseFlashMistral):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
super(FlashMistral, self).__init__(
|
||||
config_cls=MistralConfig,
|
||||
model_cls=FlashMistralForCausalLM,
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
|
|
@ -1,31 +0,0 @@
|
|||
import torch
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from text_generation_server.models.flash_mistral import BaseFlashMistral
|
||||
from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
|
||||
MixtralConfig,
|
||||
FlashMixtralForCausalLM,
|
||||
)
|
||||
|
||||
|
||||
class FlashMixtral(BaseFlashMistral):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
super(FlashMixtral, self).__init__(
|
||||
config_cls=MixtralConfig,
|
||||
model_cls=FlashMixtralForCausalLM,
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
|
@ -1,82 +0,0 @@
|
|||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from opentelemetry import trace
|
||||
from transformers import AutoTokenizer, AutoConfig
|
||||
from typing import Optional
|
||||
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
|
||||
FlashGPTNeoXForCausalLM,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
class FlashNeoXSharded(FlashCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif SYSTEM == "ipex":
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
device = torch.device(f"xpu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError("FlashNeoX is only available on GPU")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.speculator = speculator
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(
|
||||
filenames, device=device, dtype=dtype, process_group=self.process_group
|
||||
)
|
||||
if config.quantize in ["gptq", "marlin"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
model = FlashGPTNeoXForCausalLM(config, weights)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashNeoXSharded, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model.to(device),
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.gpt_neox.layers),
|
||||
num_kv_heads=model.gpt_neox.num_heads,
|
||||
head_size=model.gpt_neox.head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
|
@ -1,111 +0,0 @@
|
|||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from opentelemetry import trace
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
from typing import Optional
|
||||
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
|
||||
FlashPhiForCausalLM,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
class FlashPhi(FlashCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif SYSTEM == "ipex":
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
device = torch.device(f"xpu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError("FlashPhi is only available on GPU")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.speculator = speculator
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
if config.quantize in ["gptq", "awq", "marlin"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
model = FlashPhiForCausalLM(config, weights)
|
||||
if speculator:
|
||||
from text_generation_server.utils.medusa import MedusaModel
|
||||
from huggingface_hub import hf_hub_download
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
is_local_model = (
|
||||
Path(speculator).exists() and Path(speculator).is_dir()
|
||||
) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
|
||||
|
||||
if not is_local_model:
|
||||
medusa_config = hf_hub_download(
|
||||
speculator, revision=revision, filename="config.json"
|
||||
)
|
||||
medusa_head = hf_hub_download(
|
||||
speculator, revision=revision, filename="medusa_lm_head.pt"
|
||||
)
|
||||
else:
|
||||
medusa_config = str(Path(speculator) / "config.json")
|
||||
medusa_head = str(Path(speculator) / "medusa_lm_head.pt")
|
||||
|
||||
with open(medusa_config, "r") as f:
|
||||
config = json.load(f)
|
||||
medusa_sf = medusa_head[: -len(".pt")] + ".safetensors"
|
||||
weights = Weights(
|
||||
[medusa_sf], device, dtype, process_group=self.process_group
|
||||
)
|
||||
lm_head = model.lm_head
|
||||
model.lm_head = MedusaModel(config, weights, lm_head)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashPhi, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.model.layers),
|
||||
num_kv_heads=model.model.num_key_value_heads,
|
||||
head_size=model.model.head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
|
@ -1,93 +0,0 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from opentelemetry import trace
|
||||
from transformers import AutoTokenizer, AutoConfig
|
||||
from typing import Optional
|
||||
|
||||
from text_generation_server.models.flash_mistral import (
|
||||
BaseFlashMistral,
|
||||
set_sliding_window,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
|
||||
Qwen2ForCausalLM,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
class FlashQwen2(BaseFlashMistral):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif SYSTEM == "ipex":
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
device = torch.device(f"xpu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError("FlashQwen2 is only available on GPU")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.speculator = speculator
|
||||
|
||||
# Set context windows
|
||||
if config.sliding_window is not None:
|
||||
set_sliding_window(config.sliding_window)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
if config.quantize in ["gptq", "awq", "marlin"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
model = Qwen2ForCausalLM(config, weights)
|
||||
|
||||
self.cuda_graphs = {}
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(BaseFlashMistral, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.model.layers),
|
||||
num_kv_heads=model.model.num_key_value_heads,
|
||||
head_size=model.model.head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
sliding_window=config.sliding_window,
|
||||
)
|
|
@ -1,91 +0,0 @@
|
|||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from opentelemetry import trace
|
||||
from transformers import AutoTokenizer
|
||||
from typing import Optional
|
||||
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
from text_generation_server.models.custom_modeling.flash_rw_modeling import (
|
||||
RWConfig,
|
||||
FlashRWForCausalLM,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
class FlashRWSharded(FlashCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif SYSTEM == "ipex":
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
device = torch.device(f"xpu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError("FlashRW is only available on GPU")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = RWConfig.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(
|
||||
filenames,
|
||||
device,
|
||||
dtype,
|
||||
process_group=self.process_group,
|
||||
aliases={
|
||||
"lm_head.weight": ["transformer.word_embeddings.weight"],
|
||||
"transformer.word_embeddings.weight": ["lm_head.weight"],
|
||||
},
|
||||
)
|
||||
|
||||
config.quantize = quantize
|
||||
config.speculator = speculator
|
||||
if config.quantize in ["gptq", "marlin"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
model = FlashRWForCausalLM(config, weights)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashRWSharded, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model.to(device),
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.transformer.h),
|
||||
num_kv_heads=model.transformer.cache_size,
|
||||
head_size=model.transformer.head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
|
@ -1,99 +0,0 @@
|
|||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from opentelemetry import trace
|
||||
from transformers import AutoTokenizer, AutoConfig
|
||||
from typing import Optional, List
|
||||
import json
|
||||
import os
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
|
||||
FlashSantacoderForCausalLM,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
class FlashSantacoderSharded(FlashCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif SYSTEM == "ipex":
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
device = torch.device(f"xpu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.speculator = speculator
|
||||
config.transpose = config.architectures[0].startswith("GPT2")
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(
|
||||
filenames,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
process_group=self.process_group,
|
||||
aliases={"transformer.wte.weight": ["lm_head.weight"]},
|
||||
)
|
||||
if config.quantize in ["gptq", "marlin"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
model = FlashSantacoderForCausalLM(config, weights)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashSantacoderSharded, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model.to(device),
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.transformer.h),
|
||||
num_kv_heads=1,
|
||||
head_size=model.transformer.head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
def decode(self, generated_ids: List[int]) -> str:
|
||||
# Do not skip special tokens as they are used for custom parsing rules of the generated text
|
||||
return self.tokenizer.decode(
|
||||
generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
|
||||
)
|
|
@ -1,84 +0,0 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from transformers.models.gpt2 import GPT2TokenizerFast
|
||||
|
||||
from text_generation_server.models.flash_mistral import (
|
||||
BaseFlashMistral,
|
||||
set_sliding_window,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
|
||||
Starcoder2Config,
|
||||
FlashStarcoder2ForCausalLM,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
|
||||
|
||||
# Starcoder2 has the same base as Mistral
|
||||
class FlashStarcoder2(BaseFlashMistral):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError("FlashStarcoder2 is only available on GPU")
|
||||
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = Starcoder2Config.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.speculator = speculator
|
||||
|
||||
# Set context windows
|
||||
if config.sliding_window is not None:
|
||||
set_sliding_window(config.sliding_window)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
if config.quantize in ["gptq", "awq", "marlin"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
model = FlashStarcoder2ForCausalLM(config, weights)
|
||||
|
||||
self.cuda_graphs = {}
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(BaseFlashMistral, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.model.layers),
|
||||
num_kv_heads=model.model.num_key_value_heads,
|
||||
head_size=model.model.head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
sliding_window=config.sliding_window,
|
||||
)
|
|
@ -162,83 +162,3 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
|||
padding_right_offset=padding_right_offset,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
|
||||
class GalacticaSharded(CausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
tp_parallel=True,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
config.quantize = quantize
|
||||
tokenizer.pad_token_id = config.pad_token_id
|
||||
config.speculator = speculator
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(
|
||||
filenames, device=device, dtype=dtype, process_group=self.process_group
|
||||
)
|
||||
if config.quantize in ["gptq", "marlin"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
model = OPTForCausalLM(config, weights)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(CausalLM, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
@property
|
||||
def batch_type(self) -> Type[CausalLMBatch]:
|
||||
return GalacticaCausalLMBatch
|
||||
|
||||
def decode(self, generated_ids: List[int]) -> str:
|
||||
# Do not skip special tokens as they are used for custom parsing rules of the generated text
|
||||
return self.tokenizer.decode(
|
||||
generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
):
|
||||
outputs, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
return outputs.logits, speculative_logits, outputs.past_key_values
|
||||
|
|
|
@ -1,89 +0,0 @@
|
|||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoConfig,
|
||||
)
|
||||
from text_generation_server.models import CausalLM
|
||||
from text_generation_server.models.custom_modeling.neox_modeling import (
|
||||
GPTNeoxForCausalLM,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
|
||||
|
||||
class GPTNeoxSharded(CausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.speculator = speculator
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(
|
||||
filenames, device=device, dtype=dtype, process_group=self.process_group
|
||||
)
|
||||
if config.quantize in ["gptq", "marlin"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
model = GPTNeoxForCausalLM(config, weights)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(CausalLM, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
):
|
||||
outputs, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
return outputs.logits, speculative_logits, outputs.past_key_values
|
|
@ -1,51 +0,0 @@
|
|||
import torch
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.idefics2 import (
|
||||
Idefics2ForConditionalGeneration,
|
||||
)
|
||||
|
||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
||||
|
||||
|
||||
class Idefics2(VlmCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
# XXX: Extremely important to cap resolution in order to limit
|
||||
# VRAM usage.
|
||||
size={"longest_edge": 448, "shortest_edge": 378},
|
||||
)
|
||||
super().__init__(
|
||||
model_cls=Idefics2ForConditionalGeneration,
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
def get_layer_config(self, model) -> Tuple[int, int, int]:
|
||||
return (
|
||||
len(model.text_model.model.layers),
|
||||
model.text_model.model.num_key_value_heads,
|
||||
model.text_model.model.head_size,
|
||||
)
|
||||
|
||||
def max_past(self) -> Optional[int]:
|
||||
return getattr(self.model.text_model, "max_past", None)
|
|
@ -1,46 +0,0 @@
|
|||
import torch
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.llava_next import (
|
||||
LlavaNextForConditionalGeneration,
|
||||
)
|
||||
|
||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
||||
|
||||
|
||||
class LlavaNext(VlmCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
super().__init__(
|
||||
model_cls=LlavaNextForConditionalGeneration,
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
def get_layer_config(self, model) -> Tuple[int, int, int]:
|
||||
return (
|
||||
len(model.language_model.model.layers),
|
||||
model.language_model.model.num_key_value_heads,
|
||||
model.language_model.model.head_size,
|
||||
)
|
||||
|
||||
def max_past(self) -> Optional[int]:
|
||||
return getattr(self.model.language_model, "max_past", None)
|
|
@ -63,7 +63,7 @@ class Model(ABC):
|
|||
self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict(
|
||||
LayerAdapterWeights
|
||||
)
|
||||
self.target_to_layer = self.adapter_target_to_layer()
|
||||
self.target_to_layer = None
|
||||
self.loaded_adapters = set()
|
||||
self.static_adapter_id = adapter_id
|
||||
|
||||
|
@ -206,6 +206,8 @@ class Model(ABC):
|
|||
into model. Otherwise, the adapter weights are applied during the forward
|
||||
pass and stored separately from the base model parameters.
|
||||
"""
|
||||
if self.target_to_layer is None:
|
||||
self.target_to_layer = self.adapter_target_to_layer()
|
||||
if adapter_index in self.loaded_adapters:
|
||||
# Adapter already loaded
|
||||
return
|
||||
|
|
|
@ -1,105 +0,0 @@
|
|||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, Type
|
||||
from opentelemetry import trace
|
||||
from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase
|
||||
from huggingface_hub import hf_hub_download
|
||||
import json
|
||||
|
||||
from text_generation_server.models import CausalLM
|
||||
from text_generation_server.models.causal_lm import CausalLMBatch
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.models.custom_modeling.mpt_modeling import (
|
||||
MPTForCausalLM,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
class MPTCausalLMBatch(CausalLMBatch):
|
||||
@classmethod
|
||||
def from_pb(
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "CausalLMBatch":
|
||||
batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device)
|
||||
batch.keys_head_dim_last = False
|
||||
return batch
|
||||
|
||||
|
||||
class MPTSharded(CausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# If model_id is a local path, load the file directly
|
||||
local_path = Path(model_id, "config.json")
|
||||
if local_path.exists():
|
||||
filename = str(local_path.resolve())
|
||||
else:
|
||||
filename = hf_hub_download(
|
||||
model_id, revision=revision, filename="config.json"
|
||||
)
|
||||
with open(filename, "r") as f:
|
||||
config = json.load(f)
|
||||
config = PretrainedConfig(**config)
|
||||
config.quantize = quantize
|
||||
config.speculator = speculator
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
if config.quantize in ["gptq", "marlin"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
config.quantize = quantize
|
||||
model = MPTForCausalLM(config, weights)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(CausalLM, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=False,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
@property
|
||||
def batch_type(self) -> Type[CausalLMBatch]:
|
||||
return MPTCausalLMBatch
|
|
@ -1,86 +0,0 @@
|
|||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoConfig,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
|
||||
from text_generation_server.models import CausalLM
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
|
||||
|
||||
class OPTSharded(CausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.speculator = speculator
|
||||
tokenizer.pad_token_id = config.pad_token_id
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(
|
||||
filenames, device=device, dtype=dtype, process_group=self.process_group
|
||||
)
|
||||
if config.quantize in ["gptq", "marlin"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
model = OPTForCausalLM(config, weights)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(CausalLM, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
):
|
||||
outputs, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
return outputs.logits, speculative_logits, outputs.past_key_values
|
|
@ -74,45 +74,3 @@ class PaliGemmaBatch(VlmCausalLMBatch):
|
|||
else:
|
||||
image_inputs = None
|
||||
return batch_tokenized_inputs, image_inputs
|
||||
|
||||
|
||||
class PaliGemma(VlmCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
config_cls=AutoConfig,
|
||||
model_cls=PaliGemmaForConditionalGeneration,
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
@property
|
||||
def batch_type(self):
|
||||
return PaliGemmaBatch
|
||||
|
||||
def get_layer_config(self, model) -> Tuple[int, int, int]:
|
||||
return (
|
||||
len(model.text_model.model.layers),
|
||||
model.text_model.model.num_key_value_heads,
|
||||
model.text_model.model.head_size,
|
||||
)
|
||||
|
||||
def max_past(self) -> Optional[int]:
|
||||
return getattr(self.model.text_model, "max_past", None)
|
||||
|
|
|
@ -1,69 +0,0 @@
|
|||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from text_generation_server.models import CausalLM
|
||||
from text_generation_server.models.custom_modeling.phi_modeling import (
|
||||
PhiConfig,
|
||||
PhiForCausalLM,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
|
||||
|
||||
class Phi(CausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, _rank, _world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
if quantize:
|
||||
raise ValueError("quantization is not available on CPU")
|
||||
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
config = PhiConfig.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
|
||||
tokenizer.bos_token_id = config.bos_token_id
|
||||
tokenizer.eos_token_id = config.eos_token_id
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
config.quantize = quantize
|
||||
config.speculator = speculator
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
model = PhiForCausalLM(config, weights)
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(CausalLM, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
|
@ -1,84 +0,0 @@
|
|||
import torch
|
||||
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from text_generation_server.models import CausalLM
|
||||
|
||||
|
||||
class RW(CausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
if speculator:
|
||||
raise RuntimeError("Medusa decoding is not enabled for AutoModel")
|
||||
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
if quantize:
|
||||
raise ValueError("quantization is not available on CPU")
|
||||
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
torch_dtype=dtype,
|
||||
device_map=(
|
||||
"auto"
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
||||
else None
|
||||
),
|
||||
load_in_8bit=quantize == "bitsandbytes",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
||||
model = model.cuda()
|
||||
|
||||
if tokenizer.pad_token_id is None:
|
||||
if model.config.pad_token_id is not None:
|
||||
tokenizer.pad_token_id = model.config.pad_token_id
|
||||
elif model.config.eos_token_id is not None:
|
||||
tokenizer.pad_token_id = model.config.eos_token_id
|
||||
elif tokenizer.eos_token_id is not None:
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
else:
|
||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
|
||||
super(CausalLM, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
):
|
||||
# Model Forward
|
||||
outputs, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
return outputs.logits, speculative_logits, outputs.past_key_values
|
|
@ -1,77 +0,0 @@
|
|||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from typing import Optional, List
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
from text_generation_server.models import CausalLM
|
||||
|
||||
FIM_PREFIX = "<fim-prefix>"
|
||||
FIM_MIDDLE = "<fim-middle>"
|
||||
FIM_SUFFIX = "<fim-suffix>"
|
||||
FIM_PAD = "<fim-pad>"
|
||||
EOD = "<|endoftext|>"
|
||||
|
||||
|
||||
class SantaCoder(CausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
if quantize:
|
||||
raise ValueError("quantization is not available on CPU")
|
||||
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
tokenizer.add_special_tokens(
|
||||
{
|
||||
"additional_special_tokens": [
|
||||
EOD,
|
||||
FIM_PREFIX,
|
||||
FIM_MIDDLE,
|
||||
FIM_SUFFIX,
|
||||
FIM_PAD,
|
||||
],
|
||||
"pad_token": EOD,
|
||||
}
|
||||
)
|
||||
with device:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
torch_dtype=dtype,
|
||||
load_in_8bit=quantize == "bitsandbytes",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
super(CausalLM, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def decode(self, generated_ids: List[int]) -> str:
|
||||
# Do not skip special tokens as they are used for custom parsing rules of the generated text
|
||||
return self.tokenizer.decode(
|
||||
generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
|
||||
)
|
|
@ -1,11 +1,22 @@
|
|||
import torch
|
||||
import torch.distributed
|
||||
import time
|
||||
|
||||
from dataclasses import dataclass
|
||||
from opentelemetry import trace
|
||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoModelForSeq2SeqLM,
|
||||
PreTrainedTokenizerBase,
|
||||
AutoConfig,
|
||||
)
|
||||
from typing import Optional, Tuple, List, Type, Dict
|
||||
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.chunks import concat_text_chunks
|
||||
from text_generation_server.utils.tokens import batch_top_tokens
|
||||
from text_generation_server.models import Model
|
||||
|
@ -531,6 +542,80 @@ class Seq2SeqLM(Model):
|
|||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
model_class,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
default_dtype=torch.float16,
|
||||
trust_remote_code: bool = False,
|
||||
config_class=AutoConfig,
|
||||
tokenizer_class=AutoTokenizer,
|
||||
aliases=None,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = default_dtype if dtype is None else dtype
|
||||
elif SYSTEM == "ipex":
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
device = torch.device(f"xpu:{rank}")
|
||||
dtype = default_dtype if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
# Float16 doesn't exist on target.
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
||||
config = config_class.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.speculator = speculator
|
||||
|
||||
tokenizer = tokenizer_class.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
tokenizer.bos_token_id = config.decoder_start_token_id
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(
|
||||
filenames,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
process_group=self.process_group,
|
||||
aliases=aliases,
|
||||
)
|
||||
if config.quantize in ["awq", "exl2", "gptq", "marlin"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
model = model_class(config, weights)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super().__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def fallback(
|
||||
cls,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
|
@ -574,7 +659,11 @@ class Seq2SeqLM(Model):
|
|||
)
|
||||
tokenizer.bos_token_id = model.config.decoder_start_token_id
|
||||
|
||||
super(Seq2SeqLM, self).__init__(
|
||||
self = cls.__new__(
|
||||
cls,
|
||||
)
|
||||
super().__init__(
|
||||
self,
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
|
@ -582,16 +671,12 @@ class Seq2SeqLM(Model):
|
|||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
return self
|
||||
|
||||
@property
|
||||
def batch_type(self) -> Type[Seq2SeqLMBatch]:
|
||||
return Seq2SeqLMBatch
|
||||
|
||||
def decode(self, decoder_ids: List[int]) -> str:
|
||||
return self.tokenizer.decode(
|
||||
decoder_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
|
|
|
@ -1,115 +0,0 @@
|
|||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoConfig,
|
||||
)
|
||||
|
||||
from text_generation_server.models import Seq2SeqLM
|
||||
from text_generation_server.models.custom_modeling.t5_modeling import (
|
||||
T5ForConditionalGeneration,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
|
||||
|
||||
class T5Sharded(Seq2SeqLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.speculator = speculator
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
tokenizer.bos_token_id = config.decoder_start_token_id
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(
|
||||
filenames,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
process_group=self.process_group,
|
||||
aliases={
|
||||
"shared.weight": [
|
||||
"encoder.embed_tokens.weight",
|
||||
"decoder.embed_tokens.weight",
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
model = T5ForConditionalGeneration(config, weights)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(Seq2SeqLM, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask: Optional,
|
||||
encoder_last_hidden_state: Optional,
|
||||
past_key_values: Optional = None,
|
||||
) -> Tuple[
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
|
||||
]:
|
||||
# Model Forward
|
||||
outputs, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
encoder_outputs=encoder_last_hidden_state,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
return (
|
||||
outputs.logits,
|
||||
speculative_logits,
|
||||
outputs.encoder_last_hidden_state,
|
||||
outputs.past_key_values,
|
||||
)
|
|
@ -9,10 +9,11 @@ from typing import Iterable, Optional, Tuple, List, Type, Dict
|
|||
from transformers import PreTrainedTokenizerBase
|
||||
from transformers.image_processing_utils import select_best_resolution
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch
|
||||
from text_generation_server.models.flash_mistral import (
|
||||
BaseFlashMistral,
|
||||
from text_generation_server.models.flash_causal_lm import (
|
||||
FlashCausalLMBatch,
|
||||
FlashCausalLM,
|
||||
)
|
||||
from transformers import AutoProcessor
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
@ -239,10 +240,35 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||
return batch
|
||||
|
||||
|
||||
class VlmCausalLM(BaseFlashMistral):
|
||||
class VlmCausalLM(FlashCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
*,
|
||||
processor_class=AutoProcessor,
|
||||
processor_kwargs=None,
|
||||
batch_class=VlmCausalLMBatch,
|
||||
revision,
|
||||
trust_remote_code: bool,
|
||||
**kwargs,
|
||||
):
|
||||
if processor_kwargs is None:
|
||||
processor_kwargs = {}
|
||||
self.processor = processor_class.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**processor_kwargs,
|
||||
)
|
||||
self.batch_class = batch_class
|
||||
super().__init__(model_id=model_id, **kwargs)
|
||||
|
||||
@property
|
||||
def batch_type(self) -> Type[VlmCausalLMBatch]:
|
||||
return VlmCausalLMBatch
|
||||
return self.batch_class
|
||||
|
||||
def max_past(self) -> Optional[int]:
|
||||
return getattr(self.model.text_model, "max_past", None)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
import subprocess
|
||||
import argparse
|
||||
import ast
|
||||
import json
|
||||
import os
|
||||
|
||||
TEMPLATE = """
|
||||
# Supported Models and Hardware
|
||||
|
@ -122,6 +124,53 @@ def check_supported_models(check: bool):
|
|||
f.write(final_doc)
|
||||
|
||||
|
||||
def get_openapi_schema():
|
||||
try:
|
||||
output = subprocess.check_output(["text-generation-router", "print-schema"])
|
||||
return json.loads(output)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error running text-generation-router print-schema: {e}")
|
||||
raise SystemExit(1)
|
||||
except json.JSONDecodeError:
|
||||
print("Error: Invalid JSON received from text-generation-router print-schema")
|
||||
raise SystemExit(1)
|
||||
|
||||
|
||||
def check_openapi(check: bool):
|
||||
new_openapi_data = get_openapi_schema()
|
||||
filename = "docs/openapi.json"
|
||||
tmp_filename = "openapi_tmp.json"
|
||||
|
||||
with open(tmp_filename, "w") as f:
|
||||
json.dump(new_openapi_data, f, indent=2)
|
||||
|
||||
if check:
|
||||
diff = subprocess.run(
|
||||
[
|
||||
"diff",
|
||||
# allow for trailing whitespace since it's not significant
|
||||
# and the precommit hook will remove it
|
||||
"--ignore-trailing-space",
|
||||
tmp_filename,
|
||||
filename,
|
||||
],
|
||||
capture_output=True,
|
||||
).stdout.decode()
|
||||
os.remove(tmp_filename)
|
||||
|
||||
if diff:
|
||||
print(diff)
|
||||
raise Exception(
|
||||
"OpenAPI documentation is not up-to-date, run `python update_doc.py` in order to update it"
|
||||
)
|
||||
|
||||
return True
|
||||
else:
|
||||
os.rename(tmp_filename, filename)
|
||||
print("OpenAPI documentation updated.")
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--check", action="store_true")
|
||||
|
@ -130,6 +179,7 @@ def main():
|
|||
|
||||
check_cli(args.check)
|
||||
check_supported_models(args.check)
|
||||
check_openapi(args.check)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue