feat(server): Rework model loading (#344)

# What does this PR do?

Reworked the loading logic. Idea is to use cleaner loading code:

- Remove need for `no_init_weights`
- Remove all weird `bnb_linear` and `load_weights` and
`post_load_weights`.

New code layout:

- New class `Weights` in charge of handling loading the weights from
multiple files into appropiate tensors (potentially sharded)
- TP layers now are "shells", they contain the code to know what kind of
sharding we need + eventual `all_reduce`. They do not inherit from
linear, but they contain some kind of Linear instead
- the contained linear can be either FastLinear, BnbLinear or GPTq
Linear next.
- All modeling code is explictly made for sharding, process group is
just no-ops for non sharded code (removes a lot of test cases)

![Screenshot from 2023-05-19
23-19-59](https://github.com/huggingface/text-generation-inference/assets/204321/9a802654-74a3-488c-87a8-073743a6143f)

---------

Co-authored-by: Ubuntu <ubuntu@ip-172-31-41-161.taildb5d.ts.net>
Co-authored-by: Ubuntu <ubuntu@ip-172-31-41-161.ec2.internal>
Co-authored-by: OlivierDehaene <olivier@huggingface.co>
Co-authored-by: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com>
This commit is contained in:
Nicolas Patry 2023-06-08 14:51:52 +02:00 committed by GitHub
parent 19c41824cb
commit abd58ff82c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
43 changed files with 6806 additions and 2793 deletions

1
.gitignore vendored
View File

@ -1,3 +1,4 @@
.idea
target
router/tokenizer.json
*__pycache__*

View File

@ -2,6 +2,8 @@
FROM lukemathwalker/cargo-chef:latest-rust-1.69 AS chef
WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
FROM chef as planner
COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml
@ -98,14 +100,14 @@ COPY server/Makefile-flash-att Makefile
RUN make build-flash-attention
# Build Transformers CUDA kernels
FROM kernel-builder as transformers-builder
FROM kernel-builder as custom-kernels-builder
WORKDIR /usr/src
COPY server/Makefile-transformers Makefile
COPY server/custom_kernels/ .
# Build specific version of transformers
RUN BUILD_EXTENSIONS="True" make build-transformers
RUN python setup.py build
# Text Generation Inference base image
FROM nvidia/cuda:11.8.0-base-ubuntu20.04 as base
@ -136,11 +138,10 @@ COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
# Copy build artifacts from transformers builder
COPY --from=transformers-builder /usr/src/transformers /usr/src/transformers
COPY --from=transformers-builder /usr/src/transformers/build/lib.linux-x86_64-cpython-39/transformers /usr/src/transformers/src/transformers
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39/custom_kernels /usr/src/custom-kernels/src/custom_kernels
# Install transformers dependencies
RUN cd /usr/src/transformers && pip install -e . --no-cache-dir && pip install einops --no-cache-dir
# Install flash-attention dependencies
RUN pip install einops --no-cache-dir
# Install server
COPY proto proto

View File

@ -1,6 +1,9 @@
install-server:
cd server && make install
install-custom-kernels:
if [ "$$BUILD_EXTENSIONS" == "True" ]; then cd server/custom_kernels && python setup.py install; else echo "Custom kernels are disabled, you need set to BUILD_EXTENSION environment variable to 'True' in order to build them. (Please read the docs, kernels might not work on all hardware)"; fi
install-integration-tests:
cd integration-tests && pip install -r requirements.txt
cd clients/python && pip install .
@ -14,7 +17,7 @@ install-launcher:
install-benchmark:
cd benchmark && cargo install --path .
install: install-server install-router install-launcher
install: install-server install-router install-launcher install-custom-kernels
server-dev:
cd server && make run-dev

View File

@ -209,6 +209,7 @@ def launcher(event_loop):
num_shard: Optional[int] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
use_flash_attention: bool = True,
):
port = random.randint(8000, 10_000)
master_port = random.randint(10_000, 20_000)
@ -240,6 +241,9 @@ def launcher(event_loop):
env = os.environ
env["LOG_LEVEL"] = "info,text_generation_router=debug"
if not use_flash_attention:
env["USE_FLASH_ATTENTION"] = "false"
with subprocess.Popen(
args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env
) as process:
@ -254,12 +258,16 @@ def launcher(event_loop):
process.stdout.close()
process.stderr.close()
if not use_flash_attention:
del env["USE_FLASH_ATTENTION"]
@contextlib.contextmanager
def docker_launcher(
model_id: str,
num_shard: Optional[int] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
use_flash_attention: bool = True,
):
port = random.randint(8000, 10_000)
@ -287,6 +295,9 @@ def launcher(event_loop):
gpu_count = num_shard if num_shard is not None else 1
env = {"LOG_LEVEL": "info,text_generation_router=debug"}
if not use_flash_attention:
env["USE_FLASH_ATTENTION"] = "false"
if HUGGING_FACE_HUB_TOKEN is not None:
env["HUGGING_FACE_HUB_TOKEN"] = HUGGING_FACE_HUB_TOKEN
@ -310,6 +321,9 @@ def launcher(event_loop):
yield ContainerLauncherHandle(client, container.name, port)
if not use_flash_attention:
del env["USE_FLASH_ATTENTION"]
try:
container.stop()
container.wait()

View File

@ -11,17 +11,17 @@
},
{
"id": 1459,
"logprob": -5.6289062,
"logprob": -5.6328125,
"text": " print"
},
{
"id": 81,
"logprob": -1.6005859,
"logprob": -1.6035156,
"text": "_"
},
{
"id": 7656,
"logprob": -5.9921875,
"logprob": -5.9882812,
"text": "hello"
}
],
@ -59,19 +59,19 @@
},
{
"id": 10896,
"logprob": -0.3659668,
"logprob": -0.38549805,
"special": false,
"text": " World"
},
{
"id": 657,
"logprob": -0.49804688,
"logprob": -0.5229492,
"special": false,
"text": "\")"
},
{
"id": 203,
"logprob": -0.11279297,
"logprob": -0.10632324,
"special": false,
"text": "\n"
},
@ -113,7 +113,7 @@
},
{
"id": 426,
"logprob": -0.051635742,
"logprob": 0.0,
"special": false,
"text": "name"
},

View File

@ -0,0 +1,113 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 50278,
"logprob": null,
"text": "<|USER|>"
},
{
"id": 1276,
"logprob": -4.5546875,
"text": "What"
},
{
"id": 434,
"logprob": -4.1992188,
"text": "'s"
},
{
"id": 634,
"logprob": -5.125,
"text": " your"
},
{
"id": 12315,
"logprob": -9.8984375,
"text": " mood"
},
{
"id": 3063,
"logprob": -4.0976562,
"text": " today"
},
{
"id": 32,
"logprob": -0.14562988,
"text": "?"
},
{
"id": 50279,
"logprob": -0.26733398,
"text": "<|ASSISTANT|>"
}
],
"seed": null,
"tokens": [
{
"id": 42,
"logprob": -0.86279297,
"special": false,
"text": "I"
},
{
"id": 1353,
"logprob": -0.94921875,
"special": false,
"text": "'m"
},
{
"id": 7016,
"logprob": -2.1835938,
"special": false,
"text": " sorry"
},
{
"id": 13,
"logprob": -0.074035645,
"special": false,
"text": ","
},
{
"id": 1394,
"logprob": -0.86376953,
"special": false,
"text": "You"
},
{
"id": 452,
"logprob": -1.2070312,
"special": false,
"text": " have"
},
{
"id": 247,
"logprob": -1.4365234,
"special": false,
"text": " a"
},
{
"id": 4327,
"logprob": -1.109375,
"special": false,
"text": " choice"
},
{
"id": 273,
"logprob": -0.93408203,
"special": false,
"text": " of"
},
{
"id": 752,
"logprob": -1.8808594,
"special": false,
"text": " what"
}
]
},
"generated_text": "I'm sorry,You have a choice of what"
}

View File

@ -0,0 +1,454 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 50278,
"logprob": null,
"text": "<|USER|>"
},
{
"id": 1276,
"logprob": -4.5546875,
"text": "What"
},
{
"id": 434,
"logprob": -4.1953125,
"text": "'s"
},
{
"id": 634,
"logprob": -5.125,
"text": " your"
},
{
"id": 12315,
"logprob": -9.8828125,
"text": " mood"
},
{
"id": 3063,
"logprob": -3.9980469,
"text": " today"
},
{
"id": 32,
"logprob": -0.14672852,
"text": "?"
},
{
"id": 50279,
"logprob": -0.26489258,
"text": "<|ASSISTANT|>"
}
],
"seed": null,
"tokens": [
{
"id": 42,
"logprob": -0.8618164,
"special": false,
"text": "I"
},
{
"id": 1353,
"logprob": -0.9506836,
"special": false,
"text": "'m"
},
{
"id": 7016,
"logprob": -2.1738281,
"special": false,
"text": " sorry"
},
{
"id": 13,
"logprob": -0.0758667,
"special": false,
"text": ","
},
{
"id": 1394,
"logprob": -0.9135742,
"special": false,
"text": "You"
},
{
"id": 452,
"logprob": -1.1445312,
"special": false,
"text": " have"
},
{
"id": 247,
"logprob": -1.4375,
"special": false,
"text": " a"
},
{
"id": 4327,
"logprob": -1.1103516,
"special": false,
"text": " choice"
},
{
"id": 273,
"logprob": -1.0058594,
"special": false,
"text": " of"
},
{
"id": 752,
"logprob": -1.921875,
"special": false,
"text": " what"
}
]
},
"generated_text": "I'm sorry,You have a choice of what"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 50278,
"logprob": null,
"text": "<|USER|>"
},
{
"id": 1276,
"logprob": -4.5546875,
"text": "What"
},
{
"id": 434,
"logprob": -4.1953125,
"text": "'s"
},
{
"id": 634,
"logprob": -5.125,
"text": " your"
},
{
"id": 12315,
"logprob": -9.8828125,
"text": " mood"
},
{
"id": 3063,
"logprob": -3.9980469,
"text": " today"
},
{
"id": 32,
"logprob": -0.14672852,
"text": "?"
},
{
"id": 50279,
"logprob": -0.26489258,
"text": "<|ASSISTANT|>"
}
],
"seed": null,
"tokens": [
{
"id": 42,
"logprob": -0.8618164,
"special": false,
"text": "I"
},
{
"id": 1353,
"logprob": -0.9506836,
"special": false,
"text": "'m"
},
{
"id": 7016,
"logprob": -2.1738281,
"special": false,
"text": " sorry"
},
{
"id": 13,
"logprob": -0.0758667,
"special": false,
"text": ","
},
{
"id": 1394,
"logprob": -0.9135742,
"special": false,
"text": "You"
},
{
"id": 452,
"logprob": -1.1445312,
"special": false,
"text": " have"
},
{
"id": 247,
"logprob": -1.4375,
"special": false,
"text": " a"
},
{
"id": 4327,
"logprob": -1.1103516,
"special": false,
"text": " choice"
},
{
"id": 273,
"logprob": -1.0058594,
"special": false,
"text": " of"
},
{
"id": 752,
"logprob": -1.921875,
"special": false,
"text": " what"
}
]
},
"generated_text": "I'm sorry,You have a choice of what"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 50278,
"logprob": null,
"text": "<|USER|>"
},
{
"id": 1276,
"logprob": -4.5546875,
"text": "What"
},
{
"id": 434,
"logprob": -4.1953125,
"text": "'s"
},
{
"id": 634,
"logprob": -5.125,
"text": " your"
},
{
"id": 12315,
"logprob": -9.8828125,
"text": " mood"
},
{
"id": 3063,
"logprob": -3.9980469,
"text": " today"
},
{
"id": 32,
"logprob": -0.14672852,
"text": "?"
},
{
"id": 50279,
"logprob": -0.26489258,
"text": "<|ASSISTANT|>"
}
],
"seed": null,
"tokens": [
{
"id": 42,
"logprob": -0.8618164,
"special": false,
"text": "I"
},
{
"id": 1353,
"logprob": -0.9506836,
"special": false,
"text": "'m"
},
{
"id": 7016,
"logprob": -2.1738281,
"special": false,
"text": " sorry"
},
{
"id": 13,
"logprob": -0.0758667,
"special": false,
"text": ","
},
{
"id": 1394,
"logprob": -0.9135742,
"special": false,
"text": "You"
},
{
"id": 452,
"logprob": -1.1445312,
"special": false,
"text": " have"
},
{
"id": 247,
"logprob": -1.4375,
"special": false,
"text": " a"
},
{
"id": 4327,
"logprob": -1.1103516,
"special": false,
"text": " choice"
},
{
"id": 273,
"logprob": -1.0058594,
"special": false,
"text": " of"
},
{
"id": 752,
"logprob": -1.921875,
"special": false,
"text": " what"
}
]
},
"generated_text": "I'm sorry,You have a choice of what"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 50278,
"logprob": null,
"text": "<|USER|>"
},
{
"id": 1276,
"logprob": -4.5546875,
"text": "What"
},
{
"id": 434,
"logprob": -4.1953125,
"text": "'s"
},
{
"id": 634,
"logprob": -5.125,
"text": " your"
},
{
"id": 12315,
"logprob": -9.8828125,
"text": " mood"
},
{
"id": 3063,
"logprob": -3.9980469,
"text": " today"
},
{
"id": 32,
"logprob": -0.14672852,
"text": "?"
},
{
"id": 50279,
"logprob": -0.26489258,
"text": "<|ASSISTANT|>"
}
],
"seed": null,
"tokens": [
{
"id": 42,
"logprob": -0.8618164,
"special": false,
"text": "I"
},
{
"id": 1353,
"logprob": -0.9506836,
"special": false,
"text": "'m"
},
{
"id": 7016,
"logprob": -2.1738281,
"special": false,
"text": " sorry"
},
{
"id": 13,
"logprob": -0.0758667,
"special": false,
"text": ","
},
{
"id": 1394,
"logprob": -0.9135742,
"special": false,
"text": "You"
},
{
"id": 452,
"logprob": -1.1445312,
"special": false,
"text": " have"
},
{
"id": 247,
"logprob": -1.4375,
"special": false,
"text": " a"
},
{
"id": 4327,
"logprob": -1.1103516,
"special": false,
"text": " choice"
},
{
"id": 273,
"logprob": -1.0058594,
"special": false,
"text": " of"
},
{
"id": 752,
"logprob": -1.921875,
"special": false,
"text": " what"
}
]
},
"generated_text": "I'm sorry,You have a choice of what"
}
]

View File

@ -0,0 +1,163 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 50278,
"logprob": null,
"text": "<|prompter|>"
},
{
"id": 1276,
"logprob": -8.0234375,
"text": "What"
},
{
"id": 310,
"logprob": -5.4179688,
"text": " is"
},
{
"id": 247,
"logprob": -2.1542969,
"text": " a"
},
{
"id": 1167,
"logprob": -5.359375,
"text": " mem"
},
{
"id": 70,
"logprob": -0.006038666,
"text": "e"
},
{
"id": 13,
"logprob": -7.328125,
"text": ","
},
{
"id": 285,
"logprob": -0.3173828,
"text": " and"
},
{
"id": 752,
"logprob": -2.0625,
"text": " what"
},
{
"id": 434,
"logprob": -5.7734375,
"text": "'s"
},
{
"id": 253,
"logprob": -0.74072266,
"text": " the"
},
{
"id": 2892,
"logprob": -6.5898438,
"text": " history"
},
{
"id": 3212,
"logprob": -2.2949219,
"text": " behind"
},
{
"id": 436,
"logprob": -11.40625,
"text": " this"
},
{
"id": 3159,
"logprob": -2.1113281,
"text": " word"
},
{
"id": 32,
"logprob": -0.008056641,
"text": "?"
},
{
"id": 0,
"logprob": -2.3300781,
"text": "<|endoftext|>"
},
{
"id": 50281,
"logprob": -18.28125,
"text": "<|assistant|>"
}
],
"seed": null,
"tokens": [
{
"id": 510,
"logprob": -0.5878906,
"special": false,
"text": "The"
},
{
"id": 3159,
"logprob": -0.5449219,
"special": false,
"text": " word"
},
{
"id": 346,
"logprob": -0.05038452,
"special": false,
"text": " \""
},
{
"id": 6441,
"logprob": -0.002292633,
"special": false,
"text": "mem"
},
{
"id": 70,
"logprob": -1.3828278e-05,
"special": false,
"text": "e"
},
{
"id": 3,
"logprob": -0.0010242462,
"special": false,
"text": "\""
},
{
"id": 369,
"logprob": -0.090270996,
"special": false,
"text": " was"
},
{
"id": 806,
"logprob": -0.12719727,
"special": false,
"text": " first"
},
{
"id": 908,
"logprob": -0.016571045,
"special": false,
"text": " used"
},
{
"id": 275,
"logprob": -0.43432617,
"special": false,
"text": " in"
}
]
},
"generated_text": "The word \"meme\" was first used in"
}

View File

@ -0,0 +1,654 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 50278,
"logprob": null,
"text": "<|prompter|>"
},
{
"id": 1276,
"logprob": -8.0234375,
"text": "What"
},
{
"id": 310,
"logprob": -5.4179688,
"text": " is"
},
{
"id": 247,
"logprob": -2.1542969,
"text": " a"
},
{
"id": 1167,
"logprob": -5.359375,
"text": " mem"
},
{
"id": 70,
"logprob": -0.006038666,
"text": "e"
},
{
"id": 13,
"logprob": -7.328125,
"text": ","
},
{
"id": 285,
"logprob": -0.3173828,
"text": " and"
},
{
"id": 752,
"logprob": -2.0625,
"text": " what"
},
{
"id": 434,
"logprob": -5.7734375,
"text": "'s"
},
{
"id": 253,
"logprob": -0.74072266,
"text": " the"
},
{
"id": 2892,
"logprob": -6.5898438,
"text": " history"
},
{
"id": 3212,
"logprob": -2.2949219,
"text": " behind"
},
{
"id": 436,
"logprob": -11.40625,
"text": " this"
},
{
"id": 3159,
"logprob": -2.1113281,
"text": " word"
},
{
"id": 32,
"logprob": -0.008056641,
"text": "?"
},
{
"id": 0,
"logprob": -2.3300781,
"text": "<|endoftext|>"
},
{
"id": 50281,
"logprob": -18.28125,
"text": "<|assistant|>"
}
],
"seed": null,
"tokens": [
{
"id": 510,
"logprob": -0.5878906,
"special": false,
"text": "The"
},
{
"id": 3159,
"logprob": -0.5498047,
"special": false,
"text": " word"
},
{
"id": 346,
"logprob": -0.04815674,
"special": false,
"text": " \""
},
{
"id": 6441,
"logprob": -0.002313614,
"special": false,
"text": "mem"
},
{
"id": 70,
"logprob": -1.2636185e-05,
"special": false,
"text": "e"
},
{
"id": 3,
"logprob": -0.0010147095,
"special": false,
"text": "\""
},
{
"id": 369,
"logprob": -0.0859375,
"special": false,
"text": " was"
},
{
"id": 806,
"logprob": -0.12609863,
"special": false,
"text": " first"
},
{
"id": 908,
"logprob": -0.016601562,
"special": false,
"text": " used"
},
{
"id": 275,
"logprob": -0.38256836,
"special": false,
"text": " in"
}
]
},
"generated_text": "The word \"meme\" was first used in"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 50278,
"logprob": null,
"text": "<|prompter|>"
},
{
"id": 1276,
"logprob": -8.0234375,
"text": "What"
},
{
"id": 310,
"logprob": -5.421875,
"text": " is"
},
{
"id": 247,
"logprob": -2.1640625,
"text": " a"
},
{
"id": 1167,
"logprob": -5.40625,
"text": " mem"
},
{
"id": 70,
"logprob": -0.005420685,
"text": "e"
},
{
"id": 13,
"logprob": -7.2226562,
"text": ","
},
{
"id": 285,
"logprob": -0.26879883,
"text": " and"
},
{
"id": 752,
"logprob": -2.1992188,
"text": " what"
},
{
"id": 434,
"logprob": -5.46875,
"text": "'s"
},
{
"id": 253,
"logprob": -0.8017578,
"text": " the"
},
{
"id": 2892,
"logprob": -6.6796875,
"text": " history"
},
{
"id": 3212,
"logprob": -2.1972656,
"text": " behind"
},
{
"id": 436,
"logprob": -11.4453125,
"text": " this"
},
{
"id": 3159,
"logprob": -2.1933594,
"text": " word"
},
{
"id": 32,
"logprob": -0.007858276,
"text": "?"
},
{
"id": 0,
"logprob": -2.328125,
"text": "<|endoftext|>"
},
{
"id": 50281,
"logprob": -18.21875,
"text": "<|assistant|>"
}
],
"seed": null,
"tokens": [
{
"id": 510,
"logprob": -0.6201172,
"special": false,
"text": "The"
},
{
"id": 3159,
"logprob": -0.546875,
"special": false,
"text": " word"
},
{
"id": 346,
"logprob": -0.051879883,
"special": false,
"text": " \""
},
{
"id": 6441,
"logprob": -0.0020179749,
"special": false,
"text": "mem"
},
{
"id": 70,
"logprob": -9.059906e-06,
"special": false,
"text": "e"
},
{
"id": 3,
"logprob": -0.00096797943,
"special": false,
"text": "\""
},
{
"id": 369,
"logprob": -0.07940674,
"special": false,
"text": " was"
},
{
"id": 806,
"logprob": -0.12182617,
"special": false,
"text": " first"
},
{
"id": 908,
"logprob": -0.017227173,
"special": false,
"text": " used"
},
{
"id": 275,
"logprob": -0.44482422,
"special": false,
"text": " in"
}
]
},
"generated_text": "The word \"meme\" was first used in"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 50278,
"logprob": null,
"text": "<|prompter|>"
},
{
"id": 1276,
"logprob": -8.0234375,
"text": "What"
},
{
"id": 310,
"logprob": -5.421875,
"text": " is"
},
{
"id": 247,
"logprob": -2.1640625,
"text": " a"
},
{
"id": 1167,
"logprob": -5.40625,
"text": " mem"
},
{
"id": 70,
"logprob": -0.005420685,
"text": "e"
},
{
"id": 13,
"logprob": -7.2226562,
"text": ","
},
{
"id": 285,
"logprob": -0.26879883,
"text": " and"
},
{
"id": 752,
"logprob": -2.1992188,
"text": " what"
},
{
"id": 434,
"logprob": -5.46875,
"text": "'s"
},
{
"id": 253,
"logprob": -0.8017578,
"text": " the"
},
{
"id": 2892,
"logprob": -6.6796875,
"text": " history"
},
{
"id": 3212,
"logprob": -2.1972656,
"text": " behind"
},
{
"id": 436,
"logprob": -11.4453125,
"text": " this"
},
{
"id": 3159,
"logprob": -2.1933594,
"text": " word"
},
{
"id": 32,
"logprob": -0.007858276,
"text": "?"
},
{
"id": 0,
"logprob": -2.328125,
"text": "<|endoftext|>"
},
{
"id": 50281,
"logprob": -18.21875,
"text": "<|assistant|>"
}
],
"seed": null,
"tokens": [
{
"id": 510,
"logprob": -0.6201172,
"special": false,
"text": "The"
},
{
"id": 3159,
"logprob": -0.546875,
"special": false,
"text": " word"
},
{
"id": 346,
"logprob": -0.051879883,
"special": false,
"text": " \""
},
{
"id": 6441,
"logprob": -0.0020179749,
"special": false,
"text": "mem"
},
{
"id": 70,
"logprob": -9.059906e-06,
"special": false,
"text": "e"
},
{
"id": 3,
"logprob": -0.00096797943,
"special": false,
"text": "\""
},
{
"id": 369,
"logprob": -0.07940674,
"special": false,
"text": " was"
},
{
"id": 806,
"logprob": -0.12182617,
"special": false,
"text": " first"
},
{
"id": 908,
"logprob": -0.017227173,
"special": false,
"text": " used"
},
{
"id": 275,
"logprob": -0.44482422,
"special": false,
"text": " in"
}
]
},
"generated_text": "The word \"meme\" was first used in"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 50278,
"logprob": null,
"text": "<|prompter|>"
},
{
"id": 1276,
"logprob": -8.0234375,
"text": "What"
},
{
"id": 310,
"logprob": -5.421875,
"text": " is"
},
{
"id": 247,
"logprob": -2.1640625,
"text": " a"
},
{
"id": 1167,
"logprob": -5.40625,
"text": " mem"
},
{
"id": 70,
"logprob": -0.005420685,
"text": "e"
},
{
"id": 13,
"logprob": -7.2226562,
"text": ","
},
{
"id": 285,
"logprob": -0.26879883,
"text": " and"
},
{
"id": 752,
"logprob": -2.1992188,
"text": " what"
},
{
"id": 434,
"logprob": -5.46875,
"text": "'s"
},
{
"id": 253,
"logprob": -0.8017578,
"text": " the"
},
{
"id": 2892,
"logprob": -6.6796875,
"text": " history"
},
{
"id": 3212,
"logprob": -2.1972656,
"text": " behind"
},
{
"id": 436,
"logprob": -11.4453125,
"text": " this"
},
{
"id": 3159,
"logprob": -2.1933594,
"text": " word"
},
{
"id": 32,
"logprob": -0.007858276,
"text": "?"
},
{
"id": 0,
"logprob": -2.328125,
"text": "<|endoftext|>"
},
{
"id": 50281,
"logprob": -18.21875,
"text": "<|assistant|>"
}
],
"seed": null,
"tokens": [
{
"id": 510,
"logprob": -0.6201172,
"special": false,
"text": "The"
},
{
"id": 3159,
"logprob": -0.546875,
"special": false,
"text": " word"
},
{
"id": 346,
"logprob": -0.051879883,
"special": false,
"text": " \""
},
{
"id": 6441,
"logprob": -0.0020179749,
"special": false,
"text": "mem"
},
{
"id": 70,
"logprob": -1.04904175e-05,
"special": false,
"text": "e"
},
{
"id": 3,
"logprob": -0.0009560585,
"special": false,
"text": "\""
},
{
"id": 369,
"logprob": -0.08557129,
"special": false,
"text": " was"
},
{
"id": 806,
"logprob": -0.12084961,
"special": false,
"text": " first"
},
{
"id": 908,
"logprob": -0.01737976,
"special": false,
"text": " used"
},
{
"id": 275,
"logprob": -0.4025879,
"special": false,
"text": " in"
}
]
},
"generated_text": "The word \"meme\" was first used in"
}
]

View File

@ -37,8 +37,8 @@ async def test_flash_neox_load(flash_neox, generate_load, response_snapshot):
generated_texts = [r.generated_text for r in responses]
assert len(generated_texts) == 4
assert generated_texts, all(
assert all(
[text == generated_texts[0] for text in generated_texts]
)
), generated_texts
assert responses == response_snapshot

View File

@ -0,0 +1,48 @@
import pytest
@pytest.fixture(scope="module")
def neox_handle(launcher):
with launcher(
"stabilityai/stablelm-tuned-alpha-3b", num_shard=1, use_flash_attention=False
) as handle:
yield handle
@pytest.fixture(scope="module")
async def neox(neox_handle):
await neox_handle.health(300)
return neox_handle.client
@pytest.mark.skip
@pytest.mark.asyncio
async def test_neox(neox, response_snapshot):
response = await neox.generate(
"<|USER|>What's your mood today?<|ASSISTANT|>",
max_new_tokens=10,
decoder_input_details=True,
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.skip
@pytest.mark.asyncio
async def test_neox_load(neox, generate_load, response_snapshot):
responses = await generate_load(
neox,
"<|USER|>What's your mood today?<|ASSISTANT|>",
max_new_tokens=10,
n=4,
)
generated_texts = [r.generated_text for r in responses]
assert len(generated_texts) == 4
assert generated_texts, all(
[text == generated_texts[0] for text in generated_texts]
)
assert responses == response_snapshot

View File

@ -0,0 +1,44 @@
import pytest
@pytest.fixture(scope="module")
def neox_sharded_handle(launcher):
with launcher(
"OpenAssistant/oasst-sft-1-pythia-12b", num_shard=2, use_flash_attention=False
) as handle:
yield handle
@pytest.fixture(scope="module")
async def neox_sharded(neox_sharded_handle):
await neox_sharded_handle.health(300)
return neox_sharded_handle.client
@pytest.mark.skip
@pytest.mark.asyncio
async def test_neox(neox_sharded, response_snapshot):
response = await neox_sharded.generate(
"<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>",
max_new_tokens=10,
decoder_input_details=True,
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.skip
@pytest.mark.asyncio
async def test_neox_load(neox_sharded, generate_load, response_snapshot):
responses = await generate_load(
neox_sharded,
"<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>",
max_new_tokens=10,
n=4,
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot

View File

@ -1,4 +1,5 @@
[pytest]
addopts = --snapshot-warn-unused
asyncio_mode = auto
markers =
private: marks tests as requiring an admin hf token (deselect with '-m "not private"')

View File

@ -1,4 +1,3 @@
include Makefile-transformers
include Makefile-flash-att
unit-tests:
@ -17,7 +16,7 @@ install-torch:
# Install specific version of torch
pip install torch --extra-index-url https://download.pytorch.org/whl/cu118 --no-cache-dir
install: gen-server install-torch install-transformers
install: gen-server install-torch
pip install pip --upgrade
pip install -r requirements.txt
pip install -e ".[bnb, accelerate]"

View File

@ -1,13 +0,0 @@
transformers_commit := 69009822aa7897ffab97afb814e38126b83f639e
transformers:
# Clone fork of transformers with custom CUDA kernels and sharding logic
pip install --upgrade setuptools
git clone https://github.com/OlivierDehaene/transformers.git
build-transformers: transformers
cd transformers && git fetch && git checkout $(transformers_commit) && python setup.py build
install-transformers: build-transformers
pip uninstall transformers -y || true
cd transformers && python setup.py install

View File

@ -0,0 +1,250 @@
#include <ATen/Dispatch.h>
#include <THC/THCAtomics.cuh>
#include <ATen/ATen.h>
#include <torch/torch.h>
#include <vector>
#include <optional>
/**
* Friendly reminder of how multithreading works in CUDA: https://developer.nvidia.com/blog/even-easier-introduction-cuda
* Check example at https://github.com/thomasw21/LinearTransformers/blob/main/model/attention/fast_weight/fast_weight_cuda.cu
**/
// Available in pytorch main
//#define DISPATCH_CASE_FLOATING_TYPES(...) \
// at::AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
// at::AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
// at::AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
// at::AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
/*
* Forward passes
*/
/**
* cast to fp32 if in fp16 + mask + softmax computation in fp32 + cast back to original dtype
**/
template<typename attention_scores_scalar, int64_t min_kv_length_shard_size_per_thread>
__global__ void forward_masked_softmax_kernel(
const torch::PackedTensorAccessor32<attention_scores_scalar, 2, torch::RestrictPtrTraits> attention_scores, // [B, KV]
const torch::PackedTensorAccessor32<bool, 2, torch::RestrictPtrTraits> mask, // [B, KV]
torch::PackedTensorAccessor32<attention_scores_scalar, 2, torch::RestrictPtrTraits> result, // [B, KV]
const int64_t effective_kv_length,
const dim3 blockDim,
const int64_t rows_per_block,
const int64_t kv_length,
const int64_t batch_size
) {
const auto row_id = threadIdx.x / effective_kv_length;
const auto effective_kv_length_id = threadIdx.x % effective_kv_length;
const auto kv_length_start = effective_kv_length_id * min_kv_length_shard_size_per_thread;
auto kv_length_end_ = (effective_kv_length_id + 1) * min_kv_length_shard_size_per_thread;
kv_length_end_ = (kv_length_end_ > kv_length) ? kv_length : kv_length_end_;
const auto kv_length_end = kv_length_end_;
const auto batch_id = blockIdx.x * rows_per_block + row_id;
// We need 2 float storage for each row, one for max computation, the other for normalizing exponential
extern __shared__ float temp_storage[];
const auto row_id_mem_offset = row_id * 2;
if (effective_kv_length_id == 0) {
temp_storage[row_id_mem_offset] = -std::numeric_limits<float>::infinity();
temp_storage[row_id_mem_offset + 1] = 0;
}
__syncthreads();
// Compute mask and max
if (batch_id < batch_size) {
float thread_max = -std::numeric_limits<float>::infinity();
for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
if (mask[batch_id][kv_length_id] == 0) {
const float candidate = attention_scores[batch_id][kv_length_id];
thread_max = (thread_max < candidate) ? candidate : thread_max;
}
}
if (thread_max != -std::numeric_limits<float>::infinity()) {
// TODO @thomasw21 with more memory we can probably compute a much faster `max-reduce` in parallel O(ln(n)) operations in each memory slot
gpuAtomicMax(&temp_storage[row_id_mem_offset], thread_max);
}
}
__syncthreads();
// Compute exp(elt - max) masked
float exponential[min_kv_length_shard_size_per_thread];
if (batch_id < batch_size) {
float thread_add = 0;
for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
if (mask[batch_id][kv_length_id] == 0) {
exponential[kv_length_id - kv_length_start] = std::exp(static_cast<float>(attention_scores[batch_id][kv_length_id]) - temp_storage[row_id_mem_offset]);
thread_add = thread_add + exponential[kv_length_id - kv_length_start];
} else {
exponential[kv_length_id - kv_length_start] = 0.;
}
}
if (thread_add > 0) {
// TODO @thomasw21 with more memory we can probably compute a much faster `sum-reduce` in parallel O(ln(n)) operations in each memory slot
gpuAtomicAdd(&temp_storage[row_id_mem_offset + 1], thread_add);
}
}
__syncthreads();
// Compute softmax
if (batch_id < batch_size) {
// If sum of all exponential is 0, we set the softmax values to 0
if (temp_storage[row_id_mem_offset + 1] == 0.) {
for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
result[batch_id][kv_length_id] = 0.;
}
} else {
for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
result[batch_id][kv_length_id] = static_cast<attention_scores_scalar>(exponential[kv_length_id - kv_length_start] / temp_storage[row_id_mem_offset + 1]);
}
}
}
}
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::tuple<at::Tensor, std::optional<std::vector<at::Tensor>>, at::Tensor> forward(
const at::Tensor query,
const at::Tensor key,
const at::Tensor value,
const std::optional<std::vector<at::Tensor>> layer_past,
const at::Tensor attention_mask,
const std::optional<at::Tensor> head_mask,
const float inv_norm_factor,
const int num_heads,
const bool use_cache
) {
auto query_layer = query;
auto key_layer = key;
auto value_layer = value;
if (layer_past) {
const auto past_key = (*layer_past).at(0);
const auto past_value = (*layer_past).at(1);
key_layer = at::cat({past_key, key_layer}, 2);
value_layer = at::cat({past_value, value_layer}, 2);
}
std::optional<std::vector<at::Tensor>> present;
if (use_cache) {
present = {key_layer, value_layer};
} else {
present = {};
}
const auto batch_size = query_layer.size(0);
const auto q_length = query_layer.size(2);
const auto attn_head_size = query_layer.size(3);
const auto batch_size_times_num_heads = batch_size * num_heads;
const auto kv_length = key_layer.size(2);
const auto query_view = query_layer.reshape({batch_size_times_num_heads, q_length, attn_head_size});
auto key_view = key_layer.reshape({batch_size_times_num_heads, kv_length, attn_head_size}).transpose(1, 2);
auto value_view = value_layer.reshape({batch_size_times_num_heads, kv_length, attn_head_size});
auto query_scaled = query_view * inv_norm_factor;
auto attention_scores = at::bmm(query_scaled, key_view);
// Computing `optionally_cast_fp16_to_fp32 + masked_fill + softmax + cast_to_intial_dtype`
at::Tensor attention_probs;
if (true) {
// TODO @thomasw21: it's easier to think of attention_scores as 2D tensors
const auto attention_scores_2d = attention_scores.view({batch_size_times_num_heads * q_length, kv_length});
const auto attention_mask_2d = attention_mask.view({batch_size_times_num_heads * q_length, kv_length});
// Custom kernel
attention_probs = at::empty_like(attention_scores_2d);
// Check that inputs and contiguous + cuda tensors
CHECK_INPUT(attention_scores_2d);
CHECK_INPUT(attention_mask_2d);
// TODO @thomas21: change by to this as it's cleaner when pytorch 1.13 comes out
// DISPATCH_CASE_FLOATING_TYPES(attention_scores.scalar_type(), "masked_softmax", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, attention_scores.scalar_type(), "masked_softmax", [&] {
/*
* Understanding how GPUs work: https://developer.nvidia.com/blog/cuda-refresher-cuda-programming-model/
* A100 specifications: https://images.nvidia.com/aem-dam/en-zz/Solutions/data-center/nvidia-ampere-architecture-whitepaper.pdf
* - SMs: 108
* - TPCs: 56 (What's that?)
* - Memory size: 40 GB
* - L2 Cache size: 40960 KB (shared across all SMs)
* - L1/Shared memory size: 192 KB (shared across all threads within a SM)
* - Max Threads / SM: 2048
* - Max Thread Blocks / SM: 32
*/
/*
* We should split [batch_size_times_num_heads_block, q_length] in seperate blocks and [batch_size_times_num_heads_block_size, kv_length] a single block
* with multiple threads as we need to `sync_threads` to run exponential sum.
* We maximise the usage of threads within a single block
*/
// TODO @thomasw21 figure out everything warp related:
// - why do they have to be power of 2
// TODO @thomas21 check why everyone is setting 1024 when officially it's 2048
const auto MAX_THREADS_PER_SM = 1024;
// TODO @thomasw21 figure out how to have longer sequences, currently the maximum is `max_kv_length = MAX_THREADS_PER_SM * MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD`
const auto MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD = 4;
// `effective_kv_length = ceil(kv_length / MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD)`
const auto effective_kv_length = (kv_length - 1)/ MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD + 1;
const auto rows_per_block = MAX_THREADS_PER_SM / effective_kv_length;
const auto num_blocks = (batch_size_times_num_heads * q_length - 1) / rows_per_block + 1;
const dim3 gridDim(num_blocks); // Number of blocks that run
const dim3 blockDim(MAX_THREADS_PER_SM); // Number of threads that run per block
const int shared_mem_forward = rows_per_block * 2 * sizeof(float);
// 192 * 2 ** 10
// const auto MAX_L1_MEMORY = 196608;
// const auto MAX_SMs = 108;
// TORCH_CHECK(batch_size_times_num_heads * q_length <= MAX_L1_MEMORY, "Shared memory exceeds 192KB limitation.");
// TORCH_CHECK(gridDim.x * gridDim.y * gridDim.z <= MAX_SMs, "A100s only have 108 SMs. Raising as require blocks is bigger.");
// TORCH_CHECK(blockDim.x * blockDim.y * blockDim.z <= MAX_THREADS_PER_SM, "A100s only have 2048 threads per block. Raising as require requested threads is higher.");
forward_masked_softmax_kernel<scalar_t, MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD><<<gridDim, blockDim, shared_mem_forward>>>(
attention_scores_2d.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
attention_mask_2d.packed_accessor32<bool, 2, torch::RestrictPtrTraits>(),
attention_probs.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
effective_kv_length,
blockDim,
rows_per_block,
kv_length,
batch_size_times_num_heads * q_length
);
});
attention_probs = attention_probs.view({batch_size_times_num_heads, q_length, kv_length});
} else {
// Pytorch C++ API
auto input_dtype = attention_scores.scalar_type();
if (input_dtype == at::ScalarType::Float) {
attention_scores = attention_scores.to(at::ScalarType::Float);
};
// TODO @thomasw21 Figure out how to get minimum value
auto attn_weights = attention_scores.masked_fill_(attention_mask, -1e34);
attention_probs = attn_weights.softmax(-1, at::ScalarType::Float).to(input_dtype);
}
auto context_layer = attention_probs.bmm(value_view);
// `_merge_heads`
context_layer = context_layer.view({batch_size, num_heads, q_length, attn_head_size});
context_layer = context_layer.permute({0, 2, 1, 3});
context_layer = context_layer.reshape({batch_size, q_length, attn_head_size * num_heads});
return std::make_tuple(context_layer, present, attention_probs);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"forward",
&forward,
"GPT-Neox attention mechanism forward (CUDA)"
);
}

View File

@ -0,0 +1,250 @@
#include <ATen/Dispatch.h>
#include <THC/THCAtomics.cuh>
#include <ATen/ATen.h>
#include <torch/torch.h>
#include <vector>
#include <optional>
/**
* Friendly reminder of how multithreading works in CUDA: https://developer.nvidia.com/blog/even-easier-introduction-cuda
* Check example at https://github.com/thomasw21/LinearTransformers/blob/main/model/attention/fast_weight/fast_weight_cuda.cu
**/
// Available in pytorch main
//#define DISPATCH_CASE_FLOATING_TYPES(...) \
// at::AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
// at::AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
// at::AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
// at::AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
/*
* Forward passes
*/
/**
* cast to fp32 if in fp16 + mask + softmax computation in fp32 + cast back to original dtype
**/
template<typename attention_scores_scalar, int64_t min_kv_length_shard_size_per_thread>
__global__ void forward_masked_softmax_kernel(
const torch::PackedTensorAccessor32<attention_scores_scalar, 2, torch::RestrictPtrTraits> attention_scores, // [B, KV]
const torch::PackedTensorAccessor32<bool, 2, torch::RestrictPtrTraits> mask, // [B, KV]
torch::PackedTensorAccessor32<attention_scores_scalar, 2, torch::RestrictPtrTraits> result, // [B, KV]
const int64_t effective_kv_length,
const dim3 blockDim,
const int64_t rows_per_block,
const int64_t kv_length,
const int64_t batch_size
) {
const auto row_id = threadIdx.x / effective_kv_length;
const auto effective_kv_length_id = threadIdx.x % effective_kv_length;
const auto kv_length_start = effective_kv_length_id * min_kv_length_shard_size_per_thread;
auto kv_length_end_ = (effective_kv_length_id + 1) * min_kv_length_shard_size_per_thread;
kv_length_end_ = (kv_length_end_ > kv_length) ? kv_length : kv_length_end_;
const auto kv_length_end = kv_length_end_;
const auto batch_id = blockIdx.x * rows_per_block + row_id;
// We need 2 float storage for each row, one for max computation, the other for normalizing exponential
extern __shared__ float temp_storage[];
const auto row_id_mem_offset = row_id * 2;
if (effective_kv_length_id == 0) {
temp_storage[row_id_mem_offset] = -std::numeric_limits<float>::infinity();
temp_storage[row_id_mem_offset + 1] = 0;
}
__syncthreads();
// Compute mask and max
if (batch_id < batch_size) {
float thread_max = -std::numeric_limits<float>::infinity();
for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
if (mask[batch_id][kv_length_id] == 0) {
const float candidate = attention_scores[batch_id][kv_length_id];
thread_max = (thread_max < candidate) ? candidate : thread_max;
}
}
if (thread_max != -std::numeric_limits<float>::infinity()) {
// TODO @thomasw21 with more memory we can probably compute a much faster `max-reduce` in parallel O(ln(n)) operations in each memory slot
gpuAtomicMax(&temp_storage[row_id_mem_offset], thread_max);
}
}
__syncthreads();
// Compute exp(elt - max) masked
float exponential[min_kv_length_shard_size_per_thread];
if (batch_id < batch_size) {
float thread_add = 0;
for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
if (mask[batch_id][kv_length_id] == 0) {
exponential[kv_length_id - kv_length_start] = std::exp(static_cast<float>(attention_scores[batch_id][kv_length_id]) - temp_storage[row_id_mem_offset]);
thread_add = thread_add + exponential[kv_length_id - kv_length_start];
} else {
exponential[kv_length_id - kv_length_start] = 0.;
}
}
if (thread_add > 0) {
// TODO @thomasw21 with more memory we can probably compute a much faster `sum-reduce` in parallel O(ln(n)) operations in each memory slot
gpuAtomicAdd(&temp_storage[row_id_mem_offset + 1], thread_add);
}
}
__syncthreads();
// Compute softmax
if (batch_id < batch_size) {
// If sum of all exponential is 0, we set the softmax values to 0
if (temp_storage[row_id_mem_offset + 1] == 0.) {
for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
result[batch_id][kv_length_id] = 0.;
}
} else {
for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
result[batch_id][kv_length_id] = static_cast<attention_scores_scalar>(exponential[kv_length_id - kv_length_start] / temp_storage[row_id_mem_offset + 1]);
}
}
}
}
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::tuple<at::Tensor, std::optional<std::vector<at::Tensor>>, at::Tensor> forward(
const at::Tensor fused_qkv,
const std::optional<std::vector<at::Tensor>> layer_past,
const at::Tensor alibi,
const at::Tensor attention_mask,
const std::optional<at::Tensor> head_mask,
const float beta,
const float inv_norm_factor,
const int num_heads,
const bool use_cache
) {
const auto batch_size = fused_qkv.size(0);
const auto q_length = fused_qkv.size(1);
const auto three_times_hidden_size = fused_qkv.size(2);
const auto head_dim = three_times_hidden_size / (3 * num_heads);
const auto batch_size_times_num_heads = batch_size * num_heads;
// `split_heads`
const auto fused_qkv_view = fused_qkv.view({batch_size, q_length, num_heads, 3 * head_dim});
const auto tensor_list = fused_qkv_view.split(head_dim, -1);
const auto query_layer = tensor_list[0].transpose(1, 2).reshape({batch_size_times_num_heads, q_length, head_dim});
auto key_layer = tensor_list[1].permute({0, 2, 3, 1}).reshape({batch_size_times_num_heads, head_dim, q_length});
auto value_layer = tensor_list[2].transpose(1, 2).reshape({batch_size_times_num_heads, q_length, head_dim});
if (layer_past) {
const auto past_key = (*layer_past).at(0);
const auto past_value = (*layer_past).at(1);
key_layer = at::cat({past_key, key_layer}, 2);
value_layer = at::cat({past_value, value_layer}, 1);
}
std::optional<std::vector<at::Tensor>> present;
if (use_cache) {
present = {key_layer, value_layer};
} else {
present = {};
}
auto attention_scores = alibi.baddbmm(query_layer, key_layer, beta, inv_norm_factor);
// Computing `optionally_cast_fp16_to_fp32 + masked_fill + softmax + cast_to_intial_dtype`
at::Tensor attention_probs;
if (true) {
const auto kv_length = key_layer.size(2);
// TODO @thomasw21: it's easier to think of attention_scores as 2D tensors
const auto attention_scores_2d = attention_scores.view({batch_size_times_num_heads * q_length, kv_length});
const auto attention_mask_2d = attention_mask.view({batch_size_times_num_heads * q_length, kv_length});
// Custom kernel
attention_probs = at::empty_like(attention_scores_2d);
// Check that inputs and contiguous + cuda tensors
CHECK_INPUT(attention_scores_2d);
CHECK_INPUT(attention_mask_2d);
// TODO @thomas21: change by to this as it's cleaner when pytorch 1.13 comes out
// DISPATCH_CASE_FLOATING_TYPES(attention_scores.scalar_type(), "masked_softmax", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, attention_scores.scalar_type(), "masked_softmax", [&] {
/*
* Understanding how GPUs work: https://developer.nvidia.com/blog/cuda-refresher-cuda-programming-model/
* A100 specifications: https://images.nvidia.com/aem-dam/en-zz/Solutions/data-center/nvidia-ampere-architecture-whitepaper.pdf
* - SMs: 108
* - TPCs: 56 (What's that?)
* - Memory size: 40 GB
* - L2 Cache size: 40960 KB (shared across all SMs)
* - L1/Shared memory size: 192 KB (shared across all threads within a SM)
* - Max Threads / SM: 2048
* - Max Thread Blocks / SM: 32
*/
/*
* We should split [batch_size_times_num_heads_block, q_length] in seperate blocks and [batch_size_times_num_heads_block_size, kv_length] a single block
* with multiple threads as we need to `sync_threads` to run exponential sum.
* We maximise the usage of threads within a single block
*/
// TODO @thomasw21 figure out everything warp related:
// - why do they have to be power of 2
// TODO @thomas21 check why everyone is setting 1024 when officially it's 2048
const auto MAX_THREADS_PER_SM = 1024;
// TODO @thomasw21 figure out how to have longer sequences, currently the maximum is `max_kv_length = MAX_THREADS_PER_SM * MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD`
const auto MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD = 4;
// `effective_kv_length = ceil(kv_length / MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD)`
const auto effective_kv_length = (kv_length - 1)/ MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD + 1;
const auto rows_per_block = MAX_THREADS_PER_SM / effective_kv_length;
const auto num_blocks = (batch_size_times_num_heads * q_length - 1) / rows_per_block + 1;
const dim3 gridDim(num_blocks); // Number of blocks that run
const dim3 blockDim(MAX_THREADS_PER_SM); // Number of threads that run per block
const int shared_mem_forward = rows_per_block * 2 * sizeof(float);
// 192 * 2 ** 10
// const auto MAX_L1_MEMORY = 196608;
// const auto MAX_SMs = 108;
// TORCH_CHECK(batch_size_times_num_heads * q_length <= MAX_L1_MEMORY, "Shared memory exceeds 192KB limitation.");
// TORCH_CHECK(gridDim.x * gridDim.y * gridDim.z <= MAX_SMs, "A100s only have 108 SMs. Raising as require blocks is bigger.");
// TORCH_CHECK(blockDim.x * blockDim.y * blockDim.z <= MAX_THREADS_PER_SM, "A100s only have 2048 threads per block. Raising as require requested threads is higher.");
forward_masked_softmax_kernel<scalar_t, MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD><<<gridDim, blockDim, shared_mem_forward>>>(
attention_scores_2d.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
attention_mask_2d.packed_accessor32<bool, 2, torch::RestrictPtrTraits>(),
attention_probs.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
effective_kv_length,
blockDim,
rows_per_block,
kv_length,
batch_size_times_num_heads * q_length
);
});
attention_probs = attention_probs.view({batch_size_times_num_heads, q_length, kv_length});
} else {
// Pytorch C++ API
auto input_dtype = attention_scores.scalar_type();
if (input_dtype == at::ScalarType::Float) {
attention_scores = attention_scores.to(at::ScalarType::Float);
};
// TODO @thomasw21 Figure out how to get minimum value
auto attn_weights = attention_scores.masked_fill_(attention_mask, -1e34);
attention_probs = attn_weights.softmax(-1, at::ScalarType::Float).to(input_dtype);
}
auto context_layer = attention_probs.bmm(value_layer);
// `_merge_heads`
context_layer = context_layer.view({batch_size, num_heads, q_length, head_dim});
context_layer = context_layer.permute({0, 2, 1, 3});
context_layer = context_layer.reshape({batch_size, q_length, three_times_hidden_size / 3});
return std::make_tuple(context_layer, present, attention_probs);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"forward",
&forward,
"Bloom attention mechanism forward (CUDA)"
);
}

View File

@ -0,0 +1,19 @@
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name="custom_kernels",
ext_modules=[
CUDAExtension(
name="custom_kernels.fused_bloom_attention_cuda",
sources=["custom_kernels/fused_bloom_attention_cuda.cu"],
extra_compile_args=["-arch=compute_80", "-std=c++17"],
),
CUDAExtension(
name="custom_kernels.fused_attention_cuda",
sources=["custom_kernels/fused_attention_cuda.cu"],
extra_compile_args=["-arch=compute_80", "-std=c++17"],
),
],
cmdclass={"build_ext": BuildExtension},
)

View File

@ -25,7 +25,8 @@ opentelemetry-instrumentation-grpc = "^0.36b0"
hf-transfer = "^0.1.2"
sentencepiece = "^0.1.97"
tokenizers = "0.13.3"
huggingface-hub = "0.14.0"
huggingface-hub = "^0.14.1"
transformers = "^4.29.2"
[tool.poetry.extras]
accelerate = ["accelerate"]

View File

@ -13,8 +13,8 @@ grpcio-reflection==1.55.0 ; python_version >= "3.9" and python_version < "4.0"
grpcio-status==1.55.0 ; python_version >= "3.9" and python_version < "4.0"
grpcio==1.55.0 ; python_version >= "3.9" and python_version < "4.0"
hf-transfer==0.1.3 ; python_version >= "3.9" and python_version < "4.0"
huggingface-hub==0.14.0 ; python_version >= "3.9" and python_version < "4.0"
idna==3.4 ; python_version >= "3.9" and python_version < "4.0"
huggingface-hub==0.14.1 ; python_version >= "3.9" and python_version < "4.0"
idna==3.4 ; python_version >= "3.9" and python_version < "4"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "4.0"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
@ -33,6 +33,7 @@ safetensors==0.3.1 ; python_version >= "3.9" and python_version < "4.0"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "4.0"
setuptools==67.8.0 ; python_version >= "3.9" and python_version < "4.0"
tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "4.0"
transformers==4.29.2 ; python_version >= "3.9" and python_version < "4.0"
tqdm==4.65.0 ; python_version >= "3.9" and python_version < "4.0"
typer==0.6.1 ; python_version >= "3.9" and python_version < "4.0"
typing-extensions==4.6.0 ; python_version >= "3.9" and python_version < "4.0"

View File

@ -6,12 +6,17 @@ from transformers import AutoTokenizer
from text_generation_server.pb import generate_pb2
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOM
from text_generation_server.utils import weight_hub_files, download_weights
from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOMSharded
@pytest.fixture(scope="session")
def default_bloom():
return BLOOM("bigscience/bloom-560m")
model_id = "bigscience/bloom-560m"
revision = "main"
filenames = weight_hub_files(model_id, revision, ".safetensors")
download_weights(filenames, model_id, revision)
return BLOOMSharded(model_id)
@pytest.fixture(scope="session")

View File

@ -1,3 +1,4 @@
import os
import torch
from loguru import logger
@ -8,17 +9,20 @@ from typing import Optional
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.bloom import BLOOM, BLOOMSharded
from text_generation_server.models.bloom import BLOOMSharded
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
from text_generation_server.models.rw import RW
from text_generation_server.models.opt import OPT, OPTSharded
from text_generation_server.models.galactica import Galactica, GalacticaSharded
from text_generation_server.models.opt import OPTSharded
from text_generation_server.models.galactica import GalacticaSharded
from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.gpt_neox import GPTNeoxSharded
from text_generation_server.models.t5 import T5Sharded
from text_generation_server.models.gpt_neox import GPTNeoxSharded
try:
if torch.cuda.is_available():
if (
torch.cuda.is_available()
and not os.getenv("USE_FLASH_ATTENTION", "").lower() == "false"
):
major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
is_sm8x = major == 8 and minor >= 0
@ -30,14 +34,12 @@ try:
f"GPU with CUDA capability {major} {minor} is not supported"
)
from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded
from text_generation_server.models.flash_rw import FlashRW, FlashRWSharded
from text_generation_server.models.flash_rw import FlashRWSharded
from text_generation_server.models.flash_neox import FlashNeoXSharded
from text_generation_server.models.flash_llama import (
FlashLlama,
FlashLlamaSharded,
)
from text_generation_server.models.flash_santacoder import (
FlashSantacoder,
FlashSantacoderSharded,
)
@ -52,30 +54,22 @@ except ImportError:
__all__ = [
"Model",
"BLOOM",
"BLOOMSharded",
"CausalLM",
"FlashCausalLM",
"Galactica",
"GalacticaSharded",
"GPTNeoxSharded",
"Seq2SeqLM",
"SantaCoder",
"OPT",
"OPTSharded",
"T5Sharded",
"get_model",
]
if FLASH_ATTENTION:
__all__.append(FlashNeoX)
__all__.append(FlashNeoXSharded)
__all__.append(FlashRW)
__all__.append(FlashRWSharded)
__all__.append(FlashSantacoder)
__all__.append(FlashSantacoderSharded)
__all__.append(FlashLlama)
__all__.append(FlashLlamaSharded)
FLASH_ATT_ERROR_MESSAGE = (
"{} requires Flash Attention CUDA kernels to be installed.\n"
@ -102,36 +96,24 @@ def get_model(
trust_remote_code: bool,
) -> Model:
if "facebook/galactica" in model_id:
if sharded:
return GalacticaSharded(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
else:
return Galactica(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
return GalacticaSharded(
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
)
if model_id.startswith("bigcode/"):
if sharded:
if not FLASH_ATTENTION:
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder")
)
if FLASH_ATTENTION:
return FlashSantacoderSharded(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
elif sharded:
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
)
else:
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
return santacoder_cls(
return SantaCoder(
model_id,
revision,
quantize=quantize,
@ -144,20 +126,19 @@ def get_model(
model_type = config_dict["model_type"]
if model_type == "gpt_bigcode":
if sharded:
if not FLASH_ATTENTION:
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder")
)
if FLASH_ATTENTION:
return FlashSantacoderSharded(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
elif sharded:
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
)
else:
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
return santacoder_cls(
return SantaCoder(
model_id,
revision,
quantize=quantize,
@ -165,33 +146,45 @@ def get_model(
)
if model_type == "bloom":
if sharded:
return BLOOMSharded(
return BLOOMSharded(
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
)
elif model_type == "gpt_neox":
if FLASH_ATTENTION:
return FlashNeoXSharded(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
elif sharded:
return GPTNeoxSharded(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
else:
return BLOOM(
return CausalLM(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
if model_type == "gpt_neox":
if sharded:
neox_cls = FlashNeoXSharded if FLASH_ATTENTION else GPTNeoxSharded
return neox_cls(
elif model_type == "llama":
if FLASH_ATTENTION:
return FlashLlama(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
else:
neox_cls = FlashNeoX if FLASH_ATTENTION else CausalLM
return neox_cls(
return CausalLM(
model_id,
revision,
quantize=quantize,
@ -217,7 +210,7 @@ def get_model(
)
else:
if FLASH_ATTENTION and not config_dict.get("alibi", False):
return FlashRW(
return FlashRWSharded(
model_id,
revision,
quantize=quantize,
@ -231,42 +224,12 @@ def get_model(
trust_remote_code=trust_remote_code,
)
if model_type == "llama":
if sharded:
if FLASH_ATTENTION:
return FlashLlamaSharded(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Llama"))
else:
llama_cls = FlashLlama if FLASH_ATTENTION else CausalLM
return llama_cls(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
elif model_type == "opt":
return OPTSharded(
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
)
if model_type == "opt":
if sharded:
return OPTSharded(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
else:
return OPT(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
if model_type == "t5":
elif model_type == "t5":
if sharded:
return T5Sharded(
model_id,

View File

@ -1,37 +1,26 @@
import torch
import torch.distributed
from typing import List, Optional, Type
from typing import Optional, Type
from accelerate import init_empty_weights
from safetensors import safe_open
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
AutoConfig,
PreTrainedTokenizerBase,
)
from transformers.models.bloom.parallel_layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
)
from text_generation_server.models.custom_modeling.bloom_modeling import (
BloomForCausalLM,
)
from text_generation_server.models import CausalLM
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
HAS_BITS_AND_BYTES = True
try:
import bitsandbytes as bnb
from bitsandbytes.nn import Int8Params
except Exception as e:
HAS_BITS_AND_BYTES = False
class BloomCausalLMBatch(CausalLMBatch):
@classmethod
@ -42,34 +31,12 @@ class BloomCausalLMBatch(CausalLMBatch):
dtype: torch.dtype,
device: torch.device,
) -> "CausalLMBatch":
batch = super(BloomCausalLMBatch, cls).from_pb(
pb=pb, tokenizer=tokenizer, dtype=dtype, device=device
)
batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device)
batch.keys_head_dim_last = False
return batch
class BLOOM(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
super(BLOOM, self).__init__(
model_id=model_id,
revision=revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
@property
def batch_type(self) -> Type[CausalLMBatch]:
return BloomCausalLMBatch
class BLOOMSharded(BLOOM):
class BLOOMSharded(CausalLM):
def __init__(
self,
model_id: str,
@ -101,25 +68,16 @@ class BLOOMSharded(BLOOM):
trust_remote_code=trust_remote_code,
)
config.pad_token_id = 3
config.quantize = quantize
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
with init_empty_weights():
model = AutoModelForCausalLM.from_config(
config, trust_remote_code=trust_remote_code
)
torch.distributed.barrier(group=self.process_group)
self.load_weights(
model,
filenames,
quantize=quantize,
device=device,
dtype=dtype,
rank=rank,
world_size=world_size,
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
)
model = BloomForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model=model,
@ -131,132 +89,9 @@ class BLOOMSharded(BLOOM):
world_size=world_size,
)
@staticmethod
def load_weights(
model,
filenames: List[str],
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
rank: int,
world_size: int,
):
parameters = dict(model.named_parameters())
for file in filenames:
with safe_open(
file, framework="pt", device=str(device) if quantize is None else "cpu"
) as f:
for name in f.keys():
if name.startswith("transformer.") or name.startswith("lm_head."):
full_name = name
else:
full_name = f"transformer.{name}"
module_name, param_name = full_name.rsplit(".", 1)
module = model.get_submodule(module_name)
current_tensor = parameters[full_name]
slice_ = f.get_slice(name)
if isinstance(module, TensorParallelColumnLinear):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif isinstance(module, TensorParallelRowLinear):
if param_name == "weight":
size = slice_.get_shape()[1]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[:, start:stop]
else:
tensor = slice_[:]
# XXX: Hack for Rowlinear to add the bias only once.
if rank != 0:
tensor = torch.zeros_like(tensor)
elif (
isinstance(module, TensorParallelEmbedding)
or name == "lm_head.weight"
):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
else:
tensor = slice_[:]
if current_tensor.shape != tensor.shape:
raise ValueError(
f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}"
)
tensor = tensor.contiguous().to(dtype)
if quantize == "bitsandbytes":
if not HAS_BITS_AND_BYTES:
raise ImportError(
"bitsandbytes is not available on your machine either because it is not installed "
"or you don't have a GPU.\n"
"You can install it with `pip install bitsandbytes`."
)
if (
type(module)
in [TensorParallelRowLinear, TensorParallelColumnLinear]
and param_name == "weight"
):
tensor = Int8Params(
tensor,
has_fp16_weights=False,
requires_grad=False,
).to(device)
state = bnb.MatmulLtState()
state.threshold = 6.0
state.has_fp16_weights = False
state.memory_efficient_backward = False
state.use_pool = True
state.CB = tensor.CB
state.SCB = tensor.SCB
tensor.CB = None
tensor.SCB = None
def replace_linear(state):
def linear(input, weight, bias):
out = bnb.matmul(
input,
weight,
state=state,
threshold=state.threshold,
bias=bias,
)
if state.CB is not None:
# we converted 8-bit row major to turing/ampere format
# in the first inference pass
# we no longer need the row-major weight
del state.CB
weight.data = state.CxB
return out
return linear
module.linear = replace_linear(state)
else:
tensor = tensor.to(device)
elif quantize == "gptq":
raise NotImplementedError("`gptq` is not implemented for now")
elif quantize is None:
tensor = tensor.to(device)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
module._parameters[param_name] = tensor
if name == "word_embeddings.weight":
model.lm_head._parameters["weight"] = tensor
@property
def batch_type(self) -> Type[CausalLMBatch]:
return BloomCausalLMBatch
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
@ -269,9 +104,5 @@ class BLOOMSharded(BLOOM):
use_cache=True,
)
# Logits are sharded, so we need to gather them
logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
torch.distributed.all_gather(logits, outputs.logits, group=self.process_group)
logits = torch.cat(logits, dim=2)
logits = outputs.logits
return logits, outputs.past_key_values

View File

@ -0,0 +1,912 @@
# coding=utf-8
# Copyright 2022 HuggingFace Inc. team and BigScience workshop.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BLOOM model."""
import math
import os
import warnings
from typing import Optional, Tuple, Union
import torch
import torch.distributed
import torch.utils.checkpoint
from torch import nn
from torch.nn import LayerNorm
from torch.nn import functional as F
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
)
from transformers import BloomConfig, PreTrainedModel
from text_generation_server.utils.layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
TensorParallelHead,
)
CUSTOM_KERNELS_ENABLED = False
if not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True":
try:
from custom_kernels import fused_bloom_attention_cuda
CUSTOM_KERNELS_ENABLED = True
except ImportError:
pass
_CHECKPOINT_FOR_DOC = "bigscience/bloom-560m"
_CONFIG_FOR_DOC = "BloomConfig"
BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST = [
"bigscience/bigscience-small-testing",
"bigscience/bloom-560m",
"bigscience/bloom-1b1",
"bigscience/bloom-1b7",
"bigscience/bloom-3b",
"bigscience/bloom-7b1",
"bigscience/bloom",
]
def _make_causal_mask(
input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
) -> torch.BoolTensor:
"""
Make causal mask used for self-attention.
"""
batch_size, target_length = input_ids_shape
mask = torch.ones(
(target_length, target_length + past_key_values_length),
dtype=torch.bool,
device=device,
)
mask = mask.triu(1 + past_key_values_length)
expanded_mask = mask.unsqueeze(0).expand(
batch_size, target_length, target_length + past_key_values_length
)
return expanded_mask
def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
"""
Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`.
"""
batch_size, src_length = mask.shape
tgt_length = tgt_length if tgt_length is not None else src_length
expanded_mask = ~(mask[:, None, :].to(torch.bool))
return expanded_mask.expand(batch_size, tgt_length, src_length)
def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int) -> torch.Tensor:
"""
Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
`softmax(l+a) = softmax(l)`. Based on
https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly.
Args:
Returns tensor shaped (batch_size * num_heads, 1, max_seq_len)
attention_mask (`torch.Tensor`):
Token-wise attention mask, this should be of shape (batch_size, max_seq_len).
num_heads (`int`, *required*):
number of heads
dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`):
dtype of the output tensor
"""
batch_size, seq_length = attention_mask.shape
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
base = torch.tensor(
2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))),
device=attention_mask.device,
dtype=torch.float32,
)
powers = torch.arange(
1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32
)
slopes = torch.pow(base, powers)
if closest_power_of_2 != num_heads:
extra_base = torch.tensor(
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))),
device=attention_mask.device,
dtype=torch.float32,
)
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
extra_powers = torch.arange(
1,
1 + 2 * num_remaining_heads,
2,
device=attention_mask.device,
dtype=torch.int32,
)
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
# Note: alibi will added to the attention bias that will be applied to the query, key product of attention
# => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
# => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
# => the query_length dimension will then be broadcasted correctly
# This is more or less identical to T5's relative position bias:
# https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
alibi = slopes[..., None] * arange_tensor
return alibi
# @torch.jit.script
def dropout_add(
x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool
) -> torch.Tensor:
"""
Dropout add function
Args:
x (`torch.tensor`, *required*):
input tensor
residual (`torch.tensor`, *required*):
esidual tensor
prob (`float`, *required*):
dropout probability
training (`bool`, *required*):
training mode
"""
out = F.dropout(x, p=prob, training=training)
out = residual + out
return out
# @torch.jit.script # this is shit for unknow reasons.
def _split_heads(
fused_qkv: torch.Tensor, num_heads: int, head_dim: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory
storage as `fused_qkv`
Args:
fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
Returns:
query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
value: [batch_size, seq_length, num_heads, head_dim]
"""
batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
fused_qkv = fused_qkv.view(batch_size, seq_length, num_heads, 3 * head_dim)
query_layer, key_layer, value_layer = fused_qkv.split(head_dim, dim=-1)
query_layer = query_layer.transpose(1, 2).reshape(
batch_size * num_heads, seq_length, head_dim
)
key_layer = key_layer.permute(0, 2, 3, 1).reshape(
batch_size * num_heads, head_dim, seq_length
)
value_layer = value_layer.transpose(1, 2).reshape(
batch_size * num_heads, seq_length, head_dim
)
return query_layer, key_layer, value_layer
# @torch.jit.script
def _merge_heads(x: torch.Tensor, num_heads: int, head_dim: int) -> torch.Tensor:
"""
Merge heads together over the last dimenstion
Args:
x: (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim]
Returns:
torch.tensor: [batch_size, seq_length, num_heads * head_dim]
"""
# What we want to achieve is:
# batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim
batch_size_and_num_heads, seq_length, _ = x.shape
batch_size = batch_size_and_num_heads // num_heads
# First view to decompose the batch size
# batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim
x = x.view(batch_size, num_heads, seq_length, head_dim)
# batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim
x = x.permute(0, 2, 1, 3)
# batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim
return x.reshape(batch_size, seq_length, num_heads * head_dim)
class BloomAttention(nn.Module):
def __init__(self, prefix, config: BloomConfig, weights):
super().__init__()
self.pretraining_tp = config.pretraining_tp
self.slow_but_exact = config.slow_but_exact
self.process_group = weights.process_group
self.hidden_size = config.hidden_size
self.num_heads = config.n_head
self.head_dim = self.hidden_size // self.num_heads
self.split_size = self.hidden_size
self.hidden_dropout = config.hidden_dropout
if self.head_dim * self.num_heads != self.hidden_size:
raise ValueError(
f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
f" {self.num_heads})."
)
# Layer-wise attention scaling
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
self.beta = 1.0
process_group = weights.process_group
self.num_heads = self.num_heads // process_group.size()
self.query_key_value = TensorParallelColumnLinear.load(
config=config,
prefix=f"{prefix}.query_key_value",
weights=weights,
bias=True,
)
self.dense = TensorParallelRowLinear.load(
config=config, prefix=f"{prefix}.dense", weights=weights, bias=True
)
self.attention_dropout = nn.Dropout(config.attention_dropout)
@staticmethod
def compute_attention(
fused_qkv: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]],
alibi: torch.Tensor,
attention_mask: torch.Tensor,
head_mask: Optional[torch.Tensor],
beta: float,
inv_norm_factor: float,
num_heads: int,
use_cache: bool,
):
batch_size, q_length, three_times_hidden_size = fused_qkv.shape
head_dim = three_times_hidden_size // (3 * num_heads)
batch_size * num_heads
### TODO @thomasw21: this takes quite a bit of time, how do I accelerate that?
# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = _split_heads(
fused_qkv, num_heads=num_heads, head_dim=head_dim
)
if layer_past is not None:
past_key, past_value = layer_past
# concatenate along seq_length dimension:
# - key: [batch_size * self.num_heads, head_dim, kv_length]
# - value: [batch_size * self.num_heads, kv_length, head_dim]
past_key = past_key.view(-1, *past_key.shape[-2:])
key_layer = torch.cat((past_key, key_layer), dim=2)
past_value = past_value.view(-1, *past_value.shape[-2:])
value_layer = torch.cat((past_value, value_layer), dim=1)
_, _, kv_length = key_layer.shape
if use_cache is True:
present = (key_layer, value_layer)
else:
present = None
###
# [batch_size * num_heads, q_length, kv_length]
# we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11
attention_scores = alibi.baddbmm(
batch1=query_layer,
batch2=key_layer,
beta=beta,
alpha=inv_norm_factor,
)
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
input_dtype = attention_scores.dtype
# `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
if input_dtype == torch.float16:
attention_scores = attention_scores.to(torch.float)
# torch.finfo not supported by torch.jit, we temporarily remplace with `-1e34`
attn_weights = attention_scores.masked_fill_(
attention_mask, torch.finfo(attention_scores.dtype).min
)
attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
input_dtype
)
# # [batch_size, num_heads, q_length, kv_length]
# attention_probs = self.attention_dropout(attention_probs)
if head_mask is not None:
attention_probs = attention_probs * head_mask
# matmul: [batch_size * num_heads, q_length, head_dim]
context_layer = torch.bmm(attention_probs, value_layer, out=query_layer)
# change view [batch_size, num_heads, q_length, head_dim]
context_layer = _merge_heads(
context_layer, num_heads=num_heads, head_dim=head_dim
)
return context_layer, present, attention_probs
def forward(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor,
alibi: torch.Tensor,
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
fused_qkv = self.query_key_value(
hidden_states
) # [batch_size, seq_length, 3 x hidden_size]
batch_size, q_length, _ = fused_qkv.shape
if layer_past is not None:
past_key, past_value = layer_past
layer_past = (
past_key.view(-1, *past_key.shape[-2:]),
past_value.view(-1, *past_value.shape[-2:]),
)
if CUSTOM_KERNELS_ENABLED:
assert self.training is False, "Only foward pass was implemented"
assert (
attention_mask.shape[-1] < 4096
), "Custom kernel support only up to 4096 tokens"
(
context_layer,
present,
attention_probs,
) = fused_bloom_attention_cuda.forward(
fused_qkv,
layer_past,
alibi,
attention_mask,
head_mask,
self.beta,
self.inv_norm_factor,
self.num_heads,
use_cache,
)
else:
context_layer, present, attention_probs = self.compute_attention(
fused_qkv=fused_qkv,
layer_past=layer_past,
alibi=alibi,
attention_mask=attention_mask,
head_mask=head_mask,
beta=self.beta,
inv_norm_factor=self.inv_norm_factor,
num_heads=self.num_heads,
use_cache=use_cache,
)
# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
if self.pretraining_tp > 1 and self.slow_but_exact:
slices = self.hidden_size / self.pretraining_tp
output_tensor = torch.zeros_like(context_layer)
for i in range(self.pretraining_tp):
output_tensor = output_tensor + F.linear(
context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
)
else:
output_tensor = self.dense(context_layer)
# output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
output_tensor += residual
outputs = (output_tensor, present)
if output_attentions:
outputs += (attention_probs,)
return outputs
class BloomMLP(nn.Module):
def __init__(self, prefix, config: BloomConfig, weights):
super().__init__()
self.pretraining_tp = config.pretraining_tp
self.slow_but_exact = config.slow_but_exact
self.dense_h_to_4h = TensorParallelColumnLinear.load(
config=config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=True
)
self.dense_4h_to_h = TensorParallelRowLinear.load(
config=config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=True
)
self.gelu_impl = torch.nn.GELU(approximate="tanh")
self.hidden_dropout = config.hidden_dropout
def forward(
self, hidden_states: torch.Tensor, residual: torch.Tensor
) -> torch.Tensor:
hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states))
if self.pretraining_tp > 1 and self.slow_but_exact:
intermediate_output = torch.zeros_like(residual)
slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp
for i in range(self.pretraining_tp):
intermediate_output = intermediate_output + F.linear(
hidden_states[:, :, int(i * slices) : int((i + 1) * slices)],
self.dense_4h_to_h.weight[
:, int(i * slices) : int((i + 1) * slices)
],
)
else:
intermediate_output = self.dense_4h_to_h(hidden_states)
# output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)
intermediate_output += residual
return intermediate_output
class BloomBlock(nn.Module):
def __init__(self, layer_id: int, config: BloomConfig, weights):
super().__init__()
prefix = f"h.{layer_id}"
self.input_layernorm = LayerNorm.load(
prefix=f"{prefix}.input_layernorm",
weights=weights,
eps=config.layer_norm_epsilon,
)
self.num_heads = config.n_head
self.self_attention = BloomAttention(
prefix=f"{prefix}.self_attention", config=config, weights=weights
)
self.post_attention_layernorm = LayerNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.layer_norm_epsilon,
)
self.mlp = BloomMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.apply_residual_connection_post_layernorm = (
config.apply_residual_connection_post_layernorm
)
self.hidden_dropout = config.hidden_dropout
def forward(
self,
hidden_states: torch.Tensor,
alibi: torch.Tensor,
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
# hidden_states: [batch_size, seq_length, hidden_size]
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Layer norm post the self attention.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
# Self attention.
attn_outputs = self.self_attention(
layernorm_output,
residual,
layer_past=layer_past,
attention_mask=attention_mask,
alibi=alibi,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
attention_output = attn_outputs[0]
outputs = attn_outputs[1:]
layernorm_output = self.post_attention_layernorm(attention_output)
# Get residual
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = attention_output
# MLP.
output = self.mlp(layernorm_output, residual)
if use_cache:
outputs = (output,) + outputs
else:
outputs = (output,) + outputs[1:]
return outputs # hidden_states, present, attentions
class BloomPreTrainedModel(PreTrainedModel):
config_class = BloomConfig
base_model_prefix = "transformer"
_no_split_modules = ["BloomBlock"]
@staticmethod
def _convert_to_standard_cache(
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
"""
Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
num_heads, ...]))
"""
batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape
num_heads = batch_size_times_num_heads // batch_size
# key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
# value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
return tuple(
(
layer_past[0].view(batch_size, num_heads, head_dim, seq_length),
layer_past[1].view(batch_size, num_heads, seq_length, head_dim),
)
for layer_past in past_key_value
)
@staticmethod
def _convert_to_bloom_cache(
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
"""
Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...]))
"""
batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
batch_size_times_num_heads = batch_size * num_heads
# key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
# value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
return tuple(
(
layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length),
layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim),
)
for layer_past in past_key_value
)
class BloomModel(BloomPreTrainedModel):
def __init__(self, config: BloomConfig, weights):
super().__init__(config)
self.embed_dim = config.hidden_size
self.num_heads = config.n_head
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.word_embeddings = TensorParallelEmbedding(
prefix="word_embeddings", weights=weights
)
self.word_embeddings_layernorm = LayerNorm.load(
prefix="word_embeddings_layernorm",
weights=weights,
eps=config.layer_norm_epsilon,
)
# Transformer blocks
self.h = nn.ModuleList(
[
BloomBlock(layer_id=layer_id, config=config, weights=weights)
for layer_id in range(config.num_hidden_layers)
]
)
# Final Layer Norm
self.ln_f = LayerNorm.load(
prefix="ln_f", weights=weights, eps=config.layer_norm_epsilon
)
def _prepare_attn_mask(
self,
attention_mask: torch.Tensor,
input_shape: Tuple[int, int],
past_key_values_length: int,
) -> torch.BoolTensor:
# create causal mask
# [batch_size, seq_length] -> [batch_size, tgt_length, src_length]
combined_attention_mask = None
device = attention_mask.device
_, src_length = input_shape
if src_length > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
device=device,
past_key_values_length=past_key_values_length,
)
# [batch_size, seq_length] -> [batch_size, tgt_length, src_length]
expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
combined_attention_mask = (
expanded_attn_mask
if combined_attention_mask is None
else expanded_attn_mask | combined_attention_mask
)
return combined_attention_mask
def set_input_embeddings(self, new_embeddings: torch.Tensor):
self.word_embeddings = new_embeddings
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**deprecated_arguments,
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
if deprecated_arguments.pop("position_ids", False) is not False:
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
warnings.warn(
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
" passing `position_ids`.",
FutureWarning,
)
if len(deprecated_arguments) > 0:
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if past_key_values is None:
past_key_values = tuple([None] * len(self.h))
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape batch_size x num_heads x N x N
# head_mask has shape n_layer x batch x num_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
# Compute alibi tensor: check build_alibi_tensor documentation
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values[0] is not None:
past_key_values_length = past_key_values[0][0].shape[-1]
seq_length_with_past = seq_length_with_past + past_key_values_length
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), device=hidden_states.device
)
else:
attention_mask = attention_mask.to(hidden_states.device)
alibi = build_alibi_tensor(attention_mask, self.num_heads)
causal_mask = self._prepare_attn_mask(
attention_mask,
input_shape=(batch_size, seq_length),
past_key_values_length=past_key_values_length,
)
if hasattr(self, "tp_rank"):
assert self.num_heads % self.tp_world_size == 0
block_size = self.num_heads // self.tp_world_size
alibi = alibi[
:, self.tp_rank * block_size : (self.tp_rank + 1) * block_size
]
alibi = alibi.reshape(batch_size * block_size, 1, seq_length_with_past)
causal_mask = torch.repeat_interleave(causal_mask, block_size, dim=0)
else:
alibi = alibi.reshape(batch_size * self.num_heads, 1, seq_length_with_past)
causal_mask = torch.repeat_interleave(causal_mask, self.num_heads, dim=0)
alibi = alibi.to(hidden_states.dtype)
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=causal_mask,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
alibi=alibi,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_self_attentions = all_self_attentions + (
outputs[2 if use_cache else 1],
)
# Add last hidden state
hidden_states = self.ln_f(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
presents,
all_hidden_states,
all_self_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
class BloomForCausalLM(BloomPreTrainedModel):
def __init__(self, config, weights):
super().__init__(config)
self.transformer = BloomModel(config, weights)
self.lm_head = TensorParallelHead.load(
config,
prefix="word_embeddings",
weights=weights,
)
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor,
past_key_values: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> dict:
# only last token for input_ids if past is not None
if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1)
# the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed
if past_key_values[0][0].shape[0] == input_ids.shape[0]:
past_key_values = self._convert_to_bloom_cache(past_key_values)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**deprecated_arguments,
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
if deprecated_arguments.pop("position_ids", False) is not False:
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
warnings.warn(
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
" passing `position_ids`.",
FutureWarning,
)
if len(deprecated_arguments) > 0:
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
loss = None
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)

View File

@ -30,21 +30,23 @@ import flash_attn_cuda
import dropout_layer_norm
from text_generation_server.utils.layers import (
FastLinear,
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
PositionRotaryEmbedding,
TensorParallelHead,
)
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
def __init__(self, prefix, weights, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
weight = weights.get_tensor(f"{prefix}.weight")
self.weight = nn.Parameter(weight)
self.variance_epsilon = eps
def forward(self, hidden_states, residual=None):
@ -91,35 +93,35 @@ class LlamaRMSNorm(nn.Module):
class FlashLlamaAttention(torch.nn.Module):
def __init__(
self,
num_heads,
hidden_size,
process_group=None,
prefix: str,
config,
weights,
):
super().__init__()
self.num_heads = num_heads
self.hidden_size = hidden_size
self.head_size = hidden_size // num_heads
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
self.rotary_emb = PositionRotaryEmbedding.load(
prefix=f"{prefix}.rotary_emb", weights=weights
)
self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000)
self.softmax_scale = self.head_size ** (-0.5)
if process_group is None:
self.query_key_value = FastLinear(hidden_size, 3 * hidden_size, bias=False)
self.o_proj = FastLinear(hidden_size, hidden_size, bias=False)
else:
self.num_heads = self.num_heads // process_group.size()
self.query_key_value = TensorParallelColumnLinear(
hidden_size,
3 * hidden_size,
bias=False,
process_group=process_group,
)
self.o_proj = TensorParallelRowLinear(
hidden_size,
hidden_size,
bias=False,
process_group=process_group,
)
self.num_heads = self.num_heads // weights.process_group.size()
self.query_key_value = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=False,
)
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
)
def forward(
self,
@ -195,8 +197,9 @@ class FlashLlamaAttention(torch.nn.Module):
class LlamaMLP(nn.Module):
def __init__(self, act, hidden_size, intermediate_size, process_group=None):
def __init__(self, prefix, config, weights):
super().__init__()
act = config.hidden_act
self.act = (
ACT2FN[act]
if "gelu" not in act
@ -207,32 +210,23 @@ class LlamaMLP(nn.Module):
else "none",
)
)
if process_group is None:
# Fuse gate and up proj
self.gate_up_proj = FastLinear(
hidden_size, 2 * intermediate_size, bias=False
)
self.down_proj = FastLinear(intermediate_size, hidden_size, bias=False)
self.intermediate_size = intermediate_size
else:
# Fuse gate and up proj
self.gate_up_proj = TensorParallelColumnLinear(
hidden_size,
2 * intermediate_size,
bias=False,
process_group=process_group,
)
self.down_proj = TensorParallelRowLinear(
intermediate_size,
hidden_size,
bias=False,
process_group=process_group,
reduce=True,
)
self.intermediate_size = self.down_proj.in_features
self.process_group = process_group
# Fuse gate and up proj
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
)
self.down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=False,
)
self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
)
def forward(self, hidden_states):
gate_up_states = self.gate_up_proj(hidden_states)
@ -241,22 +235,22 @@ class LlamaMLP(nn.Module):
class FlashLlamaLayer(nn.Module):
def __init__(
self,
num_heads,
act,
hidden_size,
intermediate_size,
rms_norm_eps,
process_group=None,
):
def __init__(self, layer_id, config, weights):
super().__init__()
prefix = f"model.layers.{layer_id}"
self.self_attn = FlashLlamaAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.self_attn = FlashLlamaAttention(num_heads, hidden_size, process_group)
self.mlp = LlamaMLP(act, hidden_size, intermediate_size, process_group)
self.input_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps)
self.input_layernorm = LlamaRMSNorm(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.post_attention_layernorm = LlamaRMSNorm(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
def forward(
self,
@ -295,54 +289,35 @@ class FlashLlamaLayer(nn.Module):
class FlashLlamaModel(torch.nn.Module):
def __init__(self, config, process_group=None):
super(FlashLlamaModel, self).__init__()
def __init__(self, config, weights):
super().__init__()
self.config = config
self.tp_embeddings = False
if process_group is not None:
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
if config.vocab_size % self.tp_world_size == 0:
self.tp_embeddings = True
if self.tp_embeddings:
self.embed_tokens = TensorParallelEmbedding(
config.vocab_size, config.hidden_size, process_group=process_group
)
else:
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights
)
self.layers = nn.ModuleList(
[
FlashLlamaLayer(
config.num_attention_heads,
config.hidden_act,
config.hidden_size,
config.intermediate_size,
config.rms_norm_eps,
process_group,
layer_id,
config,
weights,
)
for _ in range(config.num_hidden_layers)
for layer_id in range(config.num_hidden_layers)
]
)
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.norm = LlamaRMSNorm(
prefix="model.norm", weights=weights, eps=config.rms_norm_eps
)
self.gradient_checkpointing = False
self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads
def post_load_weights(self, quantize: Optional[str] = None):
if isinstance(self.embed_tokens, TensorParallelEmbedding):
self.embed_tokens.add_null_idx()
for layer in self.layers:
layer: FlashLlamaLayer
layer.self_attn.query_key_value.prepare_weights(quantize)
layer.self_attn.o_proj.prepare_weights(quantize)
layer.mlp.gate_up_proj.prepare_weights(quantize)
layer.mlp.down_proj.prepare_weights(quantize)
def forward(
self,
input_ids,
@ -410,29 +385,15 @@ class FlashLlamaModel(torch.nn.Module):
class FlashLlamaForCausalLM(torch.nn.Module):
def __init__(self, config, process_group=None):
def __init__(self, config, weights):
super().__init__()
self.process_group = process_group
if self.process_group is not None:
self.world_size = self.process_group.size()
else:
self.world_size = 1
self.model = FlashLlamaModel(config, process_group)
if self.model.tp_embeddings:
self.lm_head = FastLinear(
config.hidden_size,
config.vocab_size // process_group.size(),
bias=False,
)
else:
self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False)
def post_load_weights(self, quantize: Optional[str] = None):
self.model.post_load_weights(quantize)
self.lm_head.prepare_weights()
self.model = FlashLlamaModel(config, weights)
self.lm_head = TensorParallelHead.load(
config,
prefix="lm_head",
weights=weights,
)
def forward(
self,
@ -457,12 +418,4 @@ class FlashLlamaForCausalLM(torch.nn.Module):
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
if self.model.tp_embeddings:
# Logits are sharded, so we need to gather them
world_logits = [torch.empty_like(logits) for _ in range(self.world_size)]
torch.distributed.all_gather(world_logits, logits, group=self.process_group)
world_logits = torch.cat(world_logits, dim=1)
return world_logits, present
return logits, present

View File

@ -31,61 +31,81 @@ from typing import Optional
import flash_attn_cuda
from text_generation_server.utils.layers import (
FastLinear,
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelHead,
FastLayerNorm,
PositionRotaryEmbedding,
get_linear,
)
def load_row(config, prefix: str, weights, bias: bool):
weight = weights.get_sharded(f"{prefix}.weight", dim=1)
if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process
bias = weights.get_tensor(f"{prefix}.bias")
else:
bias = None
linear = get_linear(weight, bias, config.quantize)
if config.use_parallel_residual:
return linear
else:
return TensorParallelRowLinear(linear, process_group=weights.process_group)
def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size):
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
weight = (
weight.view(
num_heads,
3,
head_size,
hidden_size,
)
.permute(1, 0, 2, 3)
.reshape(-1, hidden_size)
)
bias = bias.view(num_heads, 3, head_size).permute(1, 0, 2).reshape(-1)
linear = get_linear(weight, bias, config.quantize)
if config.use_parallel_residual:
return linear
else:
return TensorParallelColumnLinear(linear)
class FlashNeoxAttention(torch.nn.Module):
def __init__(
self,
num_heads,
hidden_size,
rotary_pct,
rotary_emb_base,
process_group=None,
reduce=True,
):
def __init__(self, config, prefix, weights):
super().__init__()
num_heads = config.num_attention_heads
hidden_size = config.hidden_size
self.num_heads = num_heads
self.hidden_size = hidden_size
self.head_size = hidden_size // num_heads
self.num_heads = self.num_heads // weights.process_group.size()
self.rotary_emb = PositionRotaryEmbedding.load(
prefix=f"{prefix}.rotary_emb", weights=weights
)
rotary_ndims = int(self.head_size * rotary_pct)
self.rotary_emb = PositionRotaryEmbedding(rotary_ndims, base=rotary_emb_base)
self.softmax_scale = self.head_size ** (-0.5)
if process_group is None:
self.query_key_value = FastLinear(hidden_size, 3 * hidden_size)
self.dense = FastLinear(hidden_size, hidden_size)
else:
self.num_heads = self.num_heads // process_group.size()
self.query_key_value = TensorParallelColumnLinear(
hidden_size,
3 * hidden_size,
process_group=process_group,
)
self.dense = TensorParallelRowLinear(
hidden_size, hidden_size, process_group=process_group, reduce=reduce
)
def shuffle_qkv_dims(self):
"""Swap dims to avoid an additional permute"""
self.query_key_value.weight = torch.nn.Parameter(
self.query_key_value.weight.view(
self.num_heads, 3, self.head_size, self.hidden_size
)
.permute(1, 0, 2, 3)
.reshape(-1, self.hidden_size)
self.query_key_value = load_qkv(
config,
prefix=f"{prefix}.query_key_value",
weights=weights,
num_heads=self.num_heads,
head_size=self.head_size,
hidden_size=self.hidden_size,
)
self.query_key_value.bias = torch.nn.Parameter(
self.query_key_value.bias.view(self.num_heads, 3, self.head_size)
.permute(1, 0, 2)
.reshape(-1)
self.dense = load_row(
config, prefix=f"{prefix}.dense", weights=weights, bias=True
)
def forward(
@ -162,10 +182,9 @@ class FlashNeoxAttention(torch.nn.Module):
class FlashMLP(nn.Module):
def __init__(
self, act, hidden_size, intermediate_size, process_group=None, reduce=True
):
def __init__(self, config, prefix, weights):
super().__init__()
act = config.hidden_act
self.act = (
ACT2FN[act]
if "gelu" not in act
@ -177,22 +196,12 @@ class FlashMLP(nn.Module):
)
)
if process_group is None:
self.dense_h_to_4h = FastLinear(hidden_size, intermediate_size)
self.dense_4h_to_h = FastLinear(intermediate_size, hidden_size)
else:
self.dense_h_to_4h = TensorParallelColumnLinear(
hidden_size,
intermediate_size,
process_group=process_group,
)
self.dense_4h_to_h = TensorParallelRowLinear(
intermediate_size,
hidden_size,
process_group=process_group,
reduce=reduce,
)
self.process_group = process_group
self.dense_h_to_4h = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=True
)
self.dense_4h_to_h = load_row(
config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=True
)
def forward(self, hidden_states):
hidden_states = self.dense_h_to_4h(hidden_states)
@ -202,38 +211,28 @@ class FlashMLP(nn.Module):
class FlashNeoXLayer(nn.Module):
def __init__(
self,
num_heads,
act,
hidden_size,
intermediate_size,
rotary_pct,
rotary_emb_base,
layer_norm_eps,
use_parallel_residual,
process_group=None,
):
def __init__(self, layer_id, config, weights):
super().__init__()
self.use_parallel_residual = use_parallel_residual
self.input_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps)
self.post_attention_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps)
layer_norm_eps = config.layer_norm_eps
prefix = f"gpt_neox.layers.{layer_id}"
self.use_parallel_residual = config.use_parallel_residual
self.input_layernorm = FastLayerNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=layer_norm_eps
)
self.post_attention_layernorm = FastLayerNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=layer_norm_eps,
)
self.attention = FlashNeoxAttention(
num_heads,
hidden_size,
rotary_pct,
rotary_emb_base,
process_group,
reduce=not use_parallel_residual,
config, prefix=f"{prefix}.attention", weights=weights
)
self.mlp = FlashMLP(
act,
hidden_size,
intermediate_size,
process_group,
reduce=not use_parallel_residual,
)
self.process_group = process_group
self.mlp = FlashMLP(config, prefix=f"{prefix}.mlp", weights=weights)
self.process_group = weights.process_group
def forward(
self,
@ -266,9 +265,7 @@ class FlashNeoXLayer(nn.Module):
mlp_output = self.mlp(ln2_hidden_states)
intermediate = mlp_output + attn_output
# Only reduce once and after the addition instead of once per layer
if self.process_group is not None:
torch.distributed.all_reduce(intermediate, group=self.process_group)
torch.distributed.all_reduce(intermediate, group=self.process_group)
return intermediate + hidden_states, None
else:
@ -302,42 +299,24 @@ class FlashGPTNeoXPreTrainedModel(PreTrainedModel):
class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
def __init__(self, config, process_group=None):
def __init__(self, config, weights):
super().__init__(config)
self.config = config
self.tp_embeddings = False
if process_group is not None:
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
if config.vocab_size % self.tp_world_size == 0:
self.tp_embeddings = True
if self.tp_embeddings:
self.embed_in = TensorParallelEmbedding(
config.vocab_size, config.hidden_size, process_group=process_group
)
else:
self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size)
self.embed_in = TensorParallelEmbedding(
prefix="gpt_neox.embed_in", weights=weights
)
self.layers = nn.ModuleList(
[
FlashNeoXLayer(
config.num_attention_heads,
config.hidden_act,
config.hidden_size,
config.intermediate_size,
config.rotary_pct,
config.rotary_emb_base,
config.layer_norm_eps,
config.use_parallel_residual,
process_group,
)
for _ in range(config.num_hidden_layers)
FlashNeoXLayer(layer_id, config, weights)
for layer_id in range(config.num_hidden_layers)
]
)
self.final_layer_norm = FastLayerNorm(
config.hidden_size, eps=config.layer_norm_eps
self.final_layer_norm = FastLayerNorm.load(
prefix="gpt_neox.final_layer_norm",
weights=weights,
eps=config.layer_norm_eps,
)
self.gradient_checkpointing = False
@ -345,29 +324,6 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
self.head_size = self.layers[0].attention.head_size
self.num_heads = self.layers[0].attention.num_heads
def post_load_weights(self, quantize: Optional[str] = None):
if isinstance(self.embed_in, TensorParallelEmbedding):
self.embed_in.add_null_idx()
for layer in self.layers:
layer: FlashNeoXLayer
layer.attention.shuffle_qkv_dims()
layer.attention.query_key_value.prepare_weights(quantize)
layer.attention.dense.prepare_weights(quantize)
layer.mlp.dense_h_to_4h.prepare_weights(quantize)
layer.mlp.dense_4h_to_h.prepare_weights(quantize)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# Pop here as we will replace the layer in our own logic and don't want from_pretrained
# to do it for us
load_in_8bit = kwargs.pop("load_in_8bit", False)
model = super(FlashGPTNeoXModel, cls).from_pretrained(
pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs
)
model.post_load_weights("bitsandbytes" if load_in_8bit else None)
return model
def forward(
self,
input_ids,
@ -435,42 +391,13 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
def __init__(self, config, process_group=None):
def __init__(self, config, weights):
super().__init__(config)
self.gpt_neox = FlashGPTNeoXModel(config, weights)
self.process_group = process_group
if self.process_group is not None:
self.world_size = self.process_group.size()
else:
self.world_size = 1
self.gpt_neox = FlashGPTNeoXModel(config, process_group)
if self.gpt_neox.tp_embeddings:
self.embed_out = FastLinear(
config.hidden_size,
config.vocab_size // process_group.size(),
bias=False,
)
else:
self.embed_out = FastLinear(
config.hidden_size, config.vocab_size, bias=False
)
def post_load_weights(self, quantize: Optional[str] = None):
self.gpt_neox.post_load_weights(quantize)
self.embed_out.prepare_weights()
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# Pop here as we will replace the layer in our own logic and don't want from_pretrained
# to do it for us
load_in_8bit = kwargs.pop("load_in_8bit", False)
model = super(FlashGPTNeoXForCausalLM, cls).from_pretrained(
pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs
self.embed_out = TensorParallelHead.load(
config, prefix="embed_out", weights=weights
)
model.post_load_weights("bitsandbytes" if load_in_8bit else None)
return model
def forward(
self,
@ -495,12 +422,4 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.embed_out(hidden_states)
if self.gpt_neox.tp_embeddings:
# Logits are sharded, so we need to gather them
world_logits = [torch.empty_like(logits) for _ in range(self.world_size)]
torch.distributed.all_gather(world_logits, logits, group=self.process_group)
world_logits = torch.cat(world_logits, dim=1)
return world_logits, present
return logits, present

View File

@ -1,5 +1,3 @@
import os
import torch
import torch.distributed
@ -12,15 +10,31 @@ from typing import Optional
import flash_attn_cuda
from text_generation_server.utils.layers import (
FastLinear,
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelHead,
FastLayerNorm,
PositionRotaryEmbedding,
get_linear,
)
def load_row(config, prefix: str, weights, bias: bool):
weight = weights.get_sharded(f"{prefix}.weight", dim=1)
if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process
bias = weights.get_tensor(f"{prefix}.bias")
else:
bias = None
linear = get_linear(weight, bias, config.quantize)
if config.parallel_attn:
return linear
else:
return TensorParallelRowLinear(linear, process_group=weights.process_group)
class RWConfig(PretrainedConfig):
attribute_map = {
"num_hidden_layers": "n_layer",
@ -85,44 +99,31 @@ class RWConfig(PretrainedConfig):
class FlashRWAttention(torch.nn.Module):
def __init__(
self,
num_heads,
num_heads_kv,
hidden_size,
bias,
process_group=None,
reduce=True,
config,
prefix,
weights,
):
super().__init__()
self.num_heads = num_heads
self.num_heads_kv = num_heads_kv
self.hidden_size = hidden_size
self.head_size = hidden_size // num_heads
self.num_heads = config.n_head
self.num_heads_kv = config.n_head_kv
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000)
self.rotary_emb = PositionRotaryEmbedding.static(
dim=self.head_size, base=10000.0, device=weights.device
)
self.softmax_scale = self.head_size ** (-0.5)
self.num_heads = self.num_heads // weights.process_group.size()
if process_group is None:
self.query_key_value = FastLinear(
hidden_size,
self.head_size * (self.num_heads + 2 * self.num_heads_kv),
bias=bias,
)
self.dense = FastLinear(hidden_size, hidden_size, bias=bias)
else:
self.query_key_value = TensorParallelColumnLinear(
hidden_size,
self.head_size * (self.num_heads + 2 * self.num_heads_kv),
bias=bias,
process_group=process_group,
)
self.dense = TensorParallelRowLinear(
hidden_size,
hidden_size,
bias=bias,
process_group=process_group,
reduce=reduce,
)
self.num_heads = self.num_heads // process_group.size()
self.query_key_value = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.query_key_value",
weights=weights,
bias=config.bias,
)
self.dense = load_row(
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
)
def forward(
self,
@ -212,57 +213,48 @@ class FlashRWAttention(torch.nn.Module):
class FlashRWLargeAttention(torch.nn.Module):
def __init__(
self,
num_heads,
num_heads_kv,
hidden_size,
bias,
process_group=None,
reduce=True,
config,
prefix,
weights,
):
super().__init__()
hidden_size = config.hidden_size
num_heads = config.n_head
num_heads_kv = config.n_head_kv
self.hidden_size = hidden_size
self.head_size = hidden_size // num_heads
self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000)
self.rotary_emb = PositionRotaryEmbedding.static(
self.head_size, base=10000.0, device=weights.device
)
self.softmax_scale = self.head_size ** (-0.5)
self.num_groups = num_heads // (num_heads_kv * 2)
self.num_heads = num_heads // self.num_groups
self.num_heads_kv = num_heads_kv // self.num_groups
process_group = weights.process_group
if process_group is None:
self.query_key_value = FastLinear(
hidden_size,
self.num_groups
* self.head_size
* (self.num_heads + 2 * self.num_heads_kv),
bias=bias,
if process_group.size() > self.num_groups:
raise NotImplementedError(
f"Tensor Parallelism is not implemented for world_size > n groups"
)
self.dense = FastLinear(hidden_size, hidden_size, bias=bias)
else:
if process_group.size() > self.num_groups:
raise NotImplementedError(
f"Tensor Parallelism is not implemented for world_size > n groups"
)
if self.num_groups % process_group.size() != 0:
raise NotImplementedError(
f"Tensor Parallelism is not implemented for {self.num_groups} not divisible by {process_group.size()}"
)
self.num_groups = self.num_groups // process_group.size()
self.query_key_value = TensorParallelColumnLinear(
hidden_size,
self.num_groups
* self.head_size
* (self.num_heads + 2 * self.num_heads_kv),
bias=bias,
process_group=process_group,
)
self.dense = TensorParallelRowLinear(
hidden_size,
hidden_size,
bias=bias,
process_group=process_group,
reduce=reduce,
)
self.num_groups = self.num_groups // process_group.size()
self.query_key_value = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.query_key_value",
weights=weights,
bias=config.bias,
)
self.dense = load_row(
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
)
def forward(
self,
@ -359,28 +351,16 @@ class FlashRWLargeAttention(torch.nn.Module):
class FlashMLP(nn.Module):
def __init__(self, hidden_size, bias, process_group=None, reduce=True):
def __init__(self, config, prefix, weights):
super().__init__()
self.act = torch.nn.functional.gelu
if process_group is None:
self.dense_h_to_4h = FastLinear(hidden_size, 4 * hidden_size, bias=bias)
self.dense_4h_to_h = FastLinear(4 * hidden_size, hidden_size, bias=bias)
else:
self.dense_h_to_4h = TensorParallelColumnLinear(
hidden_size,
4 * hidden_size,
bias=bias,
process_group=process_group,
)
self.dense_4h_to_h = TensorParallelRowLinear(
4 * hidden_size,
hidden_size,
bias=bias,
process_group=process_group,
reduce=reduce,
)
self.process_group = process_group
self.dense_h_to_4h = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=config.bias
)
self.dense_4h_to_h = load_row(
config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=config.bias
)
def forward(self, hidden_states):
hidden_states = self.dense_h_to_4h(hidden_states)
@ -392,38 +372,44 @@ class FlashMLP(nn.Module):
class FlashRWLayer(nn.Module):
def __init__(
self,
num_heads,
num_heads_kv,
hidden_size,
bias,
layer_norm_eps,
parallel_attn,
process_group=None,
layer_id,
config,
weights,
):
super().__init__()
parallel_attn = config.parallel_attn
self.parallel_attn = parallel_attn
self.input_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps)
prefix = f"transformer.h.{layer_id}"
self.input_layernorm = FastLayerNorm.load(
prefix=f"{prefix}.input_layernorm",
weights=weights,
eps=config.layer_norm_epsilon,
)
self.self_attention = FlashRWAttention(
num_heads,
num_heads_kv,
hidden_size,
bias,
process_group=process_group,
reduce=False,
config,
prefix=f"{prefix}.self_attention",
weights=weights,
)
self.post_attention_layernorm = (
FastLayerNorm(hidden_size, eps=layer_norm_eps)
FastLayerNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.layer_norm_epsilon,
)
if not parallel_attn
else None
)
self.mlp = FlashMLP(
hidden_size, bias, process_group=process_group, reduce=False
config,
prefix=f"{prefix}.mlp",
weights=weights,
)
self.process_group = process_group
self.process_group = weights.process_group
def forward(
self,
@ -454,9 +440,7 @@ class FlashRWLayer(nn.Module):
mlp_output = self.mlp(ln_hidden_states)
intermediate = mlp_output + attn_output
# Only reduce once and after the addition instead of once per layer
if self.process_group is not None:
torch.distributed.all_reduce(intermediate, group=self.process_group)
torch.distributed.all_reduce(intermediate, group=self.process_group)
return intermediate, residual
else:
@ -483,33 +467,30 @@ class FlashRWLayer(nn.Module):
class FlashRWLargeLayer(nn.Module):
def __init__(
self,
num_heads,
num_heads_kv,
hidden_size,
bias,
layer_norm_eps,
process_group=None,
):
def __init__(self, layer_id, config, weights):
super().__init__()
self.ln_attn = FastLayerNorm(hidden_size, eps=layer_norm_eps)
self.ln_mlp = FastLayerNorm(hidden_size, eps=layer_norm_eps)
prefix = f"transformer.h.{layer_id}"
self.ln_attn = FastLayerNorm.load(
prefix=f"{prefix}.ln_attn",
weights=weights,
eps=config.layer_norm_epsilon,
)
self.ln_mlp = FastLayerNorm.load(
prefix=f"{prefix}.ln_mlp",
weights=weights,
eps=config.layer_norm_epsilon,
)
self.self_attention = FlashRWLargeAttention(
num_heads,
num_heads_kv,
hidden_size,
bias,
process_group=process_group,
reduce=False,
config,
prefix=f"{prefix}.self_attention",
weights=weights,
)
assert config.parallel_attn, "This version doesn't support non parallel_attn"
self.mlp = FlashMLP(
hidden_size, bias, process_group=process_group, reduce=False
)
self.mlp = FlashMLP(config, prefix=f"{prefix}.mlp", weights=weights)
self.process_group = process_group
self.process_group = weights.process_group
def forward(
self,
@ -543,9 +524,7 @@ class FlashRWLargeLayer(nn.Module):
intermediate = attn_output + mlp_output
# Only reduce once and after the addition instead of once per layer
if self.process_group is not None:
torch.distributed.all_reduce(intermediate, group=self.process_group)
torch.distributed.all_reduce(intermediate, group=self.process_group)
return intermediate, residual
@ -555,37 +534,18 @@ class FlashRWPreTrainedModel(PreTrainedModel):
class FlashRWModel(FlashRWPreTrainedModel):
def __init__(self, config, process_group=None):
def __init__(self, config, weights):
super().__init__(config)
self.config = config
self.tp_embeddings = False
if process_group is not None:
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
if config.vocab_size % self.tp_world_size == 0:
self.tp_embeddings = True
if self.tp_embeddings:
self.word_embeddings = TensorParallelEmbedding(
config.vocab_size, config.hidden_size, process_group=process_group
)
else:
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
self.word_embeddings = TensorParallelEmbedding(
prefix="transformer.word_embeddings", weights=weights
)
if config.model_type == "RefinedWebModel":
self.h = nn.ModuleList(
[
FlashRWLayer(
config.n_head,
config.n_head_kv,
config.hidden_size,
config.bias,
config.layer_norm_epsilon,
config.parallel_attn,
process_group,
)
for _ in range(config.num_hidden_layers)
FlashRWLayer(layer_id, config, weights)
for layer_id in range(config.num_hidden_layers)
]
)
self.cache_size = (
@ -596,15 +556,8 @@ class FlashRWModel(FlashRWPreTrainedModel):
elif config.model_type == "RefinedWeb":
self.h = nn.ModuleList(
[
FlashRWLargeLayer(
config.n_head,
config.n_head_kv,
config.hidden_size,
config.bias,
config.layer_norm_epsilon,
process_group,
)
for _ in range(config.num_hidden_layers)
FlashRWLargeLayer(layer_id, config, weights)
for layer_id in range(config.num_hidden_layers)
]
)
self.cache_size = (
@ -617,31 +570,13 @@ class FlashRWModel(FlashRWPreTrainedModel):
f"model_type {config.model_type} is not supported."
)
self.ln_f = FastLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.head_size = self.h[0].self_attention.head_size
def post_load_weights(self, quantize: Optional[str] = None):
if isinstance(self.word_embeddings, TensorParallelEmbedding):
self.word_embeddings.add_null_idx()
for layer in self.h:
layer: FlashRWLayer
layer.self_attention.query_key_value.prepare_weights(quantize)
layer.self_attention.dense.prepare_weights(quantize)
layer.mlp.dense_h_to_4h.prepare_weights(quantize)
layer.mlp.dense_4h_to_h.prepare_weights(quantize)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# Pop here as we will replace the layer in our own logic and don't want from_pretrained
# to do it for us
load_in_8bit = kwargs.pop("load_in_8bit", False)
model = super(FlashRWModel, cls).from_pretrained(
pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs
self.ln_f = FastLayerNorm.load(
prefix="transformer.ln_f",
weights=weights,
eps=config.layer_norm_epsilon,
)
model.post_load_weights("bitsandbytes" if load_in_8bit else None)
return model
self.head_size = self.h[0].self_attention.head_size
def forward(
self,
@ -708,40 +643,14 @@ class FlashRWModel(FlashRWPreTrainedModel):
class FlashRWForCausalLM(FlashRWPreTrainedModel):
def __init__(self, config, process_group=None):
def __init__(self, config, weights):
super().__init__(config)
self.process_group = process_group
if self.process_group is not None:
self.world_size = self.process_group.size()
else:
self.world_size = 1
self.transformer = FlashRWModel(config, weights)
self.transformer = FlashRWModel(config, process_group)
if self.transformer.tp_embeddings:
self.lm_head = FastLinear(
config.hidden_size,
config.vocab_size // process_group.size(),
bias=False,
)
else:
self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False)
def post_load_weights(self, quantize: Optional[str] = None):
self.transformer.post_load_weights(quantize)
self.lm_head.prepare_weights()
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# Pop here as we will replace the layer in our own logic and don't want from_pretrained
# to do it for us
load_in_8bit = kwargs.pop("load_in_8bit", False)
model = super(FlashRWForCausalLM, cls).from_pretrained(
pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs
self.lm_head = TensorParallelHead.load(
config, prefix="lm_head", weights=weights
)
model.post_load_weights("bitsandbytes" if load_in_8bit else None)
return model
def forward(
self,
@ -766,12 +675,4 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
if self.transformer.tp_embeddings:
# Logits are sharded, so we need to gather them
world_logits = [torch.empty_like(logits) for _ in range(self.world_size)]
torch.distributed.all_gather(world_logits, logits, group=self.process_group)
world_logits = torch.cat(world_logits, dim=1)
return world_logits, present
return logits, present

View File

@ -8,39 +8,142 @@ from typing import Optional
# Flash attention imports
import flash_attn_cuda
from text_generation_server.utils.layers import (
FastLinear,
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelHead,
TensorParallelEmbedding,
FastLayerNorm,
get_linear,
)
class FlashMQAttention(torch.nn.Module):
def __init__(
self,
num_heads,
def load_multi_mqa(
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
):
if any("c_attn" in k for k in weights.routing.keys()):
slice_ = weights._get_slice(f"{prefix}.c_attn.weight")
shape = slice_.get_shape()
world_size = weights.process_group.size()
rank = weights.process_group.rank()
if config.transpose:
block_size = (shape[1] - 2 * head_size) // world_size
start = rank * block_size
stop = (rank + 1) * block_size
assert (shape[1] - 2 * head_size) % world_size == 0
q_tensor = slice_[:, start:stop]
kv_tensor = slice_[:, -2 * head_size :]
weight = torch.cat([q_tensor, kv_tensor], dim=1).T
else:
block_size = (shape[0] - 2 * head_size) // world_size
start = rank * block_size
stop = (rank + 1) * block_size
assert (shape[0] - 2 * head_size) % world_size == 0
q_tensor = slice_[start:stop]
kv_tensor = slice_[-2 * head_size :]
weight = torch.cat([q_tensor, kv_tensor], dim=0)
if bias:
slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
shape = slice_.get_shape()
block_size = (shape[0] - 2 * head_size) // world_size
assert (shape[0] - 2 * head_size) % world_size == 0
q_tensor = slice_[start:stop]
start = rank * block_size
stop = (rank + 1) * block_size
q_tensor = slice_[start:stop]
kv_tensor = slice_[-2 * head_size :]
bias = torch.cat([q_tensor, kv_tensor], dim=0)
else:
if config.transpose:
w = [
weights.get_sharded(f"{prefix}.q_attn.weight", dim=1).T,
weights.get_tensor(f"{prefix}.kv_attn.weight").T,
]
weight = torch.cat(w, dim=0)
else:
w = [
weights.get_sharded(f"{prefix}.q_attn.weight", dim=0),
weights.get_tensor(f"{prefix}.kv_attn.weight"),
]
weight = torch.cat(w, dim=1)
if bias:
b = [
weights.get_sharded(f"{prefix}.q_attn.bias", dim=0),
weights.get_tensor(f"{prefix}.kv_attn.bias"),
]
bias = torch.cat(b, dim=0)
else:
bias = None
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
assert list(weight.shape) == [
(num_heads + 2) * head_size,
hidden_size,
process_group=None,
):
], f"{weight.shape} != {[(num_heads + 2) * head_size, hidden_size]}"
if bias is not None:
bias = bias.to(dtype=weights.dtype).to(device=weights.device)
assert list(bias.shape) == [
(num_heads + 2) * head_size
], f"{weight.shape} != {[(num_heads + 2) * head_size]}"
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
def load_col(config, prefix: str, weights, bias: bool):
if config.transpose:
weight = weights.get_sharded(f"{prefix}.weight", dim=1).T
else:
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
if bias:
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
else:
bias = None
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
def load_row(config, prefix: str, weights, bias: bool):
if config.transpose:
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T
else:
weight = weights.get_sharded(f"{prefix}.weight", dim=1)
if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process
bias = weights.get_tensor(f"{prefix}.bias")
else:
bias = None
return TensorParallelRowLinear(
get_linear(weight, bias, config.quantize), process_group=weights.process_group
)
class FlashMQAttention(torch.nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
num_heads = config.num_attention_heads
hidden_size = config.hidden_size
self.num_heads = num_heads
self.hidden_size = hidden_size
self.head_size = hidden_size // num_heads
assert self.num_heads % weights.process_group.size() == 0
self.num_heads = self.num_heads // weights.process_group.size()
self.softmax_scale = self.head_size ** (-0.5)
if process_group is None:
self.c_attn = FastLinear(hidden_size, hidden_size + 2 * self.head_size)
self.c_proj = FastLinear(hidden_size, hidden_size)
else:
self.num_heads = self.num_heads // process_group.size()
self.c_attn = FastLinear(hidden_size, self.head_size * (self.num_heads + 2))
self.c_proj = TensorParallelRowLinear(
hidden_size,
hidden_size,
process_group=process_group,
)
self.c_attn = load_multi_mqa(
config,
prefix=prefix,
weights=weights,
bias=True,
head_size=self.head_size,
hidden_size=hidden_size,
num_heads=self.num_heads,
)
self.c_proj = load_row(
config, prefix=f"{prefix}.c_proj", weights=weights, bias=True
)
def forward(
self,
@ -121,8 +224,9 @@ class FlashMQAttention(torch.nn.Module):
class MLP(nn.Module):
def __init__(self, act, hidden_size, intermediate_size, process_group=None):
def __init__(self, prefix, config, weights):
super().__init__()
act = config.activation_function
self.act = (
ACT2FN[act]
if "gelu" not in act
@ -134,20 +238,12 @@ class MLP(nn.Module):
)
)
if process_group is None:
self.c_fc = FastLinear(hidden_size, intermediate_size)
self.c_proj = FastLinear(intermediate_size, hidden_size)
else:
self.c_fc = TensorParallelColumnLinear(
hidden_size,
intermediate_size,
process_group=process_group,
)
self.c_proj = TensorParallelRowLinear(
intermediate_size,
hidden_size,
process_group=process_group,
)
self.c_fc = load_col(
config, prefix=f"{prefix}.c_fc", weights=weights, bias=True
)
self.c_proj = load_row(
config, prefix=f"{prefix}.c_proj", weights=weights, bias=True
)
def forward(self, hidden_states):
hidden_states = self.c_fc(hidden_states)
@ -157,28 +253,24 @@ class MLP(nn.Module):
class Block(nn.Module):
def __init__(
self,
num_heads,
act,
hidden_size,
intermediate_size,
layer_norm_eps,
process_group=None,
):
def __init__(self, layer_id, config, weights):
super().__init__()
self.ln_1 = FastLayerNorm(hidden_size, eps=layer_norm_eps)
self.ln_2 = FastLayerNorm(hidden_size, eps=layer_norm_eps)
prefix = f"transformer.h.{layer_id}"
self.ln_1 = FastLayerNorm.load(
prefix=f"{prefix}.ln_1", weights=weights, eps=config.layer_norm_epsilon
)
self.ln_2 = FastLayerNorm.load(
prefix=f"{prefix}.ln_2", weights=weights, eps=config.layer_norm_epsilon
)
self.attn = FlashMQAttention(
num_heads,
hidden_size,
process_group,
prefix=f"{prefix}.attn",
config=config,
weights=weights,
)
self.mlp = MLP(
act,
hidden_size,
intermediate_size,
process_group,
prefix=f"{prefix}.mlp",
config=config,
weights=weights,
)
def forward(
@ -210,66 +302,39 @@ class Block(nn.Module):
class FlashSantacoderModel(nn.Module):
def __init__(self, config, process_group=None):
def __init__(self, config, weights):
super().__init__()
self.config = config
self.process_group = process_group
self.tp_embeddings = False
if process_group is not None:
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
if config.vocab_size % self.tp_world_size == 0:
self.tp_embeddings = True
if self.tp_embeddings:
self.wte = TensorParallelEmbedding(
config.vocab_size,
config.hidden_size,
reduce=False,
process_group=process_group,
)
self.wpe = TensorParallelEmbedding(
config.max_position_embeddings,
config.hidden_size,
reduce=False,
process_group=process_group,
)
else:
self.wte = nn.Embedding(config.vocab_size, config.hidden_size)
self.wpe = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.process_group = weights.process_group
self.wte = TensorParallelEmbedding(
prefix="transformer.wte",
weights=weights,
reduce=False,
)
self.wpe = TensorParallelEmbedding(
prefix="transformer.wpe",
weights=weights,
reduce=False,
)
self.h = nn.ModuleList(
[
Block(
config.num_attention_heads,
config.activation_function,
config.hidden_size,
config.n_inner
if config.n_inner is not None
else 4 * config.hidden_size,
config.layer_norm_epsilon,
process_group,
layer_id,
config,
weights,
)
for _ in range(config.num_hidden_layers)
for layer_id in range(config.num_hidden_layers)
]
)
self.ln_f = FastLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.ln_f = FastLayerNorm.load(
prefix="transformer.ln_f", weights=weights, eps=config.layer_norm_epsilon
)
self.head_size = self.h[0].attn.head_size
self.num_heads = self.h[0].attn.num_heads
def post_load_weights(self, quantize: Optional[str] = None):
if self.tp_embeddings:
self.wte.add_null_idx()
self.wpe.add_null_idx()
for layer in self.h:
layer: Block
layer.attn.c_attn.prepare_weights(quantize)
layer.attn.c_proj.prepare_weights(quantize)
layer.mlp.c_fc.prepare_weights(quantize)
layer.mlp.c_proj.prepare_weights(quantize)
def forward(
self,
input_ids,
@ -281,8 +346,7 @@ class FlashSantacoderModel(nn.Module):
pre_allocate_past_size: Optional[int] = None,
):
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
if self.tp_embeddings:
torch.distributed.all_reduce(hidden_states, group=self.process_group)
torch.distributed.all_reduce(hidden_states, group=self.process_group)
# Prefill
if past_key_values is None:
@ -331,23 +395,12 @@ class FlashSantacoderModel(nn.Module):
class FlashSantacoderForCausalLM(nn.Module):
def __init__(self, config, process_group=None):
def __init__(self, config, weights):
super().__init__()
self.transformer = FlashSantacoderModel(config, process_group)
if self.transformer.tp_embeddings:
self.lm_head = FastLinear(
config.hidden_size,
config.vocab_size // process_group.size(),
bias=False,
)
else:
self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False)
def post_load_weights(self, quantize: Optional[str] = None):
self.transformer.post_load_weights(quantize)
self.lm_head.prepare_weights()
self.transformer = FlashSantacoderModel(config, weights)
self.lm_head = TensorParallelHead.load(
config, prefix="transformer.wte", weights=weights
)
def forward(
self,
@ -372,29 +425,4 @@ class FlashSantacoderForCausalLM(nn.Module):
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
if self.transformer.tp_embeddings:
# Logits are sharded, so we need to gather them
if logits.shape[0] == 1:
# Fast path when batch size is 1
world_logits = logits.new_empty(
(logits.shape[1] * self.transformer.tp_world_size)
)
torch.distributed.all_gather_into_tensor(
world_logits, logits.view(-1), group=self.transformer.process_group
)
world_logits = world_logits.view(1, -1)
else:
# We cannot use all_gather_into_tensor as it only support concatenating on the first dim
world_logits = [
torch.empty_like(logits)
for _ in range(self.transformer.tp_world_size)
]
torch.distributed.all_gather(
world_logits, logits, group=self.transformer.process_group
)
world_logits = torch.cat(world_logits, dim=1)
return world_logits, present
return logits, present

View File

@ -0,0 +1,794 @@
# coding=utf-8
# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch GPTNeoX model."""
from typing import Optional, Tuple, Union
import os
import torch
import torch.distributed
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.activations import ACT2FN
from transformers.file_utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from transformers.modeling_utils import PreTrainedModel
from transformers import GPTNeoXConfig
from loguru import logger
from text_generation_server.utils.layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
TensorParallelHead,
)
CUSTOM_KERNELS_ENABLED = False
if not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True":
try:
from custom_kernels import fused_attention_cuda
CUSTOM_KERNELS_ENABLED = True
except ImportError:
pass
if not CUSTOM_KERNELS_ENABLED:
logger.warning("We're not using custom kernels.")
def make_causal_mask(
input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
) -> torch.BoolTensor:
"""
Make causal mask used for self-attention.
"""
batch_size, target_length = input_ids_shape
mask = torch.ones(
(target_length, target_length + past_key_values_length),
dtype=torch.bool,
device=device,
)
mask = mask.triu(1 + past_key_values_length)
expanded_mask = mask.unsqueeze(0).expand(
batch_size, target_length, target_length + past_key_values_length
)
return expanded_mask
def expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
"""
Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`.
"""
batch_size, src_length = mask.shape
tgt_length = tgt_length if tgt_length is not None else src_length
expanded_mask = ~(mask[:, None, :].to(torch.bool))
return expanded_mask.expand(batch_size, tgt_length, src_length)
def prepare_attn_mask(
attention_mask: torch.Tensor,
input_shape: Tuple[int, int],
past_key_values_length: int,
) -> torch.BoolTensor:
# create causal mask
# [batch_size, seq_length] -> [batch_size, tgt_length, src_length]
combined_attention_mask = None
device = attention_mask.device
_, src_length = input_shape
if src_length > 1:
combined_attention_mask = make_causal_mask(
input_shape, device=device, past_key_values_length=past_key_values_length
)
# [batch_size, seq_length] -> [batch_size, tgt_length, src_length]
expanded_attn_mask = expand_mask(attention_mask, tgt_length=src_length)
combined_attention_mask = (
expanded_attn_mask
if combined_attention_mask is None
else expanded_attn_mask | combined_attention_mask
)
return combined_attention_mask
class GPTNeoXPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
class GPTNeoXAttention(nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.num_attention_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_attention_heads
self.rotary_ndims = int(self.head_size * config.rotary_pct)
max_positions = config.max_position_embeddings
# ??? TODO
# self.register_buffer(
# "bias",
# torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
# 1, 1, max_positions, max_positions
# ),
# )
# self.register_buffer("masked_bias", torch.tensor(-1e9))
self.rotary_emb = RotaryEmbedding(
self.rotary_ndims,
config.max_position_embeddings,
base=config.rotary_emb_base,
)
self.rotary_emb.inv_freq = nn.Parameter(
weights.get_tensor(f"{prefix}.rotary_emb.inv_freq")
)
self.inv_norm_factor = 1.0 / torch.sqrt(
torch.tensor(self.head_size, dtype=torch.float32)
).to(torch.get_default_dtype())
assert self.num_attention_heads % weights.process_group.size() == 0
self.num_attention_heads = (
self.num_attention_heads // weights.process_group.size()
)
self.query_key_value = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.query_key_value", weights=weights, bias=True
)
self.dense = TensorParallelRowLinear.load(
config, prefix=f"{prefix}.dense", weights=weights, bias=True
)
def forward(
self,
hidden_states,
position_ids,
attention_mask,
head_mask=None,
layer_past=None,
use_cache=False,
output_attentions=False,
):
has_layer_past = layer_past is not None
# Compute QKV
# Attention heads [batch, seq_len, hidden_size]
# --> [batch, seq_len, (np * 3 * head_size)]
qkv = self.query_key_value(hidden_states)
# [batch, seq_len, (num_heads * 3 * head_size)]
# --> [batch, seq_len, num_heads, 3 * head_size]
new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size)
qkv = qkv.view(*new_qkv_shape).permute(0, 2, 1, 3)
# [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size]
query, key, value = qkv.split(self.head_size, -1)
# Compute token offset for rotary embeddings (when decoding)
seq_len = key.shape[-2]
if has_layer_past:
seq_len += layer_past[0].shape[-2]
# Compute rotary embeddings on rotary_ndims
query_rot = query[..., : self.rotary_ndims]
key_rot = key[..., : self.rotary_ndims]
query_rot, key_rot = self.rotary_emb(query_rot, key_rot, position_ids, seq_len)
query[..., : self.rotary_ndims] = query_rot
key[..., : self.rotary_ndims] = key_rot
if CUSTOM_KERNELS_ENABLED:
attn_output, present, attn_weights = fused_attention_cuda.forward(
query,
key,
value,
layer_past,
attention_mask,
head_mask,
self.inv_norm_factor,
self.num_attention_heads,
use_cache,
)
else:
# Cache QKV values
if has_layer_past:
past_key = layer_past[0]
past_value = layer_past[1]
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
present = (key, value) if use_cache else None
# Compute attention
attn_output, attn_weights = self._attn(
query, key, value, attention_mask, head_mask
)
# Reshape outputs
attn_output = self._merge_heads(
attn_output, self.num_attention_heads, self.head_size
)
attn_output = self.dense(attn_output)
outputs = (attn_output, present)
if output_attentions:
outputs += (attn_weights,)
return outputs
@classmethod
def _split_heads(cls, tensor, num_attention_heads, attn_head_size):
"""
Splits hidden dim into attn_head_size and num_attention_heads
"""
# tensor: [bs, seq_len, hidden_size]
new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size)
# -> [bs, seq_len, num_attention_heads, attn_head_size]
tensor = tensor.view(new_shape)
# -> [bs, num_attention_heads, seq_len, attn_head_size]
tensor = tensor.permute(0, 2, 1, 3)
return tensor
@classmethod
def _merge_heads(cls, tensor, num_attention_heads, attn_head_size):
"""
Merges attn_head_size dim and num_attn_heads dim into hidden dim
"""
# tensor [bs, num_attention_heads, seq_len, attn_head_size]
tensor = tensor.permute(0, 2, 1, 3).contiguous()
# -> [bs, seq_len, num_attention_heads, attn_head_size]
tensor = tensor.view(
tensor.size(0), tensor.size(1), num_attention_heads * attn_head_size
)
# -> [bs, seq_len, hidden_size]
return tensor
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
# q, k, v: [bs, num_attention_heads, seq_len, attn_head_size]
# compute causal mask from causal mask buffer
batch_size, num_attention_heads, query_length, attn_head_size = query.size()
key_length = key.size(-2)
query = query.view(
batch_size * num_attention_heads, query_length, attn_head_size
)
key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
attn_scores = torch.zeros(
1,
dtype=query.dtype,
device=key.device,
).expand(batch_size * num_attention_heads, query_length, key_length)
attn_scores = torch.baddbmm(
attn_scores,
query,
key.transpose(1, 2),
beta=1.0,
alpha=self.inv_norm_factor,
)
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
input_dtype = attn_scores.dtype
if input_dtype in [torch.float16, torch.bfloat16]:
attn_scores = attn_scores.to(torch.float)
attn_scores = torch.where(
attention_mask, torch.finfo(attn_scores.dtype).min, attn_scores
)
attn_scores = attn_scores.view(
batch_size, num_attention_heads, query_length, key_length
)
attn_weights = nn.functional.softmax(attn_scores, dim=-1)
attn_weights = attn_weights.to(value.dtype)
# Mask heads if we want to
if head_mask is not None:
attn_weights = attn_weights * head_mask
attn_output = torch.matmul(attn_weights, value)
return attn_output, attn_weights
class RotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings, base=10000, device=None):
super().__init__()
self.true_inv_freq = 1.0 / (
base ** (torch.arange(0, dim, 2).float().to(device) / dim)
)
self.register_buffer("inv_freq", self.true_inv_freq)
# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
self.cos_cached = None
self.sin_cached = None
@staticmethod
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
@staticmethod
def _create_cos_sin(inv_freq, max_position_embeddings, dtype, device):
t = torch.arange(
max_position_embeddings, device=inv_freq.device, dtype=inv_freq.dtype
)
freqs = torch.einsum("i,j->ij", t, inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
return emb.cos().to(device).to(dtype), emb.sin().to(device).to(dtype)
def forward(self, q, k, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if (
seq_len > self.max_seq_len_cached
or self.cos_cached is None
or self.sin_cached is None
):
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
self.cos_cached, self.sin_cached = self._create_cos_sin(
self.true_inv_freq, self.max_seq_len_cached, q.dtype, q.device
)
return rotary_forward(q, k, self.cos_cached, self.sin_cached, position_ids)
@torch.jit.script
def rotary_forward(q, k, cos, sin, position_ids):
cos = cos[position_ids].unsqueeze(1)
sin = sin[position_ids].unsqueeze(1)
chunk_size = q.shape[-1] // 2
q1, q2 = q.split(chunk_size, -1)
q_rotated = torch.cat((-q2, q1), dim=-1)
k1, k2 = k.split(chunk_size, -1)
k_rotated = torch.cat((-k2, k1), dim=-1)
q_embed = (q * cos) + (q_rotated * sin)
k_embed = (k * cos) + (k_rotated * sin)
return q_embed, k_embed
class GPTNeoXMLP(nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.act = (
ACT2FN[config.hidden_act]
if "gelu_fast" not in config.hidden_act
else lambda x: torch.nn.functional.gelu(x, approximate="tanh")
)
self.dense_h_to_4h = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=True
)
self.dense_4h_to_h = TensorParallelRowLinear.load(
config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=True
)
def forward(self, hidden_states):
hidden_states = self.dense_h_to_4h(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.dense_4h_to_h(hidden_states)
return hidden_states
class GPTNeoXLayer(nn.Module):
def __init__(self, layer_id, config, weights):
super().__init__()
self.use_parallel_residual = config.use_parallel_residual
self.input_layernorm = nn.LayerNorm.load(
prefix=f"gpt_neox.layers.{layer_id}.input_layernorm",
weights=weights,
eps=config.layer_norm_eps,
)
self.post_attention_layernorm = nn.LayerNorm.load(
prefix=f"gpt_neox.layers.{layer_id}.post_attention_layernorm",
weights=weights,
eps=config.layer_norm_eps,
)
self.attention = GPTNeoXAttention(
config, prefix=f"gpt_neox.layers.{layer_id}.attention", weights=weights
)
self.mlp = GPTNeoXMLP(
config, prefix=f"gpt_neox.layers.{layer_id}.mlp", weights=weights
)
def forward(
self,
hidden_states,
position_ids,
attention_mask=None,
head_mask=None,
use_cache=False,
layer_past=None,
output_attentions=False,
):
attention_layer_outputs = self.attention(
self.input_layernorm(hidden_states),
attention_mask=attention_mask,
position_ids=position_ids,
layer_past=layer_past,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
attn_output = attention_layer_outputs[
0
] # output_attn: attn_output, present, (attn_weights)
outputs = attention_layer_outputs[1:]
if self.use_parallel_residual:
# pseudocode:
# x = x + attn(ln1(x)) + mlp(ln2(x))
mlp_output = self.mlp(self.post_attention_layernorm(hidden_states))
hidden_states = mlp_output + attn_output + hidden_states
else:
# pseudocode:
# x = x + attn(ln1(x))
# x = x + mlp(ln2(x))
attn_output = attn_output + hidden_states
mlp_output = self.mlp(self.post_attention_layernorm(attn_output))
hidden_states = mlp_output + attn_output
if use_cache:
outputs = (
hidden_states,
) + outputs # hidden_states, present, (attn_weights)
else:
outputs = (hidden_states,) + outputs[1:] # hidden_states, (attn_weights)
return outputs
class GPTNeoXModel(GPTNeoXPreTrainedModel):
def __init__(self, config, weights):
super().__init__(config)
self.config = config
self.num_attention_heads = config.num_attention_heads
self.embed_in = TensorParallelEmbedding(
prefix="gpt_neox.embed_in", weights=weights
)
self.layers = nn.ModuleList(
[
GPTNeoXLayer(layer_id, config, weights)
for layer_id in range(config.num_hidden_layers)
]
)
self.final_layer_norm = nn.LayerNorm.load(
prefix="gpt_neox.final_layer_norm",
weights=weights,
eps=config.layer_norm_eps,
)
self.tp_world_size = weights.process_group.size()
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
position_ids=None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
r"""
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif input_ids is not None:
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
batch_size, seq_length = input_shape
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * self.config.num_hidden_layers)
else:
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_length, seq_length + past_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_in(input_ids)
hidden_states = inputs_embeds
# Attention mask.
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values[0] is not None:
past_key_values_length = past_key_values[0][0].shape[-1]
seq_length_with_past = seq_length_with_past + past_key_values_length
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), device=hidden_states.device
)
else:
attention_mask = attention_mask.to(hidden_states.device)
causal_mask = prepare_attn_mask(
attention_mask,
input_shape=(batch_size, seq_length),
past_key_values_length=past_key_values_length,
)
assert self.num_attention_heads % self.tp_world_size == 0
block_size = self.num_attention_heads // self.tp_world_size
causal_mask = torch.repeat_interleave(causal_mask, block_size, dim=0)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
presents = () if use_cache else None
all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for i, (layer, layer_past) in enumerate(zip(self.layers, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = layer(
hidden_states,
position_ids=position_ids,
attention_mask=causal_mask,
head_mask=head_mask[i],
layer_past=layer_past,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_attentions = all_attentions + (outputs[2 if use_cache else 1],)
hidden_states = self.final_layer_norm(hidden_states)
# Add last hidden state
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [hidden_states, presents, all_hidden_states, all_attentions]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_attentions,
)
class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
def __init__(self, config, weights):
super().__init__(config)
self.gpt_neox = GPTNeoXModel(config, weights)
self.embed_out = TensorParallelHead.load(
config, prefix="embed_out", weights=weights
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional tensors are
only required when the model is used as a decoder in a Sequence to Sequence model.
Contains pre-computed hidden-states (key and values in the self-attention blocks that can be used (see
`past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
`[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, GPTNeoXForCausalLM, GPTNeoXConfig
>>> import torch
>>> tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
>>> config = GPTNeoXConfig.from_pretrained("EleutherAI/gpt-neox-20b")
>>> config.is_decoder = True
>>> model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/gpt-neox-20b", config=config)
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
>>> prediction_logits = outputs.logits
```"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
outputs = self.gpt_neox(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
lm_logits = self.embed_out(hidden_states)
lm_loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# we are doing next-token prediction; shift prediction scores and input ids by one
shift_logits = lm_logits[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss()
lm_loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
)
if not return_dict:
output = (lm_logits,) + outputs[1:]
return ((lm_loss,) + output) if lm_loss is not None else output
return CausalLMOutputWithPast(
loss=lm_loss,
logits=lm_logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
**kwargs,
):
input_shape = input_ids.shape
# cut decoder_input_ids if past is used
if past_key_values and past_key_values[0] is not None:
input_ids = input_ids[:, -1:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"position_ids": position_ids,
}
)
return model_inputs
def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(
past_state.index_select(0, beam_idx)
for past_state in layer_past[:2]
)
+ layer_past[2:],
)
return reordered_past

View File

@ -0,0 +1,837 @@
# coding=utf-8
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch OPT model."""
import random
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from transformers.activations import ACT2FN
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel
from transformers import OPTConfig
from text_generation_server.utils.layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
TensorParallelHead,
)
EPS = 1e-5
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size,
dtype: torch.dtype,
device: torch.device,
past_key_values_length: int = 0,
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full(
(tgt_len, tgt_len),
torch.tensor(torch.finfo(dtype).min, device=device),
device=device,
)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat(
[
torch.zeros(
tgt_len, past_key_values_length, dtype=dtype, device=device
),
mask,
],
dim=-1,
)
return mask[None, None, :, :].expand(
bsz, 1, tgt_len, tgt_len + past_key_values_length
)
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(
inverted_mask.to(torch.bool), torch.finfo(dtype).min
)
class OPTLearnedPositionalEmbedding(nn.Module):
"""
This module learns positional embeddings up to a fixed maximum size.
"""
def __init__(self, weights):
super().__init__()
self.offset = 2
self.weight = nn.Parameter(
weights.get_tensor("model.decoder.embed_positions.weight")
)
def forward(
self, attention_mask: torch.LongTensor, past_key_values_length: int = 0
):
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
attention_mask = attention_mask.long()
# create positions depending on attention_mask
positions = (
torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask
).long() - 1
# cut positions if `past_key_values_length` is > 0
positions = positions[:, past_key_values_length:]
return torch.nn.functional.embedding(positions + self.offset, self.weight)
class OPTAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
config,
prefix,
weights,
is_decoder: bool = False,
bias: bool = True,
process_group=None,
):
super().__init__()
embed_dim = config.embed_dim
num_heads = config.num_attention_heads
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = config.dropout
self.head_dim = embed_dim // num_heads
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})."
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
process_group = weights.process_group
assert self.num_heads % process_group.size() == 0
self.num_heads = self.num_heads // process_group.size()
self.embed_dim = self.embed_dim // process_group.size()
self.q_proj = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.q_proj", weights=weights, bias=bias
)
self.k_proj = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.k_proj", weights=weights, bias=bias
)
self.v_proj = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.v_proj", weights=weights, bias=bias
)
self.out_proj = TensorParallelRowLinear.load(
config, prefix=f"{prefix}.out_proj", weights=weights, bias=bias
)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return (
tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
.transpose(1, 2)
.contiguous()
)
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()
# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
# get key, value proj
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
attn_weights = (
attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attention_mask
)
attn_weights = torch.max(
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
# upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
if attn_weights.dtype == torch.float16:
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(torch.float16)
else:
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(
bsz, self.num_heads, tgt_len, src_len
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if output_attentions:
# this operation is a bit awkward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to be reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(
bsz, self.num_heads, tgt_len, src_len
)
attn_weights = attn_weights_reshaped.view(
bsz * self.num_heads, tgt_len, src_len
)
else:
attn_weights_reshaped = None
attn_probs = nn.functional.dropout(
attn_weights, p=self.dropout, training=self.training
)
attn_output = torch.bmm(attn_probs, value_states)
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned aross GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights_reshaped, past_key_value
class OPTDecoderLayer(nn.Module):
def __init__(self, layer_id: int, config: OPTConfig, weights):
super().__init__()
self.process_group = weights.process_group
self.embed_dim = config.hidden_size
prefix = f"model.decoder.layers.{layer_id}"
self.self_attn = OPTAttention(
config,
prefix=f"{prefix}.self_attn",
weights=weights,
is_decoder=True,
bias=config.enable_bias,
)
self.do_layer_norm_before = config.do_layer_norm_before
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.self_attn_layer_norm = nn.LayerNorm.load(
prefix=f"{prefix}.self_attn_layer_norm", weights=weights, eps=EPS
)
self.fc1 = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.fc1", weights=weights, bias=config.enable_bias
)
self.fc2 = TensorParallelRowLinear.load(
config, prefix=f"{prefix}.fc2", weights=weights, bias=config.enable_bias
)
self.final_layer_norm = nn.LayerNorm.load(
prefix=f"{prefix}.final_layer_norm", weights=weights, eps=EPS
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
) -> Tuple[
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size
`(encoder_attention_heads,)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
residual = hidden_states
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if self.do_layer_norm_before:
hidden_states = self.self_attn_layer_norm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
past_key_value=past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
hidden_states = nn.functional.dropout(
hidden_states, p=self.dropout, training=self.training
)
hidden_states = residual + hidden_states
# 350m applies layer norm AFTER attention
if not self.do_layer_norm_before:
hidden_states = self.self_attn_layer_norm(hidden_states)
# Fully Connected
hidden_states_shape = hidden_states.shape
hidden_states = hidden_states.reshape(-1, hidden_states.size(-1))
residual = hidden_states
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if self.do_layer_norm_before:
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
hidden_states = nn.functional.dropout(
hidden_states, p=self.dropout, training=self.training
)
hidden_states = (residual + hidden_states).view(hidden_states_shape)
# 350m applies layer norm AFTER attention
if not self.do_layer_norm_before:
hidden_states = self.final_layer_norm(hidden_states)
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
class OPTPreTrainedModel(PreTrainedModel):
config_class = OPTConfig
class OPTDecoder(OPTPreTrainedModel):
def __init__(self, config: OPTConfig, weights):
super().__init__(config)
self.dropout = config.dropout
self.layerdrop = config.layerdrop
self.padding_idx = config.pad_token_id
self.max_target_positions = config.max_position_embeddings
self.vocab_size = config.vocab_size
self.embed_tokens = TensorParallelEmbedding(
prefix="model.decoder.embed_tokens", weights=weights
)
self.embed_positions = OPTLearnedPositionalEmbedding(weights)
if config.word_embed_proj_dim != config.hidden_size:
self.project_out = FastLinear.load(
config, prefix="model.decoder.project_out", bias=False
)
else:
self.project_out = None
if config.word_embed_proj_dim != config.hidden_size:
self.project_in = FastLinear.load(
config, prefix="model.decoder.project_in", bias=False
)
else:
self.project_in = None
# Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
# with checkpoints that have been fine-tuned before transformers v4.20.1
# see https://github.com/facebookresearch/metaseq/pull/164
if config.do_layer_norm_before and not config._remove_final_layer_norm:
self.final_layer_norm = nn.LayerNorm.load(
prefix="model.decoder.final_layer_norm", weights=weights, eps=EPS
)
else:
self.final_layer_norm = None
self.layers = nn.ModuleList(
[
OPTDecoderLayer(layer_id, config, weights)
for layer_id in range(config.num_hidden_layers)
]
)
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
).to(inputs_embeds.device)
combined_attention_mask = (
expanded_attn_mask
if combined_attention_mask is None
else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
batch_size, seq_length = input_shape
past_key_values_length = (
past_key_values[0][0].shape[2] if past_key_values is not None else 0
)
# required mask seq length can be calculated via length of past
mask_seq_length = past_key_values_length + seq_length
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
batch_size, mask_seq_length, device=inputs_embeds.device
)
causal_attention_mask = self._prepare_decoder_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)
pos_embeds = self.embed_positions(attention_mask, past_key_values_length)
if self.project_in is not None:
inputs_embeds = self.project_in(inputs_embeds)
hidden_states = inputs_embeds + pos_embeds
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
# check if head_mask has a correct number of layers specified if desired
for attn_mask, mask_name in zip([head_mask], ["head_mask"]):
if attn_mask is not None:
if attn_mask.size()[0] != (len(self.layers)):
raise ValueError(
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
f" {head_mask.size()[0]}."
)
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states:
all_hidden_states += (hidden_states,)
dropout_probability = random.uniform(0, 1)
if self.training and (dropout_probability < self.layerdrop):
continue
past_key_value = (
past_key_values[idx] if past_key_values is not None else None
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
if self.final_layer_norm is not None:
hidden_states = self.final_layer_norm(hidden_states)
if self.project_out is not None:
hidden_states = self.project_out(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class OPTModel(OPTPreTrainedModel):
def __init__(self, config: OPTConfig, weights):
super().__init__(config)
self.decoder = OPTDecoder(config, weights)
# Initialize weights and apply final processing
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
decoder_outputs = self.decoder(
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if not return_dict:
return decoder_outputs
return BaseModelOutputWithPast(
last_hidden_state=decoder_outputs.last_hidden_state,
past_key_values=decoder_outputs.past_key_values,
hidden_states=decoder_outputs.hidden_states,
attentions=decoder_outputs.attentions,
)
class OPTForCausalLM(OPTPreTrainedModel):
def __init__(self, config, weights):
super().__init__(config)
self.model = OPTModel(config, weights)
self.lm_head = TensorParallelHead.load(
config, prefix="model.decoder.embed_tokens", weights=weights
)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model.decoder(
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
logits = self.lm_head(outputs[0]).contiguous()
loss = None
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
**kwargs,
):
if past_key_values:
input_ids = input_ids[:, -1:]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(
past_state.index_select(0, beam_idx) for past_state in layer_past
),
)
return reordered_past

File diff suppressed because it is too large Load Diff

View File

@ -1,154 +1,25 @@
import torch
import torch.distributed
from accelerate import init_empty_weights
from opentelemetry import trace
from pathlib import Path
from safetensors import safe_open
from transformers import AutoConfig
from transformers.models.llama import LlamaTokenizer
from typing import Optional, List
from typing import Optional
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM,
TensorParallelEmbedding,
TensorParallelRowLinear,
TensorParallelColumnLinear,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
download_weights,
weight_hub_files,
LocalEntryNotFoundError,
Weights,
)
tracer = trace.get_tracer(__name__)
class FlashLlama(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.float16
else:
raise NotImplementedError("FlashLlama is only available on GPU")
tokenizer = LlamaTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
# We do not use from_pretrained as we modified the model internal module layout
try:
filenames = weight_files(model_id, revision, ".bin")
# Local files not found
except LocalEntryNotFoundError:
hub_files = weight_hub_files(model_id, revision, ".bin")
filenames = download_weights(hub_files, model_id, revision)
with init_empty_weights():
model = FlashLlamaForCausalLM(config)
self.load_weights(model, filenames, quantize, device, dtype)
super(FlashCausalLM, self).__init__(
model=model.to(device),
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
device=device,
)
@staticmethod
def load_weights(
model,
filenames: List[Path],
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
):
for filename in filenames:
state_dict = torch.load(filename, map_location="cpu")
for key, value in state_dict.items():
value = value.to(device if quantize is None else "cpu").to(dtype)
layer_name = ".".join(key.split(".")[:4])
# Fused qkv
if "q_proj" in key or "k_proj" in key or "v_proj" in key:
final_key = layer_name + ".query_key_value.weight"
# Fused gate and up projs
elif "gate_proj" in key or "up_proj" in key:
final_key = layer_name + ".gate_up_proj.weight"
else:
final_key = key
module_name, param_name = final_key.rsplit(".", 1)
module = model.get_submodule(module_name)
try:
current_parameter_tensor = module._parameters[param_name]
except KeyError:
current_parameter_tensor = None
if current_parameter_tensor is not None:
if current_parameter_tensor.device == torch.device("meta"):
# Init qkv
if "query_key_value" in final_key:
module._parameters[param_name] = value.new_empty(
(value.shape[0] * 3, value.shape[1])
)
# Init gate and up proj
elif "gate_up_proj" in final_key:
module._parameters[param_name] = value.new_empty(
(value.shape[0] * 2, value.shape[1])
)
# Copy to correct slice
if "q_proj" in key:
module._parameters[param_name][: value.shape[0]] = value
elif "k_proj" in key:
module._parameters[param_name][
value.shape[0] : value.shape[0] * 2
] = value
elif "v_proj" in key:
module._parameters[param_name][value.shape[0] * 2 :] = value
elif "gate_proj" in key:
module._parameters[param_name][: value.shape[0]] = value
elif "up_proj" in key:
module._parameters[param_name][value.shape[0] :] = value
else:
if current_parameter_tensor.shape != value.shape:
raise ValueError(
f"Name {final_key} -- Current {current_parameter_tensor.shape} and got {value.shape}"
)
module._parameters[param_name] = value
else:
module._buffers[param_name] = value
del value
torch.cuda.empty_cache()
model.post_load_weights(quantize)
class FlashLlamaSharded(FlashLlama):
def __init__(
self,
model_id: str,
@ -176,24 +47,16 @@ class FlashLlamaSharded(FlashLlama):
)
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
with init_empty_weights():
model = FlashLlamaForCausalLM(config, process_group=self.process_group)
config.quantize = quantize
model = FlashLlamaForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
self.load_weights(
model,
filenames,
quantize=quantize,
device=device,
dtype=dtype,
rank=rank,
world_size=world_size,
)
torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__(
model=model.to(device),
model=model,
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
@ -201,114 +64,3 @@ class FlashLlamaSharded(FlashLlama):
rank=rank,
world_size=world_size,
)
@staticmethod
def load_weights(
model,
filenames: List[str],
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
rank: int,
world_size: int,
):
for file in filenames:
with safe_open(
file, framework="pt", device=str(device) if quantize is None else "cpu"
) as f:
for name in f.keys():
slice_ = f.get_slice(name)
layer_name = ".".join(name.split(".")[:4])
# Fused qkv
if "q_proj" in name or "k_proj" in name or "v_proj" in name:
final_name = layer_name + ".query_key_value.weight"
# Fused gate and up projs
elif "gate_proj" in name or "up_proj" in name:
final_name = layer_name + ".gate_up_proj.weight"
else:
final_name = name
module_name, param_name = final_name.rsplit(".", 1)
module = model.get_submodule(module_name)
if isinstance(module, TensorParallelColumnLinear):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif isinstance(module, TensorParallelRowLinear):
size = slice_.get_shape()[1]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[:, start:stop]
elif isinstance(module, TensorParallelEmbedding):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif name == "lm_head.weight" and model.model.tp_embeddings:
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
else:
try:
tensor = slice_[:]
except:
tensor = f.get_tensor(name)
tensor = tensor.contiguous().to(dtype)
try:
current_parameter_tensor = module._parameters[param_name]
except KeyError:
current_parameter_tensor = None
if current_parameter_tensor is not None:
if current_parameter_tensor.device == torch.device("meta"):
# Init qkv
if "query_key_value" in final_name:
module._parameters[param_name] = tensor.new_empty(
(tensor.shape[0] * 3, tensor.shape[1])
)
# Init gate and up proj
elif "gate_up_proj" in final_name:
module._parameters[param_name] = tensor.new_empty(
(tensor.shape[0] * 2, tensor.shape[1])
)
# Init gate and up proj
if "q_proj" in name:
module._parameters[param_name][: tensor.shape[0]] = tensor
elif "k_proj" in name:
module._parameters[param_name][
tensor.shape[0] : tensor.shape[0] * 2
] = tensor
elif "v_proj" in name:
module._parameters[param_name][
tensor.shape[0] * 2 :
] = tensor
elif "gate_proj" in name:
module._parameters[param_name][: tensor.shape[0]] = tensor
elif "up_proj" in name:
module._parameters[param_name][tensor.shape[0] :] = tensor
else:
if current_parameter_tensor.shape != tensor.shape:
raise ValueError(
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
)
module._parameters[param_name] = tensor
else:
module._buffers[param_name] = tensor
torch.cuda.empty_cache()
model.post_load_weights(quantize)

View File

@ -1,45 +1,24 @@
import torch
import torch.distributed
from accelerate import init_empty_weights
from opentelemetry import trace
from safetensors import safe_open
from transformers import AutoTokenizer, AutoConfig
from typing import Optional, List
from typing import Optional
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
FlashGPTNeoXForCausalLM,
TensorParallelEmbedding,
TensorParallelRowLinear,
TensorParallelColumnLinear,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
tracer = trace.get_tracer(__name__)
class FlashNeoX(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
super(FlashNeoX, self).__init__(
FlashGPTNeoXForCausalLM,
model_id,
revision,
quantize,
trust_remote_code=trust_remote_code,
)
class FlashNeoXSharded(FlashNeoX):
class FlashNeoXSharded(FlashCausalLM):
def __init__(
self,
model_id: str,
@ -65,23 +44,16 @@ class FlashNeoXSharded(FlashNeoX):
config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
with init_empty_weights():
model = FlashGPTNeoXForCausalLM(config, self.process_group)
torch.distributed.barrier(group=self.process_group)
self.load_weights(
model,
filenames,
quantize=quantize,
device=device,
dtype=dtype,
rank=rank,
world_size=world_size,
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
)
model = FlashGPTNeoXForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__(
model=model.to(device),
@ -92,79 +64,3 @@ class FlashNeoXSharded(FlashNeoX):
rank=rank,
world_size=world_size,
)
@staticmethod
def load_weights(
model,
filenames: List[str],
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
rank: int,
world_size: int,
):
parameters = dict(model.named_parameters())
for file in filenames:
with safe_open(
file, framework="pt", device=str(device) if quantize is None else "cpu"
) as f:
for name in f.keys():
module_name, param_name = name.rsplit(".", 1)
module = model.get_submodule(module_name)
current_parameter_tensor = parameters.get(name, None)
slice_ = f.get_slice(name)
if isinstance(module, TensorParallelColumnLinear):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif isinstance(module, TensorParallelRowLinear):
if param_name == "weight":
size = slice_.get_shape()[1]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[:, start:stop]
else:
tensor = slice_[:]
# XXX: Hack for Rowlinear to add the bias only once.
if rank != 0:
tensor = torch.zeros_like(tensor)
elif isinstance(module, TensorParallelEmbedding):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif name == "embed_out.weight" and model.gpt_neox.tp_embeddings:
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
else:
try:
tensor = slice_[:]
except:
tensor = f.get_tensor(name)
if (
current_parameter_tensor is not None
and current_parameter_tensor.shape != tensor.shape
):
raise ValueError(
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
)
tensor = tensor.contiguous().to(dtype)
if current_parameter_tensor is not None:
module._parameters[param_name] = tensor
else:
module._buffers[param_name] = tensor
model.post_load_weights(quantize)

View File

@ -1,119 +1,25 @@
import torch
import torch.distributed
from pathlib import Path
from accelerate import init_empty_weights
from opentelemetry import trace
from safetensors import safe_open
from transformers import AutoTokenizer, AutoConfig
from typing import Optional, List
from transformers import AutoTokenizer
from typing import Optional
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_rw_modeling import (
RWConfig,
FlashRWForCausalLM,
TensorParallelEmbedding,
TensorParallelRowLinear,
TensorParallelColumnLinear,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
download_weights,
weight_hub_files,
LocalEntryNotFoundError,
Weights,
)
tracer = trace.get_tracer(__name__)
class FlashRW(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.float16
else:
raise NotImplementedError("RW is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = RWConfig.from_pretrained(
model_id,
revision=revision,
)
# We do not use from_pretrained as it is too slow
try:
filenames = weight_files(model_id, revision, ".bin")
# Local files not found
except LocalEntryNotFoundError:
hub_files = weight_hub_files(model_id, revision, ".bin")
filenames = download_weights(hub_files, model_id, revision)
with init_empty_weights():
model = FlashRWForCausalLM(config)
self.load_weights(
model,
filenames,
quantize,
device,
dtype,
)
super(FlashCausalLM, self).__init__(
model=model.to(device),
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
device=device,
)
@staticmethod
def load_weights(
model: FlashRWForCausalLM,
filenames: List[Path],
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
):
for filename in filenames:
state_dict = torch.load(filename, map_location="cpu")
for key, value in state_dict.items():
value = value.to(device if quantize is None else "cpu").to(dtype)
module_name, param_name = key.rsplit(".", 1)
module = model.get_submodule(module_name)
try:
current_parameter_tensor = module._parameters[param_name]
if current_parameter_tensor.shape != value.shape:
raise ValueError(
f"Name {key} -- Current {current_parameter_tensor.shape} and got {value.shape}"
)
module._parameters[param_name] = value
except KeyError:
module._buffers[param_name] = value
del value
torch.cuda.empty_cache()
model.post_load_weights(quantize)
class FlashRWSharded(FlashRW):
class FlashRWSharded(FlashCausalLM):
def __init__(
self,
model_id: str,
@ -142,20 +48,12 @@ class FlashRWSharded(FlashRW):
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
with init_empty_weights():
model = FlashRWForCausalLM(config, self.process_group)
config.quantize = quantize
model = FlashRWForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
self.load_weights(
model,
filenames,
quantize=quantize,
device=device,
dtype=dtype,
rank=rank,
world_size=world_size,
)
torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__(
model=model.to(device),
@ -166,79 +64,3 @@ class FlashRWSharded(FlashRW):
rank=rank,
world_size=world_size,
)
@staticmethod
def load_weights(
model,
filenames: List[str],
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
rank: int,
world_size: int,
):
parameters = dict(model.named_parameters())
for file in filenames:
with safe_open(
file, framework="pt", device=str(device) if quantize is None else "cpu"
) as f:
for name in f.keys():
module_name, param_name = name.rsplit(".", 1)
module = model.get_submodule(module_name)
current_parameter_tensor = parameters.get(name, None)
slice_ = f.get_slice(name)
if isinstance(module, TensorParallelColumnLinear):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif isinstance(module, TensorParallelRowLinear):
if param_name == "weight":
size = slice_.get_shape()[1]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[:, start:stop]
else:
tensor = slice_[:]
# XXX: Hack for Rowlinear to add the bias only once.
if rank != 0:
tensor = torch.zeros_like(tensor)
elif isinstance(module, TensorParallelEmbedding):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif name == "lm_head.weight" and model.transformer.tp_embeddings:
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
else:
try:
tensor = slice_[:]
except:
tensor = f.get_tensor(name)
if (
current_parameter_tensor is not None
and current_parameter_tensor.shape != tensor.shape
):
raise ValueError(
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
)
tensor = tensor.contiguous().to(dtype)
if current_parameter_tensor is not None:
module._parameters[param_name] = tensor
else:
module._buffers[param_name] = tensor
model.post_load_weights(quantize)

View File

@ -1,197 +1,24 @@
import torch
import torch.distributed
from accelerate import init_empty_weights
from opentelemetry import trace
from safetensors import safe_open
from pathlib import Path
from transformers import AutoTokenizer, GPT2Config
from transformers import AutoTokenizer, AutoConfig
from typing import Optional, List
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
FlashSantacoderForCausalLM,
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
download_weights,
weight_hub_files,
LocalEntryNotFoundError,
Weights,
)
tracer = trace.get_tracer(__name__)
class FlashSantacoder(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.float16
else:
raise NotImplementedError("FlashSantacoder is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = GPT2Config.from_pretrained(
model_id,
revision=revision,
)
# We do not use from_pretrained as we modified the model internal module layout
filenames = weight_files(model_id, revision, ".safetensors")
with init_empty_weights():
model = FlashSantacoderForCausalLM(config)
self.load_weights(
model,
filenames,
quantize,
device,
dtype,
config.architectures[0].startswith("GPT2"),
)
super(FlashCausalLM, self).__init__(
model=model.to(device),
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
device=device,
)
@staticmethod
def load_weights(
model: FlashSantacoderForCausalLM,
filenames: List[Path],
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
transpose: bool,
):
for filename in filenames:
with safe_open(
filename,
framework="pt",
device=str(device) if quantize is None else "cpu",
) as f:
for key in f.keys():
value = f.get_tensor(key)
value = value.to(device if quantize is None else "cpu").to(dtype)
layer_name = ".".join(key.split(".")[:4])
# Fused qkv
if "q_attn.weight" in key or "kv_attn.weight" in key:
final_key = layer_name + ".c_attn.weight"
elif "q_attn.bias" in key or "kv_attn.bias" in key:
final_key = layer_name + ".c_attn.bias"
else:
final_key = key
module_name, param_name = final_key.rsplit(".", 1)
module = model.get_submodule(module_name)
try:
current_parameter_tensor = module._parameters[param_name]
except KeyError:
current_parameter_tensor = None
if current_parameter_tensor is not None:
if transpose and (
"c_fc.weight" in key
or "c_proj.weight" in key
or "q_attn.weight" in key
or "kv_attn.weight" in key
or "c_attn.weight" in key
):
# Tranpose as we use nn.Linear instead of Conv1D
value = value.T
if current_parameter_tensor.device == torch.device("meta"):
# Init qkv
if "c_attn.weight" in final_key:
module._parameters[param_name] = value.new_empty(
(
model.transformer.head_size
* (model.transformer.num_heads + 2),
value.shape[1],
)
)
elif "c_attn.bias" in final_key:
module._parameters[param_name] = value.new_empty(
(
model.transformer.head_size
* (model.transformer.num_heads + 2)
)
)
# Copy to correct slice
if "q_attn.weight" in key:
module._parameters[param_name][: value.shape[0]] = value
elif "q_attn.bias" in key:
module._parameters[param_name][: value.shape[0]] = value
elif "kv_attn.weight" in key:
module._parameters[param_name][
model.transformer.head_size
* model.transformer.num_heads :
] = value
elif "kv_attn.bias" in key:
module._parameters[param_name][
model.transformer.head_size
* model.transformer.num_heads :
] = value
else:
if current_parameter_tensor.shape != value.shape:
raise ValueError(
f"Name {final_key} -- Current {current_parameter_tensor.shape} and got {value.shape}"
)
module._parameters[param_name] = value
else:
module._buffers[param_name] = value
del value
if model.lm_head.weight.device == torch.device("meta"):
model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight)
torch.cuda.empty_cache()
model.post_load_weights(quantize)
uninitialized_parameters = []
for n, p in model.named_parameters():
if p.data.device == torch.device("meta"):
uninitialized_parameters.append(n)
if uninitialized_parameters:
raise RuntimeError(
f"found uninitialized parameters in model : {uninitialized_parameters}"
)
def decode(self, generated_ids: List[int]) -> str:
# Do not skip special tokens as they are used for custom parsing rules of the generated text
return self.tokenizer.decode(
generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
)
class FlashSantacoderSharded(FlashSantacoder):
class FlashSantacoderSharded(FlashCausalLM):
def __init__(
self,
model_id: str,
@ -214,28 +41,22 @@ class FlashSantacoderSharded(FlashSantacoder):
trust_remote_code=trust_remote_code,
)
config = GPT2Config.from_pretrained(
config = AutoConfig.from_pretrained(
model_id,
revision=revision,
trust_remote_code=True,
)
config.quantize = quantize
config.transpose = config.architectures[0].startswith("GPT2")
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
with init_empty_weights():
model = FlashSantacoderForCausalLM(config, self.process_group)
torch.distributed.barrier(group=self.process_group)
self.load_weights(
model,
filenames,
quantize=quantize,
device=device,
dtype=dtype,
rank=rank,
world_size=world_size,
transpose=config.architectures[0].startswith("GPT2"),
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
)
model = FlashSantacoderForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__(
model=model.to(device),
@ -247,164 +68,8 @@ class FlashSantacoderSharded(FlashSantacoder):
world_size=world_size,
)
@staticmethod
def load_weights(
model,
filenames: List[str],
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
rank: int,
world_size: int,
transpose: bool,
):
for file in filenames:
with safe_open(
file, framework="pt", device=str(device) if quantize is None else "cpu"
) as f:
for key in f.keys():
slice_ = f.get_slice(key)
layer_name = ".".join(key.split(".")[:4])
# Fused qkv
if "q_attn.weight" in key or "kv_attn.weight" in key:
final_key = layer_name + ".c_attn.weight"
elif "q_attn.bias" in key or "kv_attn.bias" in key:
final_key = layer_name + ".c_attn.bias"
else:
final_key = key
module_name, param_name = final_key.rsplit(".", 1)
module = model.get_submodule(module_name)
if isinstance(module, TensorParallelColumnLinear):
dim = 1 if transpose and "weight" in param_name else 0
size = slice_.get_shape()[dim]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = (
slice_[start:stop] if dim == 0 else slice_[:, start:stop]
)
elif isinstance(module, TensorParallelRowLinear):
if param_name == "weight":
dim = 0 if transpose else 1
size = slice_.get_shape()[dim]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = (
slice_[start:stop]
if dim == 0
else slice_[:, start:stop]
)
else:
tensor = slice_[:]
# XXX: Hack for Rowlinear to add the bias only once.
if rank != 0:
tensor = torch.zeros_like(tensor)
elif isinstance(module, TensorParallelEmbedding):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif key == "lm_head.weight" and model.transformer.tp_embeddings:
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
else:
try:
tensor = slice_[:]
except:
tensor = f.get_tensor(key)
tensor = tensor.contiguous().to(dtype)
try:
current_parameter_tensor = module._parameters[param_name]
except KeyError:
current_parameter_tensor = None
if current_parameter_tensor is not None:
if transpose and (
"c_fc.weight" in key
or "c_proj.weight" in key
or "q_attn.weight" in key
or "kv_attn.weight" in key
or "c_attn.weight" in key
):
# Tranpose as we use nn.Linear instead of Conv1D
tensor = tensor.T
if current_parameter_tensor.device == torch.device("meta"):
# Init qkv
if "c_attn.weight" in final_key:
module._parameters[param_name] = tensor.new_empty(
(
model.transformer.head_size
* (model.transformer.num_heads + 2),
tensor.shape[1],
)
)
elif "c_attn.bias" in final_key:
module._parameters[param_name] = tensor.new_empty(
(
model.transformer.head_size
* (model.transformer.num_heads + 2)
)
)
# Copy to correct slice
if "q_attn" in key:
size = tensor.shape[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = tensor[start:stop]
module._parameters[param_name][: tensor.shape[0]] = tensor
elif "kv_attn.weight" in key:
module._parameters[param_name][
model.transformer.head_size
* model.transformer.num_heads :
] = tensor
elif "kv_attn.bias" in key:
module._parameters[param_name][
model.transformer.head_size
* model.transformer.num_heads :
] = tensor
elif "c_attn" in key:
# Slice q_tensor by shard
q_tensor = tensor[: -2 * model.transformer.head_size]
block_size = q_tensor.shape[0] // world_size
start = rank * block_size
stop = (rank + 1) * block_size
q_tensor = q_tensor[start:stop]
module._parameters[param_name][
: q_tensor.shape[0]
] = q_tensor
# Kv tensor is copied for every shard
kv_tensor = tensor[-2 * model.transformer.head_size :]
module._parameters[param_name][
q_tensor.shape[0] :
] = kv_tensor
else:
if current_parameter_tensor.shape != tensor.shape:
raise ValueError(
f"Name {key} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
)
module._parameters[param_name] = tensor
else:
module._buffers[param_name] = tensor
if model.lm_head.weight.device == torch.device("meta"):
model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight)
torch.cuda.empty_cache()
model.post_load_weights(quantize)
def decode(self, generated_ids: List[int]) -> str:
# Do not skip special tokens as they are used for custom parsing rules of the generated text
return self.tokenizer.decode(
generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
)

View File

@ -2,41 +2,25 @@ import re
import torch
import torch.distributed
from typing import List, Optional, Type, Tuple
from typing import List, Optional, Type
from accelerate import init_empty_weights
from safetensors import safe_open
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
AutoConfig,
PreTrainedTokenizerBase,
)
from transformers.models.opt.parallel_layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
)
from text_generation_server.models import CausalLM
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.pb import generate_pb2
from text_generation_server.models.opt import OPT
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
from text_generation_server.utils import (
NextTokenChooser,
StoppingCriteria,
initialize_torch_distributed,
weight_files,
Weights,
)
HAS_BITS_AND_BYTES = True
try:
import bitsandbytes as bnb
from bitsandbytes.nn import Int8Params
except Exception as e:
HAS_BITS_AND_BYTES = False
# CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py
# we split individual characters inside special tokens like [START_DNA]
@ -168,33 +152,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
)
class Galactica(OPT):
@property
def batch_type(self) -> Type[CausalLMBatch]:
return GalacticaCausalLMBatch
def decode(self, generated_ids: List[int]) -> str:
# Do not skip special tokens as they are used for custom parsing rules of the generated text
return self.tokenizer.decode(
generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
)
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
"""Overwrite forward to ignore position_ids"""
# Model Forward
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
return outputs.logits, outputs.past_key_values
class GalacticaSharded(Galactica):
class GalacticaSharded(CausalLM):
def __init__(
self,
model_id: str,
@ -224,26 +182,17 @@ class GalacticaSharded(Galactica):
tp_parallel=True,
trust_remote_code=trust_remote_code,
)
config.quantize = quantize
tokenizer.pad_token_id = config.pad_token_id
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
with init_empty_weights():
model = AutoModelForCausalLM.from_config(
config, trust_remote_code=trust_remote_code
)
torch.distributed.barrier(group=self.process_group)
self.load_weights(
model,
filenames,
quantize=quantize,
device=device,
dtype=dtype,
rank=rank,
world_size=world_size,
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
)
model = OPTForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model=model,
@ -255,127 +204,15 @@ class GalacticaSharded(Galactica):
world_size=world_size,
)
@staticmethod
def load_weights(
model,
filenames: List[str],
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
rank: int,
world_size: int,
):
parameters = dict(model.named_parameters())
for file in filenames:
with safe_open(
file, framework="pt", device=str(device) if quantize is None else "cpu"
) as f:
for name in f.keys():
if name == "lm_head.weight":
continue
@property
def batch_type(self) -> Type[CausalLMBatch]:
return GalacticaCausalLMBatch
module_name, param_name = name.rsplit(".", 1)
module = model.get_submodule(module_name)
current_tensor = parameters[name]
slice_ = f.get_slice(name)
if isinstance(module, TensorParallelColumnLinear):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif isinstance(module, TensorParallelRowLinear):
if param_name == "weight":
size = slice_.get_shape()[1]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[:, start:stop]
else:
tensor = slice_[:]
# XXX: Hack for Rowlinear to add the bias only once.
if rank != 0:
tensor = torch.zeros_like(tensor)
elif isinstance(module, TensorParallelEmbedding):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
else:
tensor = slice_[:]
if current_tensor.shape != tensor.shape:
raise ValueError(
f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}"
)
tensor = tensor.contiguous().to(dtype)
if quantize == "bitsandbytes":
if not HAS_BITS_AND_BYTES:
raise ImportError(
"bitsandbytes is not available on your machine either because it is not installed "
"or you don't have a GPU.\n"
"You can install it with `pip install bitsandbytes`."
)
if (
type(module)
in [TensorParallelRowLinear, TensorParallelColumnLinear]
and param_name == "weight"
):
tensor = Int8Params(
tensor,
has_fp16_weights=False,
requires_grad=False,
).to(device)
state = bnb.MatmulLtState()
state.threshold = 6.0
state.has_fp16_weights = False
state.memory_efficient_backward = False
state.use_pool = True
state.CB = tensor.CB
state.SCB = tensor.SCB
tensor.CB = None
tensor.SCB = None
def replace_linear(state):
def linear(input, weight, bias):
out = bnb.matmul(
input,
weight,
state=state,
threshold=state.threshold,
bias=bias,
)
if state.CB is not None:
# we converted 8-bit row major to turing/ampere format
# in the first inference pass
# we no longer need the row-major weight
del state.CB
weight.data = state.CxB
return out
return linear
module.linear = replace_linear(state)
else:
tensor = tensor.to(device)
elif quantize == "gptq":
raise NotImplementedError("`gptq` is not implemented for now")
elif quantize is None:
tensor = tensor.to(device)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
module._parameters[param_name] = tensor
if name == "model.decoder.embed_tokens.weight":
model.lm_head._parameters["weight"] = tensor
def decode(self, generated_ids: List[int]) -> str:
# Do not skip special tokens as they are used for custom parsing rules of the generated text
return self.tokenizer.decode(
generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
)
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
@ -386,10 +223,4 @@ class GalacticaSharded(Galactica):
past_key_values=past_key_values,
use_cache=True,
)
# Logits are sharded, so we need to gather them
logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
torch.distributed.all_gather(logits, outputs.logits, group=self.process_group)
logits = torch.cat(logits, dim=2)
return logits, outputs.past_key_values
return outputs.logits, outputs.past_key_values

View File

@ -1,34 +1,22 @@
import torch
import torch.distributed
from typing import List, Optional
from typing import Optional
from accelerate import init_empty_weights
from safetensors import safe_open
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
AutoConfig,
)
from transformers.models.gpt_neox.parallel_layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
)
from text_generation_server.models import CausalLM
from text_generation_server.models.custom_modeling.neox_modeling import (
GPTNeoxForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
HAS_BITS_AND_BYTES = True
try:
import bitsandbytes as bnb
from bitsandbytes.nn import Int8Params
except Exception as e:
HAS_BITS_AND_BYTES = False
class GPTNeoxSharded(CausalLM):
def __init__(
@ -58,28 +46,18 @@ class GPTNeoxSharded(CausalLM):
config = AutoConfig.from_pretrained(
model_id,
revision=revision,
tp_parallel=True,
trust_remote_code=trust_remote_code,
)
config.quantize = quantize
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
with init_empty_weights():
model = AutoModelForCausalLM.from_config(
config, trust_remote_code=trust_remote_code
)
torch.distributed.barrier(group=self.process_group)
self.load_weights(
model,
filenames,
quantize=quantize,
device=device,
dtype=dtype,
rank=rank,
world_size=world_size,
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
)
model = GPTNeoxForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model=model,
@ -91,161 +69,16 @@ class GPTNeoxSharded(CausalLM):
world_size=world_size,
)
@staticmethod
def load_weights(
model,
filenames: List[str],
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
rank: int,
world_size: int,
):
parameters = dict(model.named_parameters())
for file in filenames:
with safe_open(
file, framework="pt", device=str(device) if quantize is None else "cpu"
) as f:
for name in f.keys():
module_name, param_name = name.rsplit(".", 1)
module = model.get_submodule(module_name)
current_parameter_tensor = parameters.get(name, None)
slice_ = f.get_slice(name)
if isinstance(module, TensorParallelColumnLinear):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif isinstance(module, TensorParallelRowLinear):
if param_name == "weight":
size = slice_.get_shape()[1]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[:, start:stop]
else:
tensor = slice_[:]
# XXX: Hack for Rowlinear to add the bias only once.
if rank != 0:
tensor = torch.zeros_like(tensor)
elif isinstance(module, TensorParallelEmbedding):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif name == "embed_out.weight" and model.gpt_neox.tp_embeddings:
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
else:
try:
tensor = slice_[:]
except:
tensor = f.get_tensor(name)
if (
current_parameter_tensor is not None
and current_parameter_tensor.shape != tensor.shape
):
raise ValueError(
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
)
tensor = tensor.contiguous().to(dtype)
if quantize == "bitsandbytes":
if not HAS_BITS_AND_BYTES:
raise ImportError(
"bitsandbytes is not available on your machine either because it is not installed "
"or you don't have a GPU.\n"
"You can install it with `pip install bitsandbytes`."
)
if (
type(module)
in [TensorParallelRowLinear, TensorParallelColumnLinear]
and param_name == "weight"
):
tensor = Int8Params(
tensor,
has_fp16_weights=False,
requires_grad=False,
).to(device)
state = bnb.MatmulLtState()
state.threshold = 6.0
state.has_fp16_weights = False
state.memory_efficient_backward = False
state.use_pool = True
state.CB = tensor.CB
state.SCB = tensor.SCB
tensor.CB = None
tensor.SCB = None
def replace_linear(state):
def linear(input, weight, bias):
out = bnb.matmul(
input,
weight,
state=state,
threshold=state.threshold,
bias=bias,
)
if state.CB is not None:
# we converted 8-bit row major to turing/ampere format
# in the first inference pass
# we no longer need the row-major weight
del state.CB
weight.data = state.CxB
return out
return linear
module.linear = replace_linear(state)
else:
tensor = tensor.to(device)
elif quantize == "gptq":
raise NotImplementedError("`gptq` is not implemented for now")
elif quantize is None:
tensor = tensor.to(device)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
if current_parameter_tensor is not None:
module._parameters[param_name] = tensor
else:
module._buffers[param_name] = tensor
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
):
if self.model.gpt_neox.tp_embeddings:
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=True,
)
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=True,
)
# Logits are sharded, so we need to gather them
logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
torch.distributed.all_gather(
logits, outputs.logits, group=self.process_group
)
logits = torch.cat(logits, dim=2)
return logits, outputs.past_key_values
# While the model itself is sharded, the embeddings might not as they might not be dividable by num-shard
else:
return super(GPTNeoxSharded, self).forward(
input_ids, attention_mask, position_ids, past_key_values
)
logits = outputs.logits
return logits, outputs.past_key_values

View File

@ -1,52 +1,22 @@
import torch
import torch.distributed
from typing import List, Optional, Tuple
from typing import Optional
from accelerate import init_empty_weights
from safetensors import safe_open
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
AutoConfig,
)
from transformers.models.opt.parallel_layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
)
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
from text_generation_server.models import CausalLM
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
HAS_BITS_AND_BYTES = True
try:
import bitsandbytes as bnb
from bitsandbytes.nn import Int8Params
except Exception as e:
HAS_BITS_AND_BYTES = False
class OPT(CausalLM):
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
"""Overwrite forward to ignore position_ids"""
# Model Forward
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
return outputs.logits, outputs.past_key_values
class OPTSharded(OPT):
class OPTSharded(CausalLM):
def __init__(
self,
model_id: str,
@ -73,29 +43,19 @@ class OPTSharded(OPT):
config = AutoConfig.from_pretrained(
model_id,
revision=revision,
tp_parallel=True,
trust_remote_code=trust_remote_code,
)
config.quantize = quantize
tokenizer.pad_token_id = config.pad_token_id
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
with init_empty_weights():
model = AutoModelForCausalLM.from_config(
config, trust_remote_code=trust_remote_code
)
torch.distributed.barrier(group=self.process_group)
self.load_weights(
model,
filenames,
quantize=quantize,
device=device,
dtype=dtype,
rank=rank,
world_size=world_size,
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
)
model = OPTForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model=model,
@ -107,128 +67,6 @@ class OPTSharded(OPT):
world_size=world_size,
)
@staticmethod
def load_weights(
model,
filenames: List[str],
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
rank: int,
world_size: int,
):
parameters = dict(model.named_parameters())
for file in filenames:
with safe_open(
file, framework="pt", device=str(device) if quantize is None else "cpu"
) as f:
for name in f.keys():
if name == "lm_head.weight":
continue
module_name, param_name = name.rsplit(".", 1)
module = model.get_submodule(module_name)
current_tensor = parameters[name]
slice_ = f.get_slice(name)
if isinstance(module, TensorParallelColumnLinear):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif isinstance(module, TensorParallelRowLinear):
if param_name == "weight":
size = slice_.get_shape()[1]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[:, start:stop]
else:
tensor = slice_[:]
# XXX: Hack for Rowlinear to add the bias only once.
if rank != 0:
tensor = torch.zeros_like(tensor)
elif isinstance(module, TensorParallelEmbedding):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
else:
tensor = slice_[:]
if current_tensor.shape != tensor.shape:
raise ValueError(
f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}"
)
tensor = tensor.contiguous().to(dtype)
if quantize == "bitsandbytes":
if not HAS_BITS_AND_BYTES:
raise ImportError(
"bitsandbytes is not available on your machine either because it is not installed "
"or you don't have a GPU.\n"
"You can install it with `pip install bitsandbytes`."
)
if (
type(module)
in [TensorParallelRowLinear, TensorParallelColumnLinear]
and param_name == "weight"
):
tensor = Int8Params(
tensor,
has_fp16_weights=False,
requires_grad=False,
).to(device)
state = bnb.MatmulLtState()
state.threshold = 6.0
state.has_fp16_weights = False
state.memory_efficient_backward = False
state.use_pool = True
state.CB = tensor.CB
state.SCB = tensor.SCB
tensor.CB = None
tensor.SCB = None
def replace_linear(state):
def linear(input, weight, bias):
out = bnb.matmul(
input,
weight,
state=state,
threshold=state.threshold,
bias=bias,
)
if state.CB is not None:
# we converted 8-bit row major to turing/ampere format
# in the first inference pass
# we no longer need the row-major weight
del state.CB
weight.data = state.CxB
return out
return linear
module.linear = replace_linear(state)
else:
tensor = tensor.to(device)
elif quantize == "gptq":
raise NotImplementedError("`gptq` is not implemented for now")
elif quantize is None:
tensor = tensor.to(device)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
module._parameters[param_name] = tensor
if name == "model.decoder.embed_tokens.weight":
model.lm_head._parameters["weight"] = tensor
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
):
@ -239,9 +77,4 @@ class OPTSharded(OPT):
use_cache=True,
)
# Logits are sharded, so we need to gather them
logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
torch.distributed.all_gather(logits, outputs.logits, group=self.process_group)
logits = torch.cat(logits, dim=2)
return logits, outputs.past_key_values
return outputs.logits, outputs.past_key_values

View File

@ -3,31 +3,20 @@ import torch.distributed
from typing import List, Optional, Tuple
from accelerate import init_empty_weights
from safetensors import safe_open
from transformers import (
AutoTokenizer,
AutoModelForSeq2SeqLM,
AutoConfig,
)
from text_generation_server.models import Seq2SeqLM
from text_generation_server.models.custom_modeling.t5_modeling import (
T5ForConditionalGeneration,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
from transformers.models.t5.parallel_layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
)
HAS_BITS_AND_BYTES = True
try:
import bitsandbytes as bnb
from bitsandbytes.nn import Int8Params
except ImportError as e:
HAS_BITS_AND_BYTES = False
class T5Sharded(Seq2SeqLM):
@ -46,6 +35,13 @@ class T5Sharded(Seq2SeqLM):
device = torch.device("cpu")
dtype = torch.float32
config = AutoConfig.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
)
config.quantize = quantize
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
@ -53,33 +49,16 @@ class T5Sharded(Seq2SeqLM):
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id,
revision=revision,
tp_parallel=True,
trust_remote_code=trust_remote_code,
)
tokenizer.bos_token_id = config.decoder_start_token_id
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
with init_empty_weights():
model = AutoModelForSeq2SeqLM.from_config(
config, trust_remote_code=trust_remote_code
)
torch.distributed.barrier(group=self.process_group)
self.load_weights(
model,
filenames,
quantize=quantize,
device=device,
dtype=dtype,
rank=rank,
world_size=world_size,
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group
)
model = T5ForConditionalGeneration(config, weights)
torch.distributed.barrier(group=self.process_group)
super(Seq2SeqLM, self).__init__(
model=model,
@ -91,151 +70,6 @@ class T5Sharded(Seq2SeqLM):
world_size=world_size,
)
@staticmethod
def load_weights(
model,
filenames: List[str],
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
rank: int,
world_size: int,
):
parameters = dict(model.named_parameters())
for file in filenames:
with safe_open(
file, framework="pt", device=str(device) if quantize is None else "cpu"
) as f:
for name in f.keys():
module_name, param_name = name.rsplit(".", 1)
module = model.get_submodule(module_name)
current_parameter_tensor = parameters.get(name, None)
slice_ = f.get_slice(name)
if isinstance(module, TensorParallelColumnLinear):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif isinstance(module, TensorParallelRowLinear):
if param_name == "weight":
size = slice_.get_shape()[1]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[:, start:stop]
else:
tensor = slice_[:]
# XXX: Hack for Rowlinear to add the bias only once.
if rank != 0:
tensor = torch.zeros_like(tensor)
elif isinstance(module, TensorParallelEmbedding):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif name == "lm_head.weight":
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif "relative_attention_bias.weight" in name:
size = slice_.get_shape()[1]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[:, start:stop]
else:
try:
tensor = slice_[:]
except:
tensor = f.get_tensor(name)
if (
current_parameter_tensor is not None
and current_parameter_tensor.shape != tensor.shape
):
raise ValueError(
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
)
tensor = tensor.contiguous()
# See: https://github.com/huggingface/transformers/blob/1fe1e3caa44617047f149bcc0c0b566343b714a7/src/transformers/models/t5/modeling_t5.py#LL316C15-L316C71
if module_name.endswith("wo"):
tensor = tensor.to(torch.float32)
else:
tensor = tensor.to(dtype)
if quantize == "bitsandbytes" and not module_name.endswith("wo"):
if not HAS_BITS_AND_BYTES:
raise ImportError(
"bitsandbytes is not available on your machine either because it is not installed "
"or you don't have a GPU.\n"
"You can install it with `pip install bitsandbytes`."
)
if (
type(module)
in [TensorParallelRowLinear, TensorParallelColumnLinear]
and param_name == "weight"
):
tensor = Int8Params(
tensor,
has_fp16_weights=False,
requires_grad=False,
).to(device)
state = bnb.MatmulLtState()
state.threshold = 6.0
state.has_fp16_weights = False
state.memory_efficient_backward = False
state.use_pool = True
state.CB = tensor.CB
state.SCB = tensor.SCB
tensor.CB = None
tensor.SCB = None
def replace_linear(state):
def linear(input, weight, bias):
out = bnb.matmul(
input,
weight,
state=state,
threshold=state.threshold,
bias=bias,
)
if state.CB is not None:
# we converted 8-bit row major to turing/ampere format
# in the first inference pass
# we no longer need the row-major weight
del state.CB
weight.data = state.CxB
return out
return linear
module.linear = replace_linear(state)
else:
tensor = tensor.to(device)
elif quantize == "gptq" and not module_name.endswith("wo"):
raise NotImplementedError("`gptq` is not implemented for now")
elif quantize is None or module_name.endswith("wo"):
tensor = tensor.to(device)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
if current_parameter_tensor is not None:
module._parameters[param_name] = tensor
else:
module._buffers[param_name] = tensor
def forward(
self,
input_ids,
@ -260,13 +94,8 @@ class T5Sharded(Seq2SeqLM):
use_cache=True,
)
# Logits are sharded, so we need to gather them
logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
torch.distributed.all_gather(logits, outputs.logits, group=self.process_group)
logits = torch.cat(logits, dim=2)
return (
logits,
outputs.logits,
outputs.encoder_last_hidden_state,
outputs.past_key_values,
)

View File

@ -1,5 +1,6 @@
from text_generation_server.utils.convert import convert_file, convert_files
from text_generation_server.utils.dist import initialize_torch_distributed
from text_generation_server.utils.weights import Weights
from text_generation_server.utils.hub import (
weight_files,
weight_hub_files,
@ -35,4 +36,5 @@ __all__ = [
"StoppingCriteria",
"StopSequenceCriteria",
"FinishReason",
"Weights",
]

View File

@ -4,6 +4,37 @@ import torch
from datetime import timedelta
class FakeBarrier:
def wait(self):
pass
class FakeGroup:
def __init__(self, rank, size):
self._rank = rank
self._size = size
def allreduce(self, *args, **kwargs):
return FakeBarrier()
def allgather(self, inputs, local_tensor, **kwargs):
assert (
len(inputs[0]) == len(local_tensor) == 1
), f"{len(inputs[0])} != {len(local_tensor)} != 1, and the FakeGroup is supposed to join on simple tensors"
for input_ in inputs:
input_[0].data = local_tensor[0].data
return FakeBarrier()
def barrier(self, *args, **kwargs):
return FakeBarrier()
def size(self):
return self._size
def rank(self):
return self._rank
def initialize_torch_distributed():
rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
@ -23,13 +54,18 @@ def initialize_torch_distributed():
backend = "gloo"
options = None
# Call the init process.
torch.distributed.init_process_group(
backend=backend,
world_size=world_size,
rank=rank,
timeout=timedelta(seconds=60),
pg_options=options,
)
if world_size == 1:
return FakeGroup(rank, world_size), rank, world_size
else:
if os.getenv("DEBUG", None) == "1":
return FakeGroup(rank, world_size), rank, world_size
# Call the init process.
torch.distributed.init_process_group(
backend=backend,
world_size=world_size,
rank=rank,
timeout=timedelta(seconds=60),
pg_options=options,
)
return torch.distributed.group.WORLD, rank, world_size
return torch.distributed.group.WORLD, rank, world_size

View File

@ -1,176 +1,240 @@
import torch
import torch.distributed
from torch import nn
from torch.nn import functional as F
from typing import Optional
from typing import List
HAS_BITS_AND_BYTES = True
try:
from bitsandbytes.nn import Linear8bitLt
except ImportError as e:
import bitsandbytes as bnb
from bitsandbytes.nn import Int8Params
except ImportError:
HAS_BITS_AND_BYTES = False
from accelerate import init_empty_weights
class FastLinear(nn.Linear):
# Monkey patching
@classmethod
def load_layer_norm(cls, prefix, weights, eps):
weight = weights.get_tensor(f"{prefix}.weight")
bias = weights.get_tensor(f"{prefix}.bias")
with init_empty_weights():
ln = cls(weight.shape, eps=eps)
ln.weight = nn.Parameter(weight)
ln.bias = nn.Parameter(bias)
return ln
torch.nn.LayerNorm.load = load_layer_norm
class FastLinear(nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
weight,
bias,
) -> None:
super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype)
self.quantized = False
self.bnb_linear = None
def prepare_weights(self, quantize: Optional[str] = None):
if quantize == "bitsandbytes":
if not HAS_BITS_AND_BYTES:
raise ImportError(
"bitsandbytes is not available on your machine either because it is not installed "
"or you don't have a GPU.\n"
"You can install it with `pip install bitsandbytes`."
)
self.quantized = True
self.bnb_linear = Linear8bitLt(
self.in_features,
self.out_features,
has_fp16_weights=False,
threshold=6.0,
bias=False,
)
# Copy data to bnb_linear
self.bnb_linear.weight.data = self.weight.data
if self.bias is not None:
self.bnb_linear.bias = nn.Parameter(self.bias)
# Delete reference to data
self.weight = None
super().__init__()
self.weight = nn.Parameter(weight)
if bias is not None:
self.bias = nn.Parameter(bias)
else:
self.bias = None
elif quantize == "gptq":
raise NotImplementedError("`gptq` is not implemented for now")
elif quantize is None:
self.weight = nn.Parameter(self.weight.T)
@classmethod
def load(cls, config, prefix: str, weights, bias: bool):
weight = weights.get_tensor(f"{prefix}.weight")
if bias:
bias = weights.get_tensor(f"{prefix}.bias")
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
bias = None
return cls(weight, bias)
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.quantized:
return self.bnb_linear(input)
else:
if self.bias is not None:
return torch.addmm(self.bias, input, self.weight)
return torch.matmul(input, self.weight)
return F.linear(input, self.weight, self.bias)
class TensorParallelColumnLinear(FastLinear):
class Linear8bitLt(nn.Module):
def __init__(
self,
in_features,
out_features,
process_group: torch.distributed.ProcessGroup,
bias=True,
device=None,
dtype=None,
weight,
bias,
has_fp16_weights=True,
memory_efficient_backward=False,
threshold=0.0,
index=None,
):
self.process_group = process_group
self.tp_world_size = process_group.size()
assert out_features % self.tp_world_size == 0
out_features = out_features // self.tp_world_size
super().__init__()
assert (
not memory_efficient_backward
), "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
self.state = bnb.MatmulLtState()
self.index = index
super().__init__(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
dtype=dtype,
# Necessary for stacked layers
self.state.threshold = threshold
self.state.has_fp16_weights = has_fp16_weights
self.state.memory_efficient_backward = memory_efficient_backward
if threshold > 0.0 and not has_fp16_weights:
self.state.use_pool = True
self.weight = Int8Params(
weight.data,
has_fp16_weights=has_fp16_weights,
requires_grad=has_fp16_weights,
)
self.weight.cuda(weight.device)
self.bias = bias
def init_8bit_state(self):
self.state.CB = self.weight.CB
self.state.SCB = self.weight.SCB
self.weight.CB = None
self.weight.SCB = None
class TensorParallelRowLinear(FastLinear):
def __init__(
self,
in_features,
out_features,
process_group: torch.distributed.ProcessGroup,
reduce=True,
bias=True,
device=None,
dtype=None,
):
self.process_group = process_group
self.tp_world_size = process_group.size()
self.reduce = reduce
assert in_features % self.tp_world_size == 0
in_features = in_features // self.tp_world_size
def forward(self, x: torch.Tensor):
self.state.is_training = self.training
if self.weight.CB is not None:
self.init_8bit_state()
super().__init__(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
dtype=dtype,
)
# weights are cast automatically as Int8Params, but the bias has to be cast manually
if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.to(x.dtype)
def forward(self, input: torch.Tensor) -> torch.Tensor:
out = super(TensorParallelRowLinear, self).forward(input)
if self.reduce:
torch.distributed.all_reduce(out, group=self.process_group)
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
if not self.state.has_fp16_weights:
if self.state.CB is not None and self.state.CxB is not None:
# we converted 8-bit row major to turing/ampere format in the first inference pass
# we no longer need the row-major weight
del self.state.CB
self.weight.data = self.state.CxB
return out
class TensorParallelEmbedding(nn.Embedding):
def __init__(
self,
num_embeddings,
embedding_dim,
process_group: torch.distributed.ProcessGroup,
reduce=True,
padding_idx=None,
max_norm=None,
norm_type=2.0,
scale_grad_by_freq=False,
sparse=False,
_weight=None,
device=None,
dtype=None,
):
self.reduce = reduce
def get_linear(weight, bias, quantize):
if quantize is None:
linear = FastLinear(weight, bias)
elif quantize == "bitsandbytes":
linear = Linear8bitLt(
weight,
bias,
has_fp16_weights=False,
threshold=6.0,
)
if bias is not None:
linear.bias = nn.Parameter(bias)
elif quantize == "gptq":
raise NotImplementedError("Soon")
else:
raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
return linear
class SuperLayer(nn.Module):
def __init__(self, linear):
super().__init__()
self.linear = linear
def forward(self, x):
return self.linear.forward(x)
class TensorParallelHead(SuperLayer):
def __init__(self, linear, process_group):
super().__init__(linear)
self.process_group = process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.original_num_embeddings = num_embeddings
assert num_embeddings % self.tp_world_size == 0
block_size = num_embeddings // self.tp_world_size
# inputs in `[min_id, max_id[` are handled by `self` to get embeddings
self.min_id = self.tp_rank * block_size
self.max_id = (self.tp_rank + 1) * block_size
# Additional entry that will map to zero
# Used for masking
self.null_idx = block_size
super().__init__(
block_size,
embedding_dim,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse,
_weight=_weight,
device=device,
dtype=dtype,
@staticmethod
def load(config, prefix: str, weights):
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
return TensorParallelHead(
get_linear(weight, bias=None, quantize=config.quantize),
process_group=weights.process_group,
)
def add_null_idx(self):
def forward(self, input: torch.Tensor) -> torch.Tensor:
output = super().forward(input)
# Logits are sharded, so we need to gather them
world_output = [
torch.empty_like(output) for _ in range(self.process_group.size())
]
torch.distributed.all_gather(world_output, output, group=self.process_group)
world_output = torch.cat(world_output, dim=-1)
return world_output
class TensorParallelColumnLinear(SuperLayer):
@classmethod
def load(cls, config, prefix: str, weights, bias: bool):
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
if bias:
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
else:
bias = None
return cls(get_linear(weight, bias, config.quantize))
@classmethod
def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
weight = torch.cat(w, dim=dim)
if bias:
b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes]
bias = torch.cat(b, dim=0)
else:
bias = None
return cls(get_linear(weight, bias, config.quantize))
class TensorParallelRowLinear(SuperLayer):
def __init__(self, linear, process_group):
super().__init__(linear)
self.process_group = process_group
@classmethod
def load(cls, config, prefix: str, weights, bias: bool):
weight = weights.get_sharded(f"{prefix}.weight", dim=1)
if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process
bias = weights.get_tensor(f"{prefix}.bias")
else:
bias = None
return cls(
get_linear(weight, bias, config.quantize),
process_group=weights.process_group,
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
out = super().forward(input)
torch.distributed.all_reduce(out, group=self.process_group)
return out
class TensorParallelEmbedding(nn.Module):
def __init__(self, prefix: str, weights, reduce=True):
super().__init__()
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
num_embeddings = weights.get_shape(f"{prefix}.weight")[0]
process_group = weights.process_group
world_size = process_group.size()
rank = process_group.rank()
block_size = num_embeddings // world_size
self.min_id = rank * block_size
self.max_id = min(num_embeddings, (rank + 1) * block_size)
self.null_idx = block_size
self.process_group = weights.process_group
self.reduce = reduce
"""Additional 0 entry used for masking"""
self.weight = nn.Parameter(F.pad(self.weight, (0, 0, 0, 1)))
self.weight = nn.Parameter(F.pad(weight, (0, 0, 0, 1)))
def forward(self, input: torch.Tensor) -> torch.Tensor:
# default all out of bounds values to `self.null_idx` that will then be mapped to 0
@ -180,7 +244,7 @@ class TensorParallelEmbedding(nn.Embedding):
self.null_idx,
input - self.min_id,
)
out = super().forward(input)
out = torch.nn.functional.embedding(input, self.weight)
if self.reduce:
torch.distributed.all_reduce(out, group=self.process_group)
return out
@ -232,7 +296,34 @@ try:
from flash_attn.layers.rotary import RotaryEmbedding
import rotary_emb
class PositionRotaryEmbedding(RotaryEmbedding):
class PositionRotaryEmbedding(nn.Module):
def __init__(self, inv_freq):
super().__init__()
self.register_buffer("inv_freq", inv_freq)
self._seq_len_cached = 0
self._cos_cached = None
self._sin_cached = None
self._cos_k_cached = None
self._sin_k_cached = None
@classmethod
def static(cls, dim, base, device):
inv_freq = 1.0 / (
base
** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
)
return cls(inv_freq)
@classmethod
def load(cls, prefix, weights):
# XXX: Always load this in float32 !
dtype = weights.dtype
weights.dtype = torch.float32
inv_freq = weights.get_tensor(f"{prefix}.inv_freq")
weights.dtype = dtype
return cls(inv_freq)
def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)

View File

@ -0,0 +1,77 @@
from pathlib import Path
from typing import List
from safetensors import safe_open
class Weights:
def __init__(self, filenames: List[Path], device, dtype, process_group):
routing = {}
for filename in filenames:
with safe_open(filename, framework="pytorch") as f:
for k in f.keys():
if k in routing:
raise RuntimeError(
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
)
routing[k] = filename
self.routing = routing
self.device = device
self.dtype = dtype
self.process_group = process_group
self._handles = {}
def _get_handle(self, filename):
if filename not in self._handles:
f = safe_open(filename, framework="pytorch")
self._handles[filename] = f
return self._handles[filename]
def get_filename(self, tensor_name: str) -> str:
filename = self.routing.get(tensor_name, None)
if filename is None:
raise RuntimeError(f"weight {tensor_name} does not exist")
return str(filename)
def _get_slice(self, tensor_name: str):
filename = self.get_filename(tensor_name)
f = self._get_handle(filename)
slice_ = f.get_slice(tensor_name)
return slice_
def get_shape(self, tensor_name: str):
return self._get_slice(tensor_name).get_shape()
def get_tensor(self, tensor_name: str):
filename = self.get_filename(tensor_name)
f = self._get_handle(filename)
tensor = f.get_tensor(tensor_name)
tensor = tensor.to(dtype=self.dtype)
tensor = tensor.to(device=self.device)
return tensor
def get_sharded(self, tensor_name: str, dim: int):
filename = self.get_filename(tensor_name)
world_size = self.process_group.size()
rank = self.process_group.rank()
f = self._get_handle(filename)
slice_ = f.get_slice(tensor_name)
size = slice_.get_shape()[dim]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
assert (
size % world_size == 0
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
if dim == 0:
tensor = slice_[start:stop]
elif dim == 1:
tensor = slice_[:, start:stop]
else:
raise NotImplementedError("Let's make that generic when needed")
tensor = tensor.to(dtype=self.dtype)
tensor = tensor.to(device=self.device)
return tensor