feat(server): Rework model loading (#344)

# What does this PR do?

Reworked the loading logic. Idea is to use cleaner loading code:

- Remove need for `no_init_weights`
- Remove all weird `bnb_linear` and `load_weights` and
`post_load_weights`.

New code layout:

- New class `Weights` in charge of handling loading the weights from
multiple files into appropiate tensors (potentially sharded)
- TP layers now are "shells", they contain the code to know what kind of
sharding we need + eventual `all_reduce`. They do not inherit from
linear, but they contain some kind of Linear instead
- the contained linear can be either FastLinear, BnbLinear or GPTq
Linear next.
- All modeling code is explictly made for sharding, process group is
just no-ops for non sharded code (removes a lot of test cases)

![Screenshot from 2023-05-19
23-19-59](https://github.com/huggingface/text-generation-inference/assets/204321/9a802654-74a3-488c-87a8-073743a6143f)

---------

Co-authored-by: Ubuntu <ubuntu@ip-172-31-41-161.taildb5d.ts.net>
Co-authored-by: Ubuntu <ubuntu@ip-172-31-41-161.ec2.internal>
Co-authored-by: OlivierDehaene <olivier@huggingface.co>
Co-authored-by: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com>
This commit is contained in:
Nicolas Patry 2023-06-08 14:51:52 +02:00 committed by GitHub
parent 19c41824cb
commit abd58ff82c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
43 changed files with 6806 additions and 2793 deletions

1
.gitignore vendored
View File

@ -1,3 +1,4 @@
.idea .idea
target target
router/tokenizer.json router/tokenizer.json
*__pycache__*

View File

@ -2,6 +2,8 @@
FROM lukemathwalker/cargo-chef:latest-rust-1.69 AS chef FROM lukemathwalker/cargo-chef:latest-rust-1.69 AS chef
WORKDIR /usr/src WORKDIR /usr/src
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
FROM chef as planner FROM chef as planner
COPY Cargo.toml Cargo.toml COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml COPY rust-toolchain.toml rust-toolchain.toml
@ -98,14 +100,14 @@ COPY server/Makefile-flash-att Makefile
RUN make build-flash-attention RUN make build-flash-attention
# Build Transformers CUDA kernels # Build Transformers CUDA kernels
FROM kernel-builder as transformers-builder FROM kernel-builder as custom-kernels-builder
WORKDIR /usr/src WORKDIR /usr/src
COPY server/Makefile-transformers Makefile COPY server/custom_kernels/ .
# Build specific version of transformers # Build specific version of transformers
RUN BUILD_EXTENSIONS="True" make build-transformers RUN python setup.py build
# Text Generation Inference base image # Text Generation Inference base image
FROM nvidia/cuda:11.8.0-base-ubuntu20.04 as base FROM nvidia/cuda:11.8.0-base-ubuntu20.04 as base
@ -136,11 +138,10 @@ COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
# Copy build artifacts from transformers builder # Copy build artifacts from transformers builder
COPY --from=transformers-builder /usr/src/transformers /usr/src/transformers COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39/custom_kernels /usr/src/custom-kernels/src/custom_kernels
COPY --from=transformers-builder /usr/src/transformers/build/lib.linux-x86_64-cpython-39/transformers /usr/src/transformers/src/transformers
# Install transformers dependencies # Install flash-attention dependencies
RUN cd /usr/src/transformers && pip install -e . --no-cache-dir && pip install einops --no-cache-dir RUN pip install einops --no-cache-dir
# Install server # Install server
COPY proto proto COPY proto proto

View File

@ -1,6 +1,9 @@
install-server: install-server:
cd server && make install cd server && make install
install-custom-kernels:
if [ "$$BUILD_EXTENSIONS" == "True" ]; then cd server/custom_kernels && python setup.py install; else echo "Custom kernels are disabled, you need set to BUILD_EXTENSION environment variable to 'True' in order to build them. (Please read the docs, kernels might not work on all hardware)"; fi
install-integration-tests: install-integration-tests:
cd integration-tests && pip install -r requirements.txt cd integration-tests && pip install -r requirements.txt
cd clients/python && pip install . cd clients/python && pip install .
@ -14,7 +17,7 @@ install-launcher:
install-benchmark: install-benchmark:
cd benchmark && cargo install --path . cd benchmark && cargo install --path .
install: install-server install-router install-launcher install: install-server install-router install-launcher install-custom-kernels
server-dev: server-dev:
cd server && make run-dev cd server && make run-dev

View File

@ -209,6 +209,7 @@ def launcher(event_loop):
num_shard: Optional[int] = None, num_shard: Optional[int] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
use_flash_attention: bool = True,
): ):
port = random.randint(8000, 10_000) port = random.randint(8000, 10_000)
master_port = random.randint(10_000, 20_000) master_port = random.randint(10_000, 20_000)
@ -240,6 +241,9 @@ def launcher(event_loop):
env = os.environ env = os.environ
env["LOG_LEVEL"] = "info,text_generation_router=debug" env["LOG_LEVEL"] = "info,text_generation_router=debug"
if not use_flash_attention:
env["USE_FLASH_ATTENTION"] = "false"
with subprocess.Popen( with subprocess.Popen(
args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env
) as process: ) as process:
@ -254,12 +258,16 @@ def launcher(event_loop):
process.stdout.close() process.stdout.close()
process.stderr.close() process.stderr.close()
if not use_flash_attention:
del env["USE_FLASH_ATTENTION"]
@contextlib.contextmanager @contextlib.contextmanager
def docker_launcher( def docker_launcher(
model_id: str, model_id: str,
num_shard: Optional[int] = None, num_shard: Optional[int] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
use_flash_attention: bool = True,
): ):
port = random.randint(8000, 10_000) port = random.randint(8000, 10_000)
@ -287,6 +295,9 @@ def launcher(event_loop):
gpu_count = num_shard if num_shard is not None else 1 gpu_count = num_shard if num_shard is not None else 1
env = {"LOG_LEVEL": "info,text_generation_router=debug"} env = {"LOG_LEVEL": "info,text_generation_router=debug"}
if not use_flash_attention:
env["USE_FLASH_ATTENTION"] = "false"
if HUGGING_FACE_HUB_TOKEN is not None: if HUGGING_FACE_HUB_TOKEN is not None:
env["HUGGING_FACE_HUB_TOKEN"] = HUGGING_FACE_HUB_TOKEN env["HUGGING_FACE_HUB_TOKEN"] = HUGGING_FACE_HUB_TOKEN
@ -310,6 +321,9 @@ def launcher(event_loop):
yield ContainerLauncherHandle(client, container.name, port) yield ContainerLauncherHandle(client, container.name, port)
if not use_flash_attention:
del env["USE_FLASH_ATTENTION"]
try: try:
container.stop() container.stop()
container.wait() container.wait()

View File

@ -11,17 +11,17 @@
}, },
{ {
"id": 1459, "id": 1459,
"logprob": -5.6289062, "logprob": -5.6328125,
"text": " print" "text": " print"
}, },
{ {
"id": 81, "id": 81,
"logprob": -1.6005859, "logprob": -1.6035156,
"text": "_" "text": "_"
}, },
{ {
"id": 7656, "id": 7656,
"logprob": -5.9921875, "logprob": -5.9882812,
"text": "hello" "text": "hello"
} }
], ],
@ -59,19 +59,19 @@
}, },
{ {
"id": 10896, "id": 10896,
"logprob": -0.3659668, "logprob": -0.38549805,
"special": false, "special": false,
"text": " World" "text": " World"
}, },
{ {
"id": 657, "id": 657,
"logprob": -0.49804688, "logprob": -0.5229492,
"special": false, "special": false,
"text": "\")" "text": "\")"
}, },
{ {
"id": 203, "id": 203,
"logprob": -0.11279297, "logprob": -0.10632324,
"special": false, "special": false,
"text": "\n" "text": "\n"
}, },
@ -113,7 +113,7 @@
}, },
{ {
"id": 426, "id": 426,
"logprob": -0.051635742, "logprob": 0.0,
"special": false, "special": false,
"text": "name" "text": "name"
}, },

View File

@ -0,0 +1,113 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 50278,
"logprob": null,
"text": "<|USER|>"
},
{
"id": 1276,
"logprob": -4.5546875,
"text": "What"
},
{
"id": 434,
"logprob": -4.1992188,
"text": "'s"
},
{
"id": 634,
"logprob": -5.125,
"text": " your"
},
{
"id": 12315,
"logprob": -9.8984375,
"text": " mood"
},
{
"id": 3063,
"logprob": -4.0976562,
"text": " today"
},
{
"id": 32,
"logprob": -0.14562988,
"text": "?"
},
{
"id": 50279,
"logprob": -0.26733398,
"text": "<|ASSISTANT|>"
}
],
"seed": null,
"tokens": [
{
"id": 42,
"logprob": -0.86279297,
"special": false,
"text": "I"
},
{
"id": 1353,
"logprob": -0.94921875,
"special": false,
"text": "'m"
},
{
"id": 7016,
"logprob": -2.1835938,
"special": false,
"text": " sorry"
},
{
"id": 13,
"logprob": -0.074035645,
"special": false,
"text": ","
},
{
"id": 1394,
"logprob": -0.86376953,
"special": false,
"text": "You"
},
{
"id": 452,
"logprob": -1.2070312,
"special": false,
"text": " have"
},
{
"id": 247,
"logprob": -1.4365234,
"special": false,
"text": " a"
},
{
"id": 4327,
"logprob": -1.109375,
"special": false,
"text": " choice"
},
{
"id": 273,
"logprob": -0.93408203,
"special": false,
"text": " of"
},
{
"id": 752,
"logprob": -1.8808594,
"special": false,
"text": " what"
}
]
},
"generated_text": "I'm sorry,You have a choice of what"
}

View File

@ -0,0 +1,454 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 50278,
"logprob": null,
"text": "<|USER|>"
},
{
"id": 1276,
"logprob": -4.5546875,
"text": "What"
},
{
"id": 434,
"logprob": -4.1953125,
"text": "'s"
},
{
"id": 634,
"logprob": -5.125,
"text": " your"
},
{
"id": 12315,
"logprob": -9.8828125,
"text": " mood"
},
{
"id": 3063,
"logprob": -3.9980469,
"text": " today"
},
{
"id": 32,
"logprob": -0.14672852,
"text": "?"
},
{
"id": 50279,
"logprob": -0.26489258,
"text": "<|ASSISTANT|>"
}
],
"seed": null,
"tokens": [
{
"id": 42,
"logprob": -0.8618164,
"special": false,
"text": "I"
},
{
"id": 1353,
"logprob": -0.9506836,
"special": false,
"text": "'m"
},
{
"id": 7016,
"logprob": -2.1738281,
"special": false,
"text": " sorry"
},
{
"id": 13,
"logprob": -0.0758667,
"special": false,
"text": ","
},
{
"id": 1394,
"logprob": -0.9135742,
"special": false,
"text": "You"
},
{
"id": 452,
"logprob": -1.1445312,
"special": false,
"text": " have"
},
{
"id": 247,
"logprob": -1.4375,
"special": false,
"text": " a"
},
{
"id": 4327,
"logprob": -1.1103516,
"special": false,
"text": " choice"
},
{
"id": 273,
"logprob": -1.0058594,
"special": false,
"text": " of"
},
{
"id": 752,
"logprob": -1.921875,
"special": false,
"text": " what"
}
]
},
"generated_text": "I'm sorry,You have a choice of what"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 50278,
"logprob": null,
"text": "<|USER|>"
},
{
"id": 1276,
"logprob": -4.5546875,
"text": "What"
},
{
"id": 434,
"logprob": -4.1953125,
"text": "'s"
},
{
"id": 634,
"logprob": -5.125,
"text": " your"
},
{
"id": 12315,
"logprob": -9.8828125,
"text": " mood"
},
{
"id": 3063,
"logprob": -3.9980469,
"text": " today"
},
{
"id": 32,
"logprob": -0.14672852,
"text": "?"
},
{
"id": 50279,
"logprob": -0.26489258,
"text": "<|ASSISTANT|>"
}
],
"seed": null,
"tokens": [
{
"id": 42,
"logprob": -0.8618164,
"special": false,
"text": "I"
},
{
"id": 1353,
"logprob": -0.9506836,
"special": false,
"text": "'m"
},
{
"id": 7016,
"logprob": -2.1738281,
"special": false,
"text": " sorry"
},
{
"id": 13,
"logprob": -0.0758667,
"special": false,
"text": ","
},
{
"id": 1394,
"logprob": -0.9135742,
"special": false,
"text": "You"
},
{
"id": 452,
"logprob": -1.1445312,
"special": false,
"text": " have"
},
{
"id": 247,
"logprob": -1.4375,
"special": false,
"text": " a"
},
{
"id": 4327,
"logprob": -1.1103516,
"special": false,
"text": " choice"
},
{
"id": 273,
"logprob": -1.0058594,
"special": false,
"text": " of"
},
{
"id": 752,
"logprob": -1.921875,
"special": false,
"text": " what"
}
]
},
"generated_text": "I'm sorry,You have a choice of what"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 50278,
"logprob": null,
"text": "<|USER|>"
},
{
"id": 1276,
"logprob": -4.5546875,
"text": "What"
},
{
"id": 434,
"logprob": -4.1953125,
"text": "'s"
},
{
"id": 634,
"logprob": -5.125,
"text": " your"
},
{
"id": 12315,
"logprob": -9.8828125,
"text": " mood"
},
{
"id": 3063,
"logprob": -3.9980469,
"text": " today"
},
{
"id": 32,
"logprob": -0.14672852,
"text": "?"
},
{
"id": 50279,
"logprob": -0.26489258,
"text": "<|ASSISTANT|>"
}
],
"seed": null,
"tokens": [
{
"id": 42,
"logprob": -0.8618164,
"special": false,
"text": "I"
},
{
"id": 1353,
"logprob": -0.9506836,
"special": false,
"text": "'m"
},
{
"id": 7016,
"logprob": -2.1738281,
"special": false,
"text": " sorry"
},
{
"id": 13,
"logprob": -0.0758667,
"special": false,
"text": ","
},
{
"id": 1394,
"logprob": -0.9135742,
"special": false,
"text": "You"
},
{
"id": 452,
"logprob": -1.1445312,
"special": false,
"text": " have"
},
{
"id": 247,
"logprob": -1.4375,
"special": false,
"text": " a"
},
{
"id": 4327,
"logprob": -1.1103516,
"special": false,
"text": " choice"
},
{
"id": 273,
"logprob": -1.0058594,
"special": false,
"text": " of"
},
{
"id": 752,
"logprob": -1.921875,
"special": false,
"text": " what"
}
]
},
"generated_text": "I'm sorry,You have a choice of what"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 50278,
"logprob": null,
"text": "<|USER|>"
},
{
"id": 1276,
"logprob": -4.5546875,
"text": "What"
},
{
"id": 434,
"logprob": -4.1953125,
"text": "'s"
},
{
"id": 634,
"logprob": -5.125,
"text": " your"
},
{
"id": 12315,
"logprob": -9.8828125,
"text": " mood"
},
{
"id": 3063,
"logprob": -3.9980469,
"text": " today"
},
{
"id": 32,
"logprob": -0.14672852,
"text": "?"
},
{
"id": 50279,
"logprob": -0.26489258,
"text": "<|ASSISTANT|>"
}
],
"seed": null,
"tokens": [
{
"id": 42,
"logprob": -0.8618164,
"special": false,
"text": "I"
},
{
"id": 1353,
"logprob": -0.9506836,
"special": false,
"text": "'m"
},
{
"id": 7016,
"logprob": -2.1738281,
"special": false,
"text": " sorry"
},
{
"id": 13,
"logprob": -0.0758667,
"special": false,
"text": ","
},
{
"id": 1394,
"logprob": -0.9135742,
"special": false,
"text": "You"
},
{
"id": 452,
"logprob": -1.1445312,
"special": false,
"text": " have"
},
{
"id": 247,
"logprob": -1.4375,
"special": false,
"text": " a"
},
{
"id": 4327,
"logprob": -1.1103516,
"special": false,
"text": " choice"
},
{
"id": 273,
"logprob": -1.0058594,
"special": false,
"text": " of"
},
{
"id": 752,
"logprob": -1.921875,
"special": false,
"text": " what"
}
]
},
"generated_text": "I'm sorry,You have a choice of what"
}
]

View File

@ -0,0 +1,163 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 50278,
"logprob": null,
"text": "<|prompter|>"
},
{
"id": 1276,
"logprob": -8.0234375,
"text": "What"
},
{
"id": 310,
"logprob": -5.4179688,
"text": " is"
},
{
"id": 247,
"logprob": -2.1542969,
"text": " a"
},
{
"id": 1167,
"logprob": -5.359375,
"text": " mem"
},
{
"id": 70,
"logprob": -0.006038666,
"text": "e"
},
{
"id": 13,
"logprob": -7.328125,
"text": ","
},
{
"id": 285,
"logprob": -0.3173828,
"text": " and"
},
{
"id": 752,
"logprob": -2.0625,
"text": " what"
},
{
"id": 434,
"logprob": -5.7734375,
"text": "'s"
},
{
"id": 253,
"logprob": -0.74072266,
"text": " the"
},
{
"id": 2892,
"logprob": -6.5898438,
"text": " history"
},
{
"id": 3212,
"logprob": -2.2949219,
"text": " behind"
},
{
"id": 436,
"logprob": -11.40625,
"text": " this"
},
{
"id": 3159,
"logprob": -2.1113281,
"text": " word"
},
{
"id": 32,
"logprob": -0.008056641,
"text": "?"
},
{
"id": 0,
"logprob": -2.3300781,
"text": "<|endoftext|>"
},
{
"id": 50281,
"logprob": -18.28125,
"text": "<|assistant|>"
}
],
"seed": null,
"tokens": [
{
"id": 510,
"logprob": -0.5878906,
"special": false,
"text": "The"
},
{
"id": 3159,
"logprob": -0.5449219,
"special": false,
"text": " word"
},
{
"id": 346,
"logprob": -0.05038452,
"special": false,
"text": " \""
},
{
"id": 6441,
"logprob": -0.002292633,
"special": false,
"text": "mem"
},
{
"id": 70,
"logprob": -1.3828278e-05,
"special": false,
"text": "e"
},
{
"id": 3,
"logprob": -0.0010242462,
"special": false,
"text": "\""
},
{
"id": 369,
"logprob": -0.090270996,
"special": false,
"text": " was"
},
{
"id": 806,
"logprob": -0.12719727,
"special": false,
"text": " first"
},
{
"id": 908,
"logprob": -0.016571045,
"special": false,
"text": " used"
},
{
"id": 275,
"logprob": -0.43432617,
"special": false,
"text": " in"
}
]
},
"generated_text": "The word \"meme\" was first used in"
}

View File

@ -0,0 +1,654 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 50278,
"logprob": null,
"text": "<|prompter|>"
},
{
"id": 1276,
"logprob": -8.0234375,
"text": "What"
},
{
"id": 310,
"logprob": -5.4179688,
"text": " is"
},
{
"id": 247,
"logprob": -2.1542969,
"text": " a"
},
{
"id": 1167,
"logprob": -5.359375,
"text": " mem"
},
{
"id": 70,
"logprob": -0.006038666,
"text": "e"
},
{
"id": 13,
"logprob": -7.328125,
"text": ","
},
{
"id": 285,
"logprob": -0.3173828,
"text": " and"
},
{
"id": 752,
"logprob": -2.0625,
"text": " what"
},
{
"id": 434,
"logprob": -5.7734375,
"text": "'s"
},
{
"id": 253,
"logprob": -0.74072266,
"text": " the"
},
{
"id": 2892,
"logprob": -6.5898438,
"text": " history"
},
{
"id": 3212,
"logprob": -2.2949219,
"text": " behind"
},
{
"id": 436,
"logprob": -11.40625,
"text": " this"
},
{
"id": 3159,
"logprob": -2.1113281,
"text": " word"
},
{
"id": 32,
"logprob": -0.008056641,
"text": "?"
},
{
"id": 0,
"logprob": -2.3300781,
"text": "<|endoftext|>"
},
{
"id": 50281,
"logprob": -18.28125,
"text": "<|assistant|>"
}
],
"seed": null,
"tokens": [
{
"id": 510,
"logprob": -0.5878906,
"special": false,
"text": "The"
},
{
"id": 3159,
"logprob": -0.5498047,
"special": false,
"text": " word"
},
{
"id": 346,
"logprob": -0.04815674,
"special": false,
"text": " \""
},
{
"id": 6441,
"logprob": -0.002313614,
"special": false,
"text": "mem"
},
{
"id": 70,
"logprob": -1.2636185e-05,
"special": false,
"text": "e"
},
{
"id": 3,
"logprob": -0.0010147095,
"special": false,
"text": "\""
},
{
"id": 369,
"logprob": -0.0859375,
"special": false,
"text": " was"
},
{
"id": 806,
"logprob": -0.12609863,
"special": false,
"text": " first"
},
{
"id": 908,
"logprob": -0.016601562,
"special": false,
"text": " used"
},
{
"id": 275,
"logprob": -0.38256836,
"special": false,
"text": " in"
}
]
},
"generated_text": "The word \"meme\" was first used in"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 50278,
"logprob": null,
"text": "<|prompter|>"
},
{
"id": 1276,
"logprob": -8.0234375,
"text": "What"
},
{
"id": 310,
"logprob": -5.421875,
"text": " is"
},
{
"id": 247,
"logprob": -2.1640625,
"text": " a"
},
{
"id": 1167,
"logprob": -5.40625,
"text": " mem"
},
{
"id": 70,
"logprob": -0.005420685,
"text": "e"
},
{
"id": 13,
"logprob": -7.2226562,
"text": ","
},
{
"id": 285,
"logprob": -0.26879883,
"text": " and"
},
{
"id": 752,
"logprob": -2.1992188,
"text": " what"
},
{
"id": 434,
"logprob": -5.46875,
"text": "'s"
},
{
"id": 253,
"logprob": -0.8017578,
"text": " the"
},
{
"id": 2892,
"logprob": -6.6796875,
"text": " history"
},
{
"id": 3212,
"logprob": -2.1972656,
"text": " behind"
},
{
"id": 436,
"logprob": -11.4453125,
"text": " this"
},
{
"id": 3159,
"logprob": -2.1933594,
"text": " word"
},
{
"id": 32,
"logprob": -0.007858276,
"text": "?"
},
{
"id": 0,
"logprob": -2.328125,
"text": "<|endoftext|>"
},
{
"id": 50281,
"logprob": -18.21875,
"text": "<|assistant|>"
}
],
"seed": null,
"tokens": [
{
"id": 510,
"logprob": -0.6201172,
"special": false,
"text": "The"
},
{
"id": 3159,
"logprob": -0.546875,
"special": false,
"text": " word"
},
{
"id": 346,
"logprob": -0.051879883,
"special": false,
"text": " \""
},
{
"id": 6441,
"logprob": -0.0020179749,
"special": false,
"text": "mem"
},
{
"id": 70,
"logprob": -9.059906e-06,
"special": false,
"text": "e"
},
{
"id": 3,
"logprob": -0.00096797943,
"special": false,
"text": "\""
},
{
"id": 369,
"logprob": -0.07940674,
"special": false,
"text": " was"
},
{
"id": 806,
"logprob": -0.12182617,
"special": false,
"text": " first"
},
{
"id": 908,
"logprob": -0.017227173,
"special": false,
"text": " used"
},
{
"id": 275,
"logprob": -0.44482422,
"special": false,
"text": " in"
}
]
},
"generated_text": "The word \"meme\" was first used in"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 50278,
"logprob": null,
"text": "<|prompter|>"
},
{
"id": 1276,
"logprob": -8.0234375,
"text": "What"
},
{
"id": 310,
"logprob": -5.421875,
"text": " is"
},
{
"id": 247,
"logprob": -2.1640625,
"text": " a"
},
{
"id": 1167,
"logprob": -5.40625,
"text": " mem"
},
{
"id": 70,
"logprob": -0.005420685,
"text": "e"
},
{
"id": 13,
"logprob": -7.2226562,
"text": ","
},
{
"id": 285,
"logprob": -0.26879883,
"text": " and"
},
{
"id": 752,
"logprob": -2.1992188,
"text": " what"
},
{
"id": 434,
"logprob": -5.46875,
"text": "'s"
},
{
"id": 253,
"logprob": -0.8017578,
"text": " the"
},
{
"id": 2892,
"logprob": -6.6796875,
"text": " history"
},
{
"id": 3212,
"logprob": -2.1972656,
"text": " behind"
},
{
"id": 436,
"logprob": -11.4453125,
"text": " this"
},
{
"id": 3159,
"logprob": -2.1933594,
"text": " word"
},
{
"id": 32,
"logprob": -0.007858276,
"text": "?"
},
{
"id": 0,
"logprob": -2.328125,
"text": "<|endoftext|>"
},
{
"id": 50281,
"logprob": -18.21875,
"text": "<|assistant|>"
}
],
"seed": null,
"tokens": [
{
"id": 510,
"logprob": -0.6201172,
"special": false,
"text": "The"
},
{
"id": 3159,
"logprob": -0.546875,
"special": false,
"text": " word"
},
{
"id": 346,
"logprob": -0.051879883,
"special": false,
"text": " \""
},
{
"id": 6441,
"logprob": -0.0020179749,
"special": false,
"text": "mem"
},
{
"id": 70,
"logprob": -9.059906e-06,
"special": false,
"text": "e"
},
{
"id": 3,
"logprob": -0.00096797943,
"special": false,
"text": "\""
},
{
"id": 369,
"logprob": -0.07940674,
"special": false,
"text": " was"
},
{
"id": 806,
"logprob": -0.12182617,
"special": false,
"text": " first"
},
{
"id": 908,
"logprob": -0.017227173,
"special": false,
"text": " used"
},
{
"id": 275,
"logprob": -0.44482422,
"special": false,
"text": " in"
}
]
},
"generated_text": "The word \"meme\" was first used in"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 50278,
"logprob": null,
"text": "<|prompter|>"
},
{
"id": 1276,
"logprob": -8.0234375,
"text": "What"
},
{
"id": 310,
"logprob": -5.421875,
"text": " is"
},
{
"id": 247,
"logprob": -2.1640625,
"text": " a"
},
{
"id": 1167,
"logprob": -5.40625,
"text": " mem"
},
{
"id": 70,
"logprob": -0.005420685,
"text": "e"
},
{
"id": 13,
"logprob": -7.2226562,
"text": ","
},
{
"id": 285,
"logprob": -0.26879883,
"text": " and"
},
{
"id": 752,
"logprob": -2.1992188,
"text": " what"
},
{
"id": 434,
"logprob": -5.46875,
"text": "'s"
},
{
"id": 253,
"logprob": -0.8017578,
"text": " the"
},
{
"id": 2892,
"logprob": -6.6796875,
"text": " history"
},
{
"id": 3212,
"logprob": -2.1972656,
"text": " behind"
},
{
"id": 436,
"logprob": -11.4453125,
"text": " this"
},
{
"id": 3159,
"logprob": -2.1933594,
"text": " word"
},
{
"id": 32,
"logprob": -0.007858276,
"text": "?"
},
{
"id": 0,
"logprob": -2.328125,
"text": "<|endoftext|>"
},
{
"id": 50281,
"logprob": -18.21875,
"text": "<|assistant|>"
}
],
"seed": null,
"tokens": [
{
"id": 510,
"logprob": -0.6201172,
"special": false,
"text": "The"
},
{
"id": 3159,
"logprob": -0.546875,
"special": false,
"text": " word"
},
{
"id": 346,
"logprob": -0.051879883,
"special": false,
"text": " \""
},
{
"id": 6441,
"logprob": -0.0020179749,
"special": false,
"text": "mem"
},
{
"id": 70,
"logprob": -1.04904175e-05,
"special": false,
"text": "e"
},
{
"id": 3,
"logprob": -0.0009560585,
"special": false,
"text": "\""
},
{
"id": 369,
"logprob": -0.08557129,
"special": false,
"text": " was"
},
{
"id": 806,
"logprob": -0.12084961,
"special": false,
"text": " first"
},
{
"id": 908,
"logprob": -0.01737976,
"special": false,
"text": " used"
},
{
"id": 275,
"logprob": -0.4025879,
"special": false,
"text": " in"
}
]
},
"generated_text": "The word \"meme\" was first used in"
}
]

View File

@ -37,8 +37,8 @@ async def test_flash_neox_load(flash_neox, generate_load, response_snapshot):
generated_texts = [r.generated_text for r in responses] generated_texts = [r.generated_text for r in responses]
assert len(generated_texts) == 4 assert len(generated_texts) == 4
assert generated_texts, all( assert all(
[text == generated_texts[0] for text in generated_texts] [text == generated_texts[0] for text in generated_texts]
) ), generated_texts
assert responses == response_snapshot assert responses == response_snapshot

View File

@ -0,0 +1,48 @@
import pytest
@pytest.fixture(scope="module")
def neox_handle(launcher):
with launcher(
"stabilityai/stablelm-tuned-alpha-3b", num_shard=1, use_flash_attention=False
) as handle:
yield handle
@pytest.fixture(scope="module")
async def neox(neox_handle):
await neox_handle.health(300)
return neox_handle.client
@pytest.mark.skip
@pytest.mark.asyncio
async def test_neox(neox, response_snapshot):
response = await neox.generate(
"<|USER|>What's your mood today?<|ASSISTANT|>",
max_new_tokens=10,
decoder_input_details=True,
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.skip
@pytest.mark.asyncio
async def test_neox_load(neox, generate_load, response_snapshot):
responses = await generate_load(
neox,
"<|USER|>What's your mood today?<|ASSISTANT|>",
max_new_tokens=10,
n=4,
)
generated_texts = [r.generated_text for r in responses]
assert len(generated_texts) == 4
assert generated_texts, all(
[text == generated_texts[0] for text in generated_texts]
)
assert responses == response_snapshot

View File

@ -0,0 +1,44 @@
import pytest
@pytest.fixture(scope="module")
def neox_sharded_handle(launcher):
with launcher(
"OpenAssistant/oasst-sft-1-pythia-12b", num_shard=2, use_flash_attention=False
) as handle:
yield handle
@pytest.fixture(scope="module")
async def neox_sharded(neox_sharded_handle):
await neox_sharded_handle.health(300)
return neox_sharded_handle.client
@pytest.mark.skip
@pytest.mark.asyncio
async def test_neox(neox_sharded, response_snapshot):
response = await neox_sharded.generate(
"<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>",
max_new_tokens=10,
decoder_input_details=True,
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.skip
@pytest.mark.asyncio
async def test_neox_load(neox_sharded, generate_load, response_snapshot):
responses = await generate_load(
neox_sharded,
"<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>",
max_new_tokens=10,
n=4,
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot

View File

@ -1,4 +1,5 @@
[pytest] [pytest]
addopts = --snapshot-warn-unused
asyncio_mode = auto asyncio_mode = auto
markers = markers =
private: marks tests as requiring an admin hf token (deselect with '-m "not private"') private: marks tests as requiring an admin hf token (deselect with '-m "not private"')

View File

@ -1,4 +1,3 @@
include Makefile-transformers
include Makefile-flash-att include Makefile-flash-att
unit-tests: unit-tests:
@ -17,7 +16,7 @@ install-torch:
# Install specific version of torch # Install specific version of torch
pip install torch --extra-index-url https://download.pytorch.org/whl/cu118 --no-cache-dir pip install torch --extra-index-url https://download.pytorch.org/whl/cu118 --no-cache-dir
install: gen-server install-torch install-transformers install: gen-server install-torch
pip install pip --upgrade pip install pip --upgrade
pip install -r requirements.txt pip install -r requirements.txt
pip install -e ".[bnb, accelerate]" pip install -e ".[bnb, accelerate]"

View File

@ -1,13 +0,0 @@
transformers_commit := 69009822aa7897ffab97afb814e38126b83f639e
transformers:
# Clone fork of transformers with custom CUDA kernels and sharding logic
pip install --upgrade setuptools
git clone https://github.com/OlivierDehaene/transformers.git
build-transformers: transformers
cd transformers && git fetch && git checkout $(transformers_commit) && python setup.py build
install-transformers: build-transformers
pip uninstall transformers -y || true
cd transformers && python setup.py install

View File

@ -0,0 +1,250 @@
#include <ATen/Dispatch.h>
#include <THC/THCAtomics.cuh>
#include <ATen/ATen.h>
#include <torch/torch.h>
#include <vector>
#include <optional>
/**
* Friendly reminder of how multithreading works in CUDA: https://developer.nvidia.com/blog/even-easier-introduction-cuda
* Check example at https://github.com/thomasw21/LinearTransformers/blob/main/model/attention/fast_weight/fast_weight_cuda.cu
**/
// Available in pytorch main
//#define DISPATCH_CASE_FLOATING_TYPES(...) \
// at::AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
// at::AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
// at::AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
// at::AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
/*
* Forward passes
*/
/**
* cast to fp32 if in fp16 + mask + softmax computation in fp32 + cast back to original dtype
**/
template<typename attention_scores_scalar, int64_t min_kv_length_shard_size_per_thread>
__global__ void forward_masked_softmax_kernel(
const torch::PackedTensorAccessor32<attention_scores_scalar, 2, torch::RestrictPtrTraits> attention_scores, // [B, KV]
const torch::PackedTensorAccessor32<bool, 2, torch::RestrictPtrTraits> mask, // [B, KV]
torch::PackedTensorAccessor32<attention_scores_scalar, 2, torch::RestrictPtrTraits> result, // [B, KV]
const int64_t effective_kv_length,
const dim3 blockDim,
const int64_t rows_per_block,
const int64_t kv_length,
const int64_t batch_size
) {
const auto row_id = threadIdx.x / effective_kv_length;
const auto effective_kv_length_id = threadIdx.x % effective_kv_length;
const auto kv_length_start = effective_kv_length_id * min_kv_length_shard_size_per_thread;
auto kv_length_end_ = (effective_kv_length_id + 1) * min_kv_length_shard_size_per_thread;
kv_length_end_ = (kv_length_end_ > kv_length) ? kv_length : kv_length_end_;
const auto kv_length_end = kv_length_end_;
const auto batch_id = blockIdx.x * rows_per_block + row_id;
// We need 2 float storage for each row, one for max computation, the other for normalizing exponential
extern __shared__ float temp_storage[];
const auto row_id_mem_offset = row_id * 2;
if (effective_kv_length_id == 0) {
temp_storage[row_id_mem_offset] = -std::numeric_limits<float>::infinity();
temp_storage[row_id_mem_offset + 1] = 0;
}
__syncthreads();
// Compute mask and max
if (batch_id < batch_size) {
float thread_max = -std::numeric_limits<float>::infinity();
for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
if (mask[batch_id][kv_length_id] == 0) {
const float candidate = attention_scores[batch_id][kv_length_id];
thread_max = (thread_max < candidate) ? candidate : thread_max;
}
}
if (thread_max != -std::numeric_limits<float>::infinity()) {
// TODO @thomasw21 with more memory we can probably compute a much faster `max-reduce` in parallel O(ln(n)) operations in each memory slot
gpuAtomicMax(&temp_storage[row_id_mem_offset], thread_max);
}
}
__syncthreads();
// Compute exp(elt - max) masked
float exponential[min_kv_length_shard_size_per_thread];
if (batch_id < batch_size) {
float thread_add = 0;
for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
if (mask[batch_id][kv_length_id] == 0) {
exponential[kv_length_id - kv_length_start] = std::exp(static_cast<float>(attention_scores[batch_id][kv_length_id]) - temp_storage[row_id_mem_offset]);
thread_add = thread_add + exponential[kv_length_id - kv_length_start];
} else {
exponential[kv_length_id - kv_length_start] = 0.;
}
}
if (thread_add > 0) {
// TODO @thomasw21 with more memory we can probably compute a much faster `sum-reduce` in parallel O(ln(n)) operations in each memory slot
gpuAtomicAdd(&temp_storage[row_id_mem_offset + 1], thread_add);
}
}
__syncthreads();
// Compute softmax
if (batch_id < batch_size) {
// If sum of all exponential is 0, we set the softmax values to 0
if (temp_storage[row_id_mem_offset + 1] == 0.) {
for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
result[batch_id][kv_length_id] = 0.;
}
} else {
for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
result[batch_id][kv_length_id] = static_cast<attention_scores_scalar>(exponential[kv_length_id - kv_length_start] / temp_storage[row_id_mem_offset + 1]);
}
}
}
}
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::tuple<at::Tensor, std::optional<std::vector<at::Tensor>>, at::Tensor> forward(
const at::Tensor query,
const at::Tensor key,
const at::Tensor value,
const std::optional<std::vector<at::Tensor>> layer_past,
const at::Tensor attention_mask,
const std::optional<at::Tensor> head_mask,
const float inv_norm_factor,
const int num_heads,
const bool use_cache
) {
auto query_layer = query;
auto key_layer = key;
auto value_layer = value;
if (layer_past) {
const auto past_key = (*layer_past).at(0);
const auto past_value = (*layer_past).at(1);
key_layer = at::cat({past_key, key_layer}, 2);
value_layer = at::cat({past_value, value_layer}, 2);
}
std::optional<std::vector<at::Tensor>> present;
if (use_cache) {
present = {key_layer, value_layer};
} else {
present = {};
}
const auto batch_size = query_layer.size(0);
const auto q_length = query_layer.size(2);
const auto attn_head_size = query_layer.size(3);
const auto batch_size_times_num_heads = batch_size * num_heads;
const auto kv_length = key_layer.size(2);
const auto query_view = query_layer.reshape({batch_size_times_num_heads, q_length, attn_head_size});
auto key_view = key_layer.reshape({batch_size_times_num_heads, kv_length, attn_head_size}).transpose(1, 2);
auto value_view = value_layer.reshape({batch_size_times_num_heads, kv_length, attn_head_size});
auto query_scaled = query_view * inv_norm_factor;
auto attention_scores = at::bmm(query_scaled, key_view);
// Computing `optionally_cast_fp16_to_fp32 + masked_fill + softmax + cast_to_intial_dtype`
at::Tensor attention_probs;
if (true) {
// TODO @thomasw21: it's easier to think of attention_scores as 2D tensors
const auto attention_scores_2d = attention_scores.view({batch_size_times_num_heads * q_length, kv_length});
const auto attention_mask_2d = attention_mask.view({batch_size_times_num_heads * q_length, kv_length});
// Custom kernel
attention_probs = at::empty_like(attention_scores_2d);
// Check that inputs and contiguous + cuda tensors
CHECK_INPUT(attention_scores_2d);
CHECK_INPUT(attention_mask_2d);
// TODO @thomas21: change by to this as it's cleaner when pytorch 1.13 comes out
// DISPATCH_CASE_FLOATING_TYPES(attention_scores.scalar_type(), "masked_softmax", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, attention_scores.scalar_type(), "masked_softmax", [&] {
/*
* Understanding how GPUs work: https://developer.nvidia.com/blog/cuda-refresher-cuda-programming-model/
* A100 specifications: https://images.nvidia.com/aem-dam/en-zz/Solutions/data-center/nvidia-ampere-architecture-whitepaper.pdf
* - SMs: 108
* - TPCs: 56 (What's that?)
* - Memory size: 40 GB
* - L2 Cache size: 40960 KB (shared across all SMs)
* - L1/Shared memory size: 192 KB (shared across all threads within a SM)
* - Max Threads / SM: 2048
* - Max Thread Blocks / SM: 32
*/
/*
* We should split [batch_size_times_num_heads_block, q_length] in seperate blocks and [batch_size_times_num_heads_block_size, kv_length] a single block
* with multiple threads as we need to `sync_threads` to run exponential sum.
* We maximise the usage of threads within a single block
*/
// TODO @thomasw21 figure out everything warp related:
// - why do they have to be power of 2
// TODO @thomas21 check why everyone is setting 1024 when officially it's 2048
const auto MAX_THREADS_PER_SM = 1024;
// TODO @thomasw21 figure out how to have longer sequences, currently the maximum is `max_kv_length = MAX_THREADS_PER_SM * MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD`
const auto MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD = 4;
// `effective_kv_length = ceil(kv_length / MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD)`
const auto effective_kv_length = (kv_length - 1)/ MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD + 1;
const auto rows_per_block = MAX_THREADS_PER_SM / effective_kv_length;
const auto num_blocks = (batch_size_times_num_heads * q_length - 1) / rows_per_block + 1;
const dim3 gridDim(num_blocks); // Number of blocks that run
const dim3 blockDim(MAX_THREADS_PER_SM); // Number of threads that run per block
const int shared_mem_forward = rows_per_block * 2 * sizeof(float);
// 192 * 2 ** 10
// const auto MAX_L1_MEMORY = 196608;
// const auto MAX_SMs = 108;
// TORCH_CHECK(batch_size_times_num_heads * q_length <= MAX_L1_MEMORY, "Shared memory exceeds 192KB limitation.");
// TORCH_CHECK(gridDim.x * gridDim.y * gridDim.z <= MAX_SMs, "A100s only have 108 SMs. Raising as require blocks is bigger.");
// TORCH_CHECK(blockDim.x * blockDim.y * blockDim.z <= MAX_THREADS_PER_SM, "A100s only have 2048 threads per block. Raising as require requested threads is higher.");
forward_masked_softmax_kernel<scalar_t, MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD><<<gridDim, blockDim, shared_mem_forward>>>(
attention_scores_2d.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
attention_mask_2d.packed_accessor32<bool, 2, torch::RestrictPtrTraits>(),
attention_probs.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
effective_kv_length,
blockDim,
rows_per_block,
kv_length,
batch_size_times_num_heads * q_length
);
});
attention_probs = attention_probs.view({batch_size_times_num_heads, q_length, kv_length});
} else {
// Pytorch C++ API
auto input_dtype = attention_scores.scalar_type();
if (input_dtype == at::ScalarType::Float) {
attention_scores = attention_scores.to(at::ScalarType::Float);
};
// TODO @thomasw21 Figure out how to get minimum value
auto attn_weights = attention_scores.masked_fill_(attention_mask, -1e34);
attention_probs = attn_weights.softmax(-1, at::ScalarType::Float).to(input_dtype);
}
auto context_layer = attention_probs.bmm(value_view);
// `_merge_heads`
context_layer = context_layer.view({batch_size, num_heads, q_length, attn_head_size});
context_layer = context_layer.permute({0, 2, 1, 3});
context_layer = context_layer.reshape({batch_size, q_length, attn_head_size * num_heads});
return std::make_tuple(context_layer, present, attention_probs);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"forward",
&forward,
"GPT-Neox attention mechanism forward (CUDA)"
);
}

View File

@ -0,0 +1,250 @@
#include <ATen/Dispatch.h>
#include <THC/THCAtomics.cuh>
#include <ATen/ATen.h>
#include <torch/torch.h>
#include <vector>
#include <optional>
/**
* Friendly reminder of how multithreading works in CUDA: https://developer.nvidia.com/blog/even-easier-introduction-cuda
* Check example at https://github.com/thomasw21/LinearTransformers/blob/main/model/attention/fast_weight/fast_weight_cuda.cu
**/
// Available in pytorch main
//#define DISPATCH_CASE_FLOATING_TYPES(...) \
// at::AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
// at::AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
// at::AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
// at::AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
/*
* Forward passes
*/
/**
* cast to fp32 if in fp16 + mask + softmax computation in fp32 + cast back to original dtype
**/
template<typename attention_scores_scalar, int64_t min_kv_length_shard_size_per_thread>
__global__ void forward_masked_softmax_kernel(
const torch::PackedTensorAccessor32<attention_scores_scalar, 2, torch::RestrictPtrTraits> attention_scores, // [B, KV]
const torch::PackedTensorAccessor32<bool, 2, torch::RestrictPtrTraits> mask, // [B, KV]
torch::PackedTensorAccessor32<attention_scores_scalar, 2, torch::RestrictPtrTraits> result, // [B, KV]
const int64_t effective_kv_length,
const dim3 blockDim,
const int64_t rows_per_block,
const int64_t kv_length,
const int64_t batch_size
) {
const auto row_id = threadIdx.x / effective_kv_length;
const auto effective_kv_length_id = threadIdx.x % effective_kv_length;
const auto kv_length_start = effective_kv_length_id * min_kv_length_shard_size_per_thread;
auto kv_length_end_ = (effective_kv_length_id + 1) * min_kv_length_shard_size_per_thread;
kv_length_end_ = (kv_length_end_ > kv_length) ? kv_length : kv_length_end_;
const auto kv_length_end = kv_length_end_;
const auto batch_id = blockIdx.x * rows_per_block + row_id;
// We need 2 float storage for each row, one for max computation, the other for normalizing exponential
extern __shared__ float temp_storage[];
const auto row_id_mem_offset = row_id * 2;
if (effective_kv_length_id == 0) {
temp_storage[row_id_mem_offset] = -std::numeric_limits<float>::infinity();
temp_storage[row_id_mem_offset + 1] = 0;
}
__syncthreads();
// Compute mask and max
if (batch_id < batch_size) {
float thread_max = -std::numeric_limits<float>::infinity();
for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
if (mask[batch_id][kv_length_id] == 0) {
const float candidate = attention_scores[batch_id][kv_length_id];
thread_max = (thread_max < candidate) ? candidate : thread_max;
}
}
if (thread_max != -std::numeric_limits<float>::infinity()) {
// TODO @thomasw21 with more memory we can probably compute a much faster `max-reduce` in parallel O(ln(n)) operations in each memory slot
gpuAtomicMax(&temp_storage[row_id_mem_offset], thread_max);
}
}
__syncthreads();
// Compute exp(elt - max) masked
float exponential[min_kv_length_shard_size_per_thread];
if (batch_id < batch_size) {
float thread_add = 0;
for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
if (mask[batch_id][kv_length_id] == 0) {
exponential[kv_length_id - kv_length_start] = std::exp(static_cast<float>(attention_scores[batch_id][kv_length_id]) - temp_storage[row_id_mem_offset]);
thread_add = thread_add + exponential[kv_length_id - kv_length_start];
} else {
exponential[kv_length_id - kv_length_start] = 0.;
}
}
if (thread_add > 0) {
// TODO @thomasw21 with more memory we can probably compute a much faster `sum-reduce` in parallel O(ln(n)) operations in each memory slot
gpuAtomicAdd(&temp_storage[row_id_mem_offset + 1], thread_add);
}
}
__syncthreads();
// Compute softmax
if (batch_id < batch_size) {
// If sum of all exponential is 0, we set the softmax values to 0
if (temp_storage[row_id_mem_offset + 1] == 0.) {
for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
result[batch_id][kv_length_id] = 0.;
}
} else {
for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
result[batch_id][kv_length_id] = static_cast<attention_scores_scalar>(exponential[kv_length_id - kv_length_start] / temp_storage[row_id_mem_offset + 1]);
}
}
}
}
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::tuple<at::Tensor, std::optional<std::vector<at::Tensor>>, at::Tensor> forward(
const at::Tensor fused_qkv,
const std::optional<std::vector<at::Tensor>> layer_past,
const at::Tensor alibi,
const at::Tensor attention_mask,
const std::optional<at::Tensor> head_mask,
const float beta,
const float inv_norm_factor,
const int num_heads,
const bool use_cache
) {
const auto batch_size = fused_qkv.size(0);
const auto q_length = fused_qkv.size(1);
const auto three_times_hidden_size = fused_qkv.size(2);
const auto head_dim = three_times_hidden_size / (3 * num_heads);
const auto batch_size_times_num_heads = batch_size * num_heads;
// `split_heads`
const auto fused_qkv_view = fused_qkv.view({batch_size, q_length, num_heads, 3 * head_dim});
const auto tensor_list = fused_qkv_view.split(head_dim, -1);
const auto query_layer = tensor_list[0].transpose(1, 2).reshape({batch_size_times_num_heads, q_length, head_dim});
auto key_layer = tensor_list[1].permute({0, 2, 3, 1}).reshape({batch_size_times_num_heads, head_dim, q_length});
auto value_layer = tensor_list[2].transpose(1, 2).reshape({batch_size_times_num_heads, q_length, head_dim});
if (layer_past) {
const auto past_key = (*layer_past).at(0);
const auto past_value = (*layer_past).at(1);
key_layer = at::cat({past_key, key_layer}, 2);
value_layer = at::cat({past_value, value_layer}, 1);
}
std::optional<std::vector<at::Tensor>> present;
if (use_cache) {
present = {key_layer, value_layer};
} else {
present = {};
}
auto attention_scores = alibi.baddbmm(query_layer, key_layer, beta, inv_norm_factor);
// Computing `optionally_cast_fp16_to_fp32 + masked_fill + softmax + cast_to_intial_dtype`
at::Tensor attention_probs;
if (true) {
const auto kv_length = key_layer.size(2);
// TODO @thomasw21: it's easier to think of attention_scores as 2D tensors
const auto attention_scores_2d = attention_scores.view({batch_size_times_num_heads * q_length, kv_length});
const auto attention_mask_2d = attention_mask.view({batch_size_times_num_heads * q_length, kv_length});
// Custom kernel
attention_probs = at::empty_like(attention_scores_2d);
// Check that inputs and contiguous + cuda tensors
CHECK_INPUT(attention_scores_2d);
CHECK_INPUT(attention_mask_2d);
// TODO @thomas21: change by to this as it's cleaner when pytorch 1.13 comes out
// DISPATCH_CASE_FLOATING_TYPES(attention_scores.scalar_type(), "masked_softmax", [&] {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, attention_scores.scalar_type(), "masked_softmax", [&] {
/*
* Understanding how GPUs work: https://developer.nvidia.com/blog/cuda-refresher-cuda-programming-model/
* A100 specifications: https://images.nvidia.com/aem-dam/en-zz/Solutions/data-center/nvidia-ampere-architecture-whitepaper.pdf
* - SMs: 108
* - TPCs: 56 (What's that?)
* - Memory size: 40 GB
* - L2 Cache size: 40960 KB (shared across all SMs)
* - L1/Shared memory size: 192 KB (shared across all threads within a SM)
* - Max Threads / SM: 2048
* - Max Thread Blocks / SM: 32
*/
/*
* We should split [batch_size_times_num_heads_block, q_length] in seperate blocks and [batch_size_times_num_heads_block_size, kv_length] a single block
* with multiple threads as we need to `sync_threads` to run exponential sum.
* We maximise the usage of threads within a single block
*/
// TODO @thomasw21 figure out everything warp related:
// - why do they have to be power of 2
// TODO @thomas21 check why everyone is setting 1024 when officially it's 2048
const auto MAX_THREADS_PER_SM = 1024;
// TODO @thomasw21 figure out how to have longer sequences, currently the maximum is `max_kv_length = MAX_THREADS_PER_SM * MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD`
const auto MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD = 4;
// `effective_kv_length = ceil(kv_length / MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD)`
const auto effective_kv_length = (kv_length - 1)/ MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD + 1;
const auto rows_per_block = MAX_THREADS_PER_SM / effective_kv_length;
const auto num_blocks = (batch_size_times_num_heads * q_length - 1) / rows_per_block + 1;
const dim3 gridDim(num_blocks); // Number of blocks that run
const dim3 blockDim(MAX_THREADS_PER_SM); // Number of threads that run per block
const int shared_mem_forward = rows_per_block * 2 * sizeof(float);
// 192 * 2 ** 10
// const auto MAX_L1_MEMORY = 196608;
// const auto MAX_SMs = 108;
// TORCH_CHECK(batch_size_times_num_heads * q_length <= MAX_L1_MEMORY, "Shared memory exceeds 192KB limitation.");
// TORCH_CHECK(gridDim.x * gridDim.y * gridDim.z <= MAX_SMs, "A100s only have 108 SMs. Raising as require blocks is bigger.");
// TORCH_CHECK(blockDim.x * blockDim.y * blockDim.z <= MAX_THREADS_PER_SM, "A100s only have 2048 threads per block. Raising as require requested threads is higher.");
forward_masked_softmax_kernel<scalar_t, MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD><<<gridDim, blockDim, shared_mem_forward>>>(
attention_scores_2d.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
attention_mask_2d.packed_accessor32<bool, 2, torch::RestrictPtrTraits>(),
attention_probs.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
effective_kv_length,
blockDim,
rows_per_block,
kv_length,
batch_size_times_num_heads * q_length
);
});
attention_probs = attention_probs.view({batch_size_times_num_heads, q_length, kv_length});
} else {
// Pytorch C++ API
auto input_dtype = attention_scores.scalar_type();
if (input_dtype == at::ScalarType::Float) {
attention_scores = attention_scores.to(at::ScalarType::Float);
};
// TODO @thomasw21 Figure out how to get minimum value
auto attn_weights = attention_scores.masked_fill_(attention_mask, -1e34);
attention_probs = attn_weights.softmax(-1, at::ScalarType::Float).to(input_dtype);
}
auto context_layer = attention_probs.bmm(value_layer);
// `_merge_heads`
context_layer = context_layer.view({batch_size, num_heads, q_length, head_dim});
context_layer = context_layer.permute({0, 2, 1, 3});
context_layer = context_layer.reshape({batch_size, q_length, three_times_hidden_size / 3});
return std::make_tuple(context_layer, present, attention_probs);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"forward",
&forward,
"Bloom attention mechanism forward (CUDA)"
);
}

View File

@ -0,0 +1,19 @@
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name="custom_kernels",
ext_modules=[
CUDAExtension(
name="custom_kernels.fused_bloom_attention_cuda",
sources=["custom_kernels/fused_bloom_attention_cuda.cu"],
extra_compile_args=["-arch=compute_80", "-std=c++17"],
),
CUDAExtension(
name="custom_kernels.fused_attention_cuda",
sources=["custom_kernels/fused_attention_cuda.cu"],
extra_compile_args=["-arch=compute_80", "-std=c++17"],
),
],
cmdclass={"build_ext": BuildExtension},
)

View File

@ -25,7 +25,8 @@ opentelemetry-instrumentation-grpc = "^0.36b0"
hf-transfer = "^0.1.2" hf-transfer = "^0.1.2"
sentencepiece = "^0.1.97" sentencepiece = "^0.1.97"
tokenizers = "0.13.3" tokenizers = "0.13.3"
huggingface-hub = "0.14.0" huggingface-hub = "^0.14.1"
transformers = "^4.29.2"
[tool.poetry.extras] [tool.poetry.extras]
accelerate = ["accelerate"] accelerate = ["accelerate"]

View File

@ -13,8 +13,8 @@ grpcio-reflection==1.55.0 ; python_version >= "3.9" and python_version < "4.0"
grpcio-status==1.55.0 ; python_version >= "3.9" and python_version < "4.0" grpcio-status==1.55.0 ; python_version >= "3.9" and python_version < "4.0"
grpcio==1.55.0 ; python_version >= "3.9" and python_version < "4.0" grpcio==1.55.0 ; python_version >= "3.9" and python_version < "4.0"
hf-transfer==0.1.3 ; python_version >= "3.9" and python_version < "4.0" hf-transfer==0.1.3 ; python_version >= "3.9" and python_version < "4.0"
huggingface-hub==0.14.0 ; python_version >= "3.9" and python_version < "4.0" huggingface-hub==0.14.1 ; python_version >= "3.9" and python_version < "4.0"
idna==3.4 ; python_version >= "3.9" and python_version < "4.0" idna==3.4 ; python_version >= "3.9" and python_version < "4"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "4.0" loguru==0.6.0 ; python_version >= "3.9" and python_version < "4.0"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "4.0" opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "4.0" opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
@ -33,6 +33,7 @@ safetensors==0.3.1 ; python_version >= "3.9" and python_version < "4.0"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "4.0" sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "4.0"
setuptools==67.8.0 ; python_version >= "3.9" and python_version < "4.0" setuptools==67.8.0 ; python_version >= "3.9" and python_version < "4.0"
tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "4.0" tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "4.0"
transformers==4.29.2 ; python_version >= "3.9" and python_version < "4.0"
tqdm==4.65.0 ; python_version >= "3.9" and python_version < "4.0" tqdm==4.65.0 ; python_version >= "3.9" and python_version < "4.0"
typer==0.6.1 ; python_version >= "3.9" and python_version < "4.0" typer==0.6.1 ; python_version >= "3.9" and python_version < "4.0"
typing-extensions==4.6.0 ; python_version >= "3.9" and python_version < "4.0" typing-extensions==4.6.0 ; python_version >= "3.9" and python_version < "4.0"

View File

@ -6,12 +6,17 @@ from transformers import AutoTokenizer
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.models.causal_lm import CausalLMBatch from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOM from text_generation_server.utils import weight_hub_files, download_weights
from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOMSharded
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def default_bloom(): def default_bloom():
return BLOOM("bigscience/bloom-560m") model_id = "bigscience/bloom-560m"
revision = "main"
filenames = weight_hub_files(model_id, revision, ".safetensors")
download_weights(filenames, model_id, revision)
return BLOOMSharded(model_id)
@pytest.fixture(scope="session") @pytest.fixture(scope="session")

View File

@ -1,3 +1,4 @@
import os
import torch import torch
from loguru import logger from loguru import logger
@ -8,17 +9,20 @@ from typing import Optional
from text_generation_server.models.model import Model from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.flash_causal_lm import FlashCausalLM from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.bloom import BLOOM, BLOOMSharded from text_generation_server.models.bloom import BLOOMSharded
from text_generation_server.models.seq2seq_lm import Seq2SeqLM from text_generation_server.models.seq2seq_lm import Seq2SeqLM
from text_generation_server.models.rw import RW from text_generation_server.models.rw import RW
from text_generation_server.models.opt import OPT, OPTSharded from text_generation_server.models.opt import OPTSharded
from text_generation_server.models.galactica import Galactica, GalacticaSharded from text_generation_server.models.galactica import GalacticaSharded
from text_generation_server.models.santacoder import SantaCoder from text_generation_server.models.santacoder import SantaCoder
from text_generation_server.models.gpt_neox import GPTNeoxSharded
from text_generation_server.models.t5 import T5Sharded from text_generation_server.models.t5 import T5Sharded
from text_generation_server.models.gpt_neox import GPTNeoxSharded
try: try:
if torch.cuda.is_available(): if (
torch.cuda.is_available()
and not os.getenv("USE_FLASH_ATTENTION", "").lower() == "false"
):
major, minor = torch.cuda.get_device_capability() major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5 is_sm75 = major == 7 and minor == 5
is_sm8x = major == 8 and minor >= 0 is_sm8x = major == 8 and minor >= 0
@ -30,14 +34,12 @@ try:
f"GPU with CUDA capability {major} {minor} is not supported" f"GPU with CUDA capability {major} {minor} is not supported"
) )
from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded from text_generation_server.models.flash_rw import FlashRWSharded
from text_generation_server.models.flash_rw import FlashRW, FlashRWSharded from text_generation_server.models.flash_neox import FlashNeoXSharded
from text_generation_server.models.flash_llama import ( from text_generation_server.models.flash_llama import (
FlashLlama, FlashLlama,
FlashLlamaSharded,
) )
from text_generation_server.models.flash_santacoder import ( from text_generation_server.models.flash_santacoder import (
FlashSantacoder,
FlashSantacoderSharded, FlashSantacoderSharded,
) )
@ -52,30 +54,22 @@ except ImportError:
__all__ = [ __all__ = [
"Model", "Model",
"BLOOM",
"BLOOMSharded", "BLOOMSharded",
"CausalLM", "CausalLM",
"FlashCausalLM", "FlashCausalLM",
"Galactica",
"GalacticaSharded", "GalacticaSharded",
"GPTNeoxSharded",
"Seq2SeqLM", "Seq2SeqLM",
"SantaCoder", "SantaCoder",
"OPT",
"OPTSharded", "OPTSharded",
"T5Sharded", "T5Sharded",
"get_model", "get_model",
] ]
if FLASH_ATTENTION: if FLASH_ATTENTION:
__all__.append(FlashNeoX)
__all__.append(FlashNeoXSharded) __all__.append(FlashNeoXSharded)
__all__.append(FlashRW)
__all__.append(FlashRWSharded) __all__.append(FlashRWSharded)
__all__.append(FlashSantacoder)
__all__.append(FlashSantacoderSharded) __all__.append(FlashSantacoderSharded)
__all__.append(FlashLlama) __all__.append(FlashLlama)
__all__.append(FlashLlamaSharded)
FLASH_ATT_ERROR_MESSAGE = ( FLASH_ATT_ERROR_MESSAGE = (
"{} requires Flash Attention CUDA kernels to be installed.\n" "{} requires Flash Attention CUDA kernels to be installed.\n"
@ -102,36 +96,24 @@ def get_model(
trust_remote_code: bool, trust_remote_code: bool,
) -> Model: ) -> Model:
if "facebook/galactica" in model_id: if "facebook/galactica" in model_id:
if sharded:
return GalacticaSharded( return GalacticaSharded(
model_id, model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
else:
return Galactica(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
) )
if model_id.startswith("bigcode/"): if model_id.startswith("bigcode/"):
if sharded: if FLASH_ATTENTION:
if not FLASH_ATTENTION:
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder")
)
return FlashSantacoderSharded( return FlashSantacoderSharded(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif sharded:
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
)
else: else:
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder return SantaCoder(
return santacoder_cls(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -144,20 +126,19 @@ def get_model(
model_type = config_dict["model_type"] model_type = config_dict["model_type"]
if model_type == "gpt_bigcode": if model_type == "gpt_bigcode":
if sharded: if FLASH_ATTENTION:
if not FLASH_ATTENTION:
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder")
)
return FlashSantacoderSharded( return FlashSantacoderSharded(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif sharded:
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
)
else: else:
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder return SantaCoder(
return santacoder_cls(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -165,33 +146,45 @@ def get_model(
) )
if model_type == "bloom": if model_type == "bloom":
if sharded:
return BLOOMSharded( return BLOOMSharded(
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
)
elif model_type == "gpt_neox":
if FLASH_ATTENTION:
return FlashNeoXSharded(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
elif sharded:
return GPTNeoxSharded(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
else: else:
return BLOOM( return CausalLM(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if model_type == "gpt_neox": elif model_type == "llama":
if sharded: if FLASH_ATTENTION:
neox_cls = FlashNeoXSharded if FLASH_ATTENTION else GPTNeoxSharded return FlashLlama(
return neox_cls(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
else: else:
neox_cls = FlashNeoX if FLASH_ATTENTION else CausalLM return CausalLM(
return neox_cls(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -217,7 +210,7 @@ def get_model(
) )
else: else:
if FLASH_ATTENTION and not config_dict.get("alibi", False): if FLASH_ATTENTION and not config_dict.get("alibi", False):
return FlashRW( return FlashRWSharded(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -231,42 +224,12 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if model_type == "llama": elif model_type == "opt":
if sharded:
if FLASH_ATTENTION:
return FlashLlamaSharded(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Llama"))
else:
llama_cls = FlashLlama if FLASH_ATTENTION else CausalLM
return llama_cls(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
if model_type == "opt":
if sharded:
return OPTSharded( return OPTSharded(
model_id, model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
else:
return OPT(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
) )
if model_type == "t5": elif model_type == "t5":
if sharded: if sharded:
return T5Sharded( return T5Sharded(
model_id, model_id,

View File

@ -1,37 +1,26 @@
import torch import torch
import torch.distributed import torch.distributed
from typing import List, Optional, Type from typing import Optional, Type
from accelerate import init_empty_weights
from safetensors import safe_open
from transformers import ( from transformers import (
AutoTokenizer, AutoTokenizer,
AutoModelForCausalLM,
AutoConfig, AutoConfig,
PreTrainedTokenizerBase, PreTrainedTokenizerBase,
) )
from transformers.models.bloom.parallel_layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
)
from text_generation_server.models.custom_modeling.bloom_modeling import (
BloomForCausalLM,
)
from text_generation_server.models import CausalLM from text_generation_server.models import CausalLM
from text_generation_server.models.causal_lm import CausalLMBatch from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
Weights,
) )
HAS_BITS_AND_BYTES = True
try:
import bitsandbytes as bnb
from bitsandbytes.nn import Int8Params
except Exception as e:
HAS_BITS_AND_BYTES = False
class BloomCausalLMBatch(CausalLMBatch): class BloomCausalLMBatch(CausalLMBatch):
@classmethod @classmethod
@ -42,34 +31,12 @@ class BloomCausalLMBatch(CausalLMBatch):
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
) -> "CausalLMBatch": ) -> "CausalLMBatch":
batch = super(BloomCausalLMBatch, cls).from_pb( batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device)
pb=pb, tokenizer=tokenizer, dtype=dtype, device=device
)
batch.keys_head_dim_last = False batch.keys_head_dim_last = False
return batch return batch
class BLOOM(CausalLM): class BLOOMSharded(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
super(BLOOM, self).__init__(
model_id=model_id,
revision=revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
@property
def batch_type(self) -> Type[CausalLMBatch]:
return BloomCausalLMBatch
class BLOOMSharded(BLOOM):
def __init__( def __init__(
self, self,
model_id: str, model_id: str,
@ -101,25 +68,16 @@ class BLOOMSharded(BLOOM):
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
config.pad_token_id = 3 config.pad_token_id = 3
config.quantize = quantize
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
with init_empty_weights(): filenames, device=device, dtype=dtype, process_group=self.process_group
model = AutoModelForCausalLM.from_config(
config, trust_remote_code=trust_remote_code
) )
torch.distributed.barrier(group=self.process_group) model = BloomForCausalLM(config, weights)
self.load_weights(
model,
filenames,
quantize=quantize,
device=device,
dtype=dtype,
rank=rank,
world_size=world_size,
)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
model=model, model=model,
@ -131,132 +89,9 @@ class BLOOMSharded(BLOOM):
world_size=world_size, world_size=world_size,
) )
@staticmethod @property
def load_weights( def batch_type(self) -> Type[CausalLMBatch]:
model, return BloomCausalLMBatch
filenames: List[str],
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
rank: int,
world_size: int,
):
parameters = dict(model.named_parameters())
for file in filenames:
with safe_open(
file, framework="pt", device=str(device) if quantize is None else "cpu"
) as f:
for name in f.keys():
if name.startswith("transformer.") or name.startswith("lm_head."):
full_name = name
else:
full_name = f"transformer.{name}"
module_name, param_name = full_name.rsplit(".", 1)
module = model.get_submodule(module_name)
current_tensor = parameters[full_name]
slice_ = f.get_slice(name)
if isinstance(module, TensorParallelColumnLinear):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif isinstance(module, TensorParallelRowLinear):
if param_name == "weight":
size = slice_.get_shape()[1]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[:, start:stop]
else:
tensor = slice_[:]
# XXX: Hack for Rowlinear to add the bias only once.
if rank != 0:
tensor = torch.zeros_like(tensor)
elif (
isinstance(module, TensorParallelEmbedding)
or name == "lm_head.weight"
):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
else:
tensor = slice_[:]
if current_tensor.shape != tensor.shape:
raise ValueError(
f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}"
)
tensor = tensor.contiguous().to(dtype)
if quantize == "bitsandbytes":
if not HAS_BITS_AND_BYTES:
raise ImportError(
"bitsandbytes is not available on your machine either because it is not installed "
"or you don't have a GPU.\n"
"You can install it with `pip install bitsandbytes`."
)
if (
type(module)
in [TensorParallelRowLinear, TensorParallelColumnLinear]
and param_name == "weight"
):
tensor = Int8Params(
tensor,
has_fp16_weights=False,
requires_grad=False,
).to(device)
state = bnb.MatmulLtState()
state.threshold = 6.0
state.has_fp16_weights = False
state.memory_efficient_backward = False
state.use_pool = True
state.CB = tensor.CB
state.SCB = tensor.SCB
tensor.CB = None
tensor.SCB = None
def replace_linear(state):
def linear(input, weight, bias):
out = bnb.matmul(
input,
weight,
state=state,
threshold=state.threshold,
bias=bias,
)
if state.CB is not None:
# we converted 8-bit row major to turing/ampere format
# in the first inference pass
# we no longer need the row-major weight
del state.CB
weight.data = state.CxB
return out
return linear
module.linear = replace_linear(state)
else:
tensor = tensor.to(device)
elif quantize == "gptq":
raise NotImplementedError("`gptq` is not implemented for now")
elif quantize is None:
tensor = tensor.to(device)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
module._parameters[param_name] = tensor
if name == "word_embeddings.weight":
model.lm_head._parameters["weight"] = tensor
def forward( def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
@ -269,9 +104,5 @@ class BLOOMSharded(BLOOM):
use_cache=True, use_cache=True,
) )
# Logits are sharded, so we need to gather them logits = outputs.logits
logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
torch.distributed.all_gather(logits, outputs.logits, group=self.process_group)
logits = torch.cat(logits, dim=2)
return logits, outputs.past_key_values return logits, outputs.past_key_values

View File

@ -0,0 +1,912 @@
# coding=utf-8
# Copyright 2022 HuggingFace Inc. team and BigScience workshop.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BLOOM model."""
import math
import os
import warnings
from typing import Optional, Tuple, Union
import torch
import torch.distributed
import torch.utils.checkpoint
from torch import nn
from torch.nn import LayerNorm
from torch.nn import functional as F
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
)
from transformers import BloomConfig, PreTrainedModel
from text_generation_server.utils.layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
TensorParallelHead,
)
CUSTOM_KERNELS_ENABLED = False
if not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True":
try:
from custom_kernels import fused_bloom_attention_cuda
CUSTOM_KERNELS_ENABLED = True
except ImportError:
pass
_CHECKPOINT_FOR_DOC = "bigscience/bloom-560m"
_CONFIG_FOR_DOC = "BloomConfig"
BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST = [
"bigscience/bigscience-small-testing",
"bigscience/bloom-560m",
"bigscience/bloom-1b1",
"bigscience/bloom-1b7",
"bigscience/bloom-3b",
"bigscience/bloom-7b1",
"bigscience/bloom",
]
def _make_causal_mask(
input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
) -> torch.BoolTensor:
"""
Make causal mask used for self-attention.
"""
batch_size, target_length = input_ids_shape
mask = torch.ones(
(target_length, target_length + past_key_values_length),
dtype=torch.bool,
device=device,
)
mask = mask.triu(1 + past_key_values_length)
expanded_mask = mask.unsqueeze(0).expand(
batch_size, target_length, target_length + past_key_values_length
)
return expanded_mask
def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
"""
Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`.
"""
batch_size, src_length = mask.shape
tgt_length = tgt_length if tgt_length is not None else src_length
expanded_mask = ~(mask[:, None, :].to(torch.bool))
return expanded_mask.expand(batch_size, tgt_length, src_length)
def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int) -> torch.Tensor:
"""
Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
`softmax(l+a) = softmax(l)`. Based on
https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly.
Args:
Returns tensor shaped (batch_size * num_heads, 1, max_seq_len)
attention_mask (`torch.Tensor`):
Token-wise attention mask, this should be of shape (batch_size, max_seq_len).
num_heads (`int`, *required*):
number of heads
dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`):
dtype of the output tensor
"""
batch_size, seq_length = attention_mask.shape
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
base = torch.tensor(
2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))),
device=attention_mask.device,
dtype=torch.float32,
)
powers = torch.arange(
1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32
)
slopes = torch.pow(base, powers)
if closest_power_of_2 != num_heads:
extra_base = torch.tensor(
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))),
device=attention_mask.device,
dtype=torch.float32,
)
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
extra_powers = torch.arange(
1,
1 + 2 * num_remaining_heads,
2,
device=attention_mask.device,
dtype=torch.int32,
)
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
# Note: alibi will added to the attention bias that will be applied to the query, key product of attention
# => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
# => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
# => the query_length dimension will then be broadcasted correctly
# This is more or less identical to T5's relative position bias:
# https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
alibi = slopes[..., None] * arange_tensor
return alibi
# @torch.jit.script
def dropout_add(
x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool
) -> torch.Tensor:
"""
Dropout add function
Args:
x (`torch.tensor`, *required*):
input tensor
residual (`torch.tensor`, *required*):
esidual tensor
prob (`float`, *required*):
dropout probability
training (`bool`, *required*):
training mode
"""
out = F.dropout(x, p=prob, training=training)
out = residual + out
return out
# @torch.jit.script # this is shit for unknow reasons.
def _split_heads(
fused_qkv: torch.Tensor, num_heads: int, head_dim: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory
storage as `fused_qkv`
Args:
fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
Returns:
query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
value: [batch_size, seq_length, num_heads, head_dim]
"""
batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
fused_qkv = fused_qkv.view(batch_size, seq_length, num_heads, 3 * head_dim)
query_layer, key_layer, value_layer = fused_qkv.split(head_dim, dim=-1)
query_layer = query_layer.transpose(1, 2).reshape(
batch_size * num_heads, seq_length, head_dim
)
key_layer = key_layer.permute(0, 2, 3, 1).reshape(
batch_size * num_heads, head_dim, seq_length
)
value_layer = value_layer.transpose(1, 2).reshape(
batch_size * num_heads, seq_length, head_dim
)
return query_layer, key_layer, value_layer
# @torch.jit.script
def _merge_heads(x: torch.Tensor, num_heads: int, head_dim: int) -> torch.Tensor:
"""
Merge heads together over the last dimenstion
Args:
x: (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim]
Returns:
torch.tensor: [batch_size, seq_length, num_heads * head_dim]
"""
# What we want to achieve is:
# batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim
batch_size_and_num_heads, seq_length, _ = x.shape
batch_size = batch_size_and_num_heads // num_heads
# First view to decompose the batch size
# batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim
x = x.view(batch_size, num_heads, seq_length, head_dim)
# batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim
x = x.permute(0, 2, 1, 3)
# batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim
return x.reshape(batch_size, seq_length, num_heads * head_dim)
class BloomAttention(nn.Module):
def __init__(self, prefix, config: BloomConfig, weights):
super().__init__()
self.pretraining_tp = config.pretraining_tp
self.slow_but_exact = config.slow_but_exact
self.process_group = weights.process_group
self.hidden_size = config.hidden_size
self.num_heads = config.n_head
self.head_dim = self.hidden_size // self.num_heads
self.split_size = self.hidden_size
self.hidden_dropout = config.hidden_dropout
if self.head_dim * self.num_heads != self.hidden_size:
raise ValueError(
f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
f" {self.num_heads})."
)
# Layer-wise attention scaling
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
self.beta = 1.0
process_group = weights.process_group
self.num_heads = self.num_heads // process_group.size()
self.query_key_value = TensorParallelColumnLinear.load(
config=config,
prefix=f"{prefix}.query_key_value",
weights=weights,
bias=True,
)
self.dense = TensorParallelRowLinear.load(
config=config, prefix=f"{prefix}.dense", weights=weights, bias=True
)
self.attention_dropout = nn.Dropout(config.attention_dropout)
@staticmethod
def compute_attention(
fused_qkv: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]],
alibi: torch.Tensor,
attention_mask: torch.Tensor,
head_mask: Optional[torch.Tensor],
beta: float,
inv_norm_factor: float,
num_heads: int,
use_cache: bool,
):
batch_size, q_length, three_times_hidden_size = fused_qkv.shape
head_dim = three_times_hidden_size // (3 * num_heads)
batch_size * num_heads
### TODO @thomasw21: this takes quite a bit of time, how do I accelerate that?
# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = _split_heads(
fused_qkv, num_heads=num_heads, head_dim=head_dim
)
if layer_past is not None:
past_key, past_value = layer_past
# concatenate along seq_length dimension:
# - key: [batch_size * self.num_heads, head_dim, kv_length]
# - value: [batch_size * self.num_heads, kv_length, head_dim]
past_key = past_key.view(-1, *past_key.shape[-2:])
key_layer = torch.cat((past_key, key_layer), dim=2)
past_value = past_value.view(-1, *past_value.shape[-2:])
value_layer = torch.cat((past_value, value_layer), dim=1)
_, _, kv_length = key_layer.shape
if use_cache is True:
present = (key_layer, value_layer)
else:
present = None
###
# [batch_size * num_heads, q_length, kv_length]
# we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11
attention_scores = alibi.baddbmm(
batch1=query_layer,
batch2=key_layer,
beta=beta,
alpha=inv_norm_factor,
)
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
input_dtype = attention_scores.dtype
# `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
if input_dtype == torch.float16:
attention_scores = attention_scores.to(torch.float)
# torch.finfo not supported by torch.jit, we temporarily remplace with `-1e34`
attn_weights = attention_scores.masked_fill_(
attention_mask, torch.finfo(attention_scores.dtype).min
)
attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
input_dtype
)
# # [batch_size, num_heads, q_length, kv_length]
# attention_probs = self.attention_dropout(attention_probs)
if head_mask is not None:
attention_probs = attention_probs * head_mask
# matmul: [batch_size * num_heads, q_length, head_dim]
context_layer = torch.bmm(attention_probs, value_layer, out=query_layer)
# change view [batch_size, num_heads, q_length, head_dim]
context_layer = _merge_heads(
context_layer, num_heads=num_heads, head_dim=head_dim
)
return context_layer, present, attention_probs
def forward(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor,
alibi: torch.Tensor,
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
fused_qkv = self.query_key_value(
hidden_states
) # [batch_size, seq_length, 3 x hidden_size]
batch_size, q_length, _ = fused_qkv.shape
if layer_past is not None:
past_key, past_value = layer_past
layer_past = (
past_key.view(-1, *past_key.shape[-2:]),
past_value.view(-1, *past_value.shape[-2:]),
)
if CUSTOM_KERNELS_ENABLED:
assert self.training is False, "Only foward pass was implemented"
assert (
attention_mask.shape[-1] < 4096
), "Custom kernel support only up to 4096 tokens"
(
context_layer,
present,
attention_probs,
) = fused_bloom_attention_cuda.forward(
fused_qkv,
layer_past,
alibi,
attention_mask,
head_mask,
self.beta,
self.inv_norm_factor,
self.num_heads,
use_cache,
)
else:
context_layer, present, attention_probs = self.compute_attention(
fused_qkv=fused_qkv,
layer_past=layer_past,
alibi=alibi,
attention_mask=attention_mask,
head_mask=head_mask,
beta=self.beta,
inv_norm_factor=self.inv_norm_factor,
num_heads=self.num_heads,
use_cache=use_cache,
)
# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
if self.pretraining_tp > 1 and self.slow_but_exact:
slices = self.hidden_size / self.pretraining_tp
output_tensor = torch.zeros_like(context_layer)
for i in range(self.pretraining_tp):
output_tensor = output_tensor + F.linear(
context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
)
else:
output_tensor = self.dense(context_layer)
# output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
output_tensor += residual
outputs = (output_tensor, present)
if output_attentions:
outputs += (attention_probs,)
return outputs
class BloomMLP(nn.Module):
def __init__(self, prefix, config: BloomConfig, weights):
super().__init__()
self.pretraining_tp = config.pretraining_tp
self.slow_but_exact = config.slow_but_exact
self.dense_h_to_4h = TensorParallelColumnLinear.load(
config=config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=True
)
self.dense_4h_to_h = TensorParallelRowLinear.load(
config=config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=True
)
self.gelu_impl = torch.nn.GELU(approximate="tanh")
self.hidden_dropout = config.hidden_dropout
def forward(
self, hidden_states: torch.Tensor, residual: torch.Tensor
) -> torch.Tensor:
hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states))
if self.pretraining_tp > 1 and self.slow_but_exact:
intermediate_output = torch.zeros_like(residual)
slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp
for i in range(self.pretraining_tp):
intermediate_output = intermediate_output + F.linear(
hidden_states[:, :, int(i * slices) : int((i + 1) * slices)],
self.dense_4h_to_h.weight[
:, int(i * slices) : int((i + 1) * slices)
],
)
else:
intermediate_output = self.dense_4h_to_h(hidden_states)
# output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)
intermediate_output += residual
return intermediate_output
class BloomBlock(nn.Module):
def __init__(self, layer_id: int, config: BloomConfig, weights):
super().__init__()
prefix = f"h.{layer_id}"
self.input_layernorm = LayerNorm.load(
prefix=f"{prefix}.input_layernorm",
weights=weights,
eps=config.layer_norm_epsilon,
)
self.num_heads = config.n_head
self.self_attention = BloomAttention(
prefix=f"{prefix}.self_attention", config=config, weights=weights
)
self.post_attention_layernorm = LayerNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.layer_norm_epsilon,
)
self.mlp = BloomMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.apply_residual_connection_post_layernorm = (
config.apply_residual_connection_post_layernorm
)
self.hidden_dropout = config.hidden_dropout
def forward(
self,
hidden_states: torch.Tensor,
alibi: torch.Tensor,
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
# hidden_states: [batch_size, seq_length, hidden_size]
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Layer norm post the self attention.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
# Self attention.
attn_outputs = self.self_attention(
layernorm_output,
residual,
layer_past=layer_past,
attention_mask=attention_mask,
alibi=alibi,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
attention_output = attn_outputs[0]
outputs = attn_outputs[1:]
layernorm_output = self.post_attention_layernorm(attention_output)
# Get residual
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = attention_output
# MLP.
output = self.mlp(layernorm_output, residual)
if use_cache:
outputs = (output,) + outputs
else:
outputs = (output,) + outputs[1:]
return outputs # hidden_states, present, attentions
class BloomPreTrainedModel(PreTrainedModel):
config_class = BloomConfig
base_model_prefix = "transformer"
_no_split_modules = ["BloomBlock"]
@staticmethod
def _convert_to_standard_cache(
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
"""
Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
num_heads, ...]))
"""
batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape
num_heads = batch_size_times_num_heads // batch_size
# key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
# value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
return tuple(
(
layer_past[0].view(batch_size, num_heads, head_dim, seq_length),
layer_past[1].view(batch_size, num_heads, seq_length, head_dim),
)
for layer_past in past_key_value
)
@staticmethod
def _convert_to_bloom_cache(
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
"""
Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...]))
"""
batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
batch_size_times_num_heads = batch_size * num_heads
# key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
# value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
return tuple(
(
layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length),
layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim),
)
for layer_past in past_key_value
)
class BloomModel(BloomPreTrainedModel):
def __init__(self, config: BloomConfig, weights):
super().__init__(config)
self.embed_dim = config.hidden_size
self.num_heads = config.n_head
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.word_embeddings = TensorParallelEmbedding(
prefix="word_embeddings", weights=weights
)
self.word_embeddings_layernorm = LayerNorm.load(
prefix="word_embeddings_layernorm",
weights=weights,
eps=config.layer_norm_epsilon,
)
# Transformer blocks
self.h = nn.ModuleList(
[
BloomBlock(layer_id=layer_id, config=config, weights=weights)
for layer_id in range(config.num_hidden_layers)
]
)
# Final Layer Norm
self.ln_f = LayerNorm.load(
prefix="ln_f", weights=weights, eps=config.layer_norm_epsilon
)
def _prepare_attn_mask(
self,
attention_mask: torch.Tensor,
input_shape: Tuple[int, int],
past_key_values_length: int,
) -> torch.BoolTensor:
# create causal mask
# [batch_size, seq_length] -> [batch_size, tgt_length, src_length]
combined_attention_mask = None
device = attention_mask.device
_, src_length = input_shape
if src_length > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
device=device,
past_key_values_length=past_key_values_length,
)
# [batch_size, seq_length] -> [batch_size, tgt_length, src_length]
expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
combined_attention_mask = (
expanded_attn_mask
if combined_attention_mask is None
else expanded_attn_mask | combined_attention_mask
)
return combined_attention_mask
def set_input_embeddings(self, new_embeddings: torch.Tensor):
self.word_embeddings = new_embeddings
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**deprecated_arguments,
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
if deprecated_arguments.pop("position_ids", False) is not False:
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
warnings.warn(
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
" passing `position_ids`.",
FutureWarning,
)
if len(deprecated_arguments) > 0:
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if past_key_values is None:
past_key_values = tuple([None] * len(self.h))
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape batch_size x num_heads x N x N
# head_mask has shape n_layer x batch x num_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
# Compute alibi tensor: check build_alibi_tensor documentation
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values[0] is not None:
past_key_values_length = past_key_values[0][0].shape[-1]
seq_length_with_past = seq_length_with_past + past_key_values_length
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), device=hidden_states.device
)
else:
attention_mask = attention_mask.to(hidden_states.device)
alibi = build_alibi_tensor(attention_mask, self.num_heads)
causal_mask = self._prepare_attn_mask(
attention_mask,
input_shape=(batch_size, seq_length),
past_key_values_length=past_key_values_length,
)
if hasattr(self, "tp_rank"):
assert self.num_heads % self.tp_world_size == 0
block_size = self.num_heads // self.tp_world_size
alibi = alibi[
:, self.tp_rank * block_size : (self.tp_rank + 1) * block_size
]
alibi = alibi.reshape(batch_size * block_size, 1, seq_length_with_past)
causal_mask = torch.repeat_interleave(causal_mask, block_size, dim=0)
else:
alibi = alibi.reshape(batch_size * self.num_heads, 1, seq_length_with_past)
causal_mask = torch.repeat_interleave(causal_mask, self.num_heads, dim=0)
alibi = alibi.to(hidden_states.dtype)
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=causal_mask,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
alibi=alibi,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_self_attentions = all_self_attentions + (
outputs[2 if use_cache else 1],
)
# Add last hidden state
hidden_states = self.ln_f(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
presents,
all_hidden_states,
all_self_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
class BloomForCausalLM(BloomPreTrainedModel):
def __init__(self, config, weights):
super().__init__(config)
self.transformer = BloomModel(config, weights)
self.lm_head = TensorParallelHead.load(
config,
prefix="word_embeddings",
weights=weights,
)
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor,
past_key_values: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> dict:
# only last token for input_ids if past is not None
if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1)
# the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed
if past_key_values[0][0].shape[0] == input_ids.shape[0]:
past_key_values = self._convert_to_bloom_cache(past_key_values)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**deprecated_arguments,
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
if deprecated_arguments.pop("position_ids", False) is not False:
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
warnings.warn(
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
" passing `position_ids`.",
FutureWarning,
)
if len(deprecated_arguments) > 0:
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
loss = None
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)

View File

@ -30,21 +30,23 @@ import flash_attn_cuda
import dropout_layer_norm import dropout_layer_norm
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
FastLinear,
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
PositionRotaryEmbedding, PositionRotaryEmbedding,
TensorParallelHead,
) )
class LlamaRMSNorm(nn.Module): class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6): def __init__(self, prefix, weights, eps=1e-6):
""" """
LlamaRMSNorm is equivalent to T5LayerNorm LlamaRMSNorm is equivalent to T5LayerNorm
""" """
super().__init__() super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
weight = weights.get_tensor(f"{prefix}.weight")
self.weight = nn.Parameter(weight)
self.variance_epsilon = eps self.variance_epsilon = eps
def forward(self, hidden_states, residual=None): def forward(self, hidden_states, residual=None):
@ -91,34 +93,34 @@ class LlamaRMSNorm(nn.Module):
class FlashLlamaAttention(torch.nn.Module): class FlashLlamaAttention(torch.nn.Module):
def __init__( def __init__(
self, self,
num_heads, prefix: str,
hidden_size, config,
process_group=None, weights,
): ):
super().__init__() super().__init__()
self.num_heads = num_heads self.num_heads = config.num_attention_heads
self.hidden_size = hidden_size self.hidden_size = config.hidden_size
self.head_size = hidden_size // num_heads self.head_size = self.hidden_size // self.num_heads
self.rotary_emb = PositionRotaryEmbedding.load(
prefix=f"{prefix}.rotary_emb", weights=weights
)
self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000)
self.softmax_scale = self.head_size ** (-0.5) self.softmax_scale = self.head_size ** (-0.5)
if process_group is None: self.num_heads = self.num_heads // weights.process_group.size()
self.query_key_value = FastLinear(hidden_size, 3 * hidden_size, bias=False) self.query_key_value = TensorParallelColumnLinear.load_multi(
self.o_proj = FastLinear(hidden_size, hidden_size, bias=False) config,
else: prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
self.num_heads = self.num_heads // process_group.size() dim=0,
self.query_key_value = TensorParallelColumnLinear( weights=weights,
hidden_size,
3 * hidden_size,
bias=False, bias=False,
process_group=process_group,
) )
self.o_proj = TensorParallelRowLinear( self.o_proj = TensorParallelRowLinear.load(
hidden_size, config,
hidden_size, prefix=f"{prefix}.o_proj",
weights=weights,
bias=False, bias=False,
process_group=process_group,
) )
def forward( def forward(
@ -195,8 +197,9 @@ class FlashLlamaAttention(torch.nn.Module):
class LlamaMLP(nn.Module): class LlamaMLP(nn.Module):
def __init__(self, act, hidden_size, intermediate_size, process_group=None): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
act = config.hidden_act
self.act = ( self.act = (
ACT2FN[act] ACT2FN[act]
if "gelu" not in act if "gelu" not in act
@ -207,32 +210,23 @@ class LlamaMLP(nn.Module):
else "none", else "none",
) )
) )
if process_group is None:
# Fuse gate and up proj # Fuse gate and up proj
self.gate_up_proj = FastLinear( self.gate_up_proj = TensorParallelColumnLinear.load_multi(
hidden_size, 2 * intermediate_size, bias=False config,
) prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
self.down_proj = FastLinear(intermediate_size, hidden_size, bias=False) weights=weights,
self.intermediate_size = intermediate_size dim=0,
else:
# Fuse gate and up proj
self.gate_up_proj = TensorParallelColumnLinear(
hidden_size,
2 * intermediate_size,
bias=False, bias=False,
process_group=process_group,
) )
self.down_proj = TensorParallelRowLinear( self.down_proj = TensorParallelRowLinear.load(
intermediate_size, config,
hidden_size, prefix=f"{prefix}.down_proj",
weights=weights,
bias=False, bias=False,
process_group=process_group,
reduce=True,
) )
self.intermediate_size = self.down_proj.in_features self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
self.process_group = process_group )
def forward(self, hidden_states): def forward(self, hidden_states):
gate_up_states = self.gate_up_proj(hidden_states) gate_up_states = self.gate_up_proj(hidden_states)
@ -241,22 +235,22 @@ class LlamaMLP(nn.Module):
class FlashLlamaLayer(nn.Module): class FlashLlamaLayer(nn.Module):
def __init__( def __init__(self, layer_id, config, weights):
self,
num_heads,
act,
hidden_size,
intermediate_size,
rms_norm_eps,
process_group=None,
):
super().__init__() super().__init__()
prefix = f"model.layers.{layer_id}"
self.self_attn = FlashLlamaAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.self_attn = FlashLlamaAttention(num_heads, hidden_size, process_group) self.input_layernorm = LlamaRMSNorm(
self.mlp = LlamaMLP(act, hidden_size, intermediate_size, process_group) prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.input_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps) self.post_attention_layernorm = LlamaRMSNorm(
self.post_attention_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps) prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
def forward( def forward(
self, self,
@ -295,54 +289,35 @@ class FlashLlamaLayer(nn.Module):
class FlashLlamaModel(torch.nn.Module): class FlashLlamaModel(torch.nn.Module):
def __init__(self, config, process_group=None): def __init__(self, config, weights):
super(FlashLlamaModel, self).__init__() super().__init__()
self.config = config self.config = config
self.tp_embeddings = False process_group = weights.process_group
if process_group is not None:
self.tp_rank = process_group.rank() self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size() self.tp_world_size = process_group.size()
if config.vocab_size % self.tp_world_size == 0:
self.tp_embeddings = True
if self.tp_embeddings:
self.embed_tokens = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(
config.vocab_size, config.hidden_size, process_group=process_group prefix="model.embed_tokens", weights=weights
) )
else:
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
FlashLlamaLayer( FlashLlamaLayer(
config.num_attention_heads, layer_id,
config.hidden_act, config,
config.hidden_size, weights,
config.intermediate_size,
config.rms_norm_eps,
process_group,
) )
for _ in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
) )
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = LlamaRMSNorm(
prefix="model.norm", weights=weights, eps=config.rms_norm_eps
)
self.gradient_checkpointing = False self.gradient_checkpointing = False
self.head_size = self.layers[0].self_attn.head_size self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads self.num_heads = self.layers[0].self_attn.num_heads
def post_load_weights(self, quantize: Optional[str] = None):
if isinstance(self.embed_tokens, TensorParallelEmbedding):
self.embed_tokens.add_null_idx()
for layer in self.layers:
layer: FlashLlamaLayer
layer.self_attn.query_key_value.prepare_weights(quantize)
layer.self_attn.o_proj.prepare_weights(quantize)
layer.mlp.gate_up_proj.prepare_weights(quantize)
layer.mlp.down_proj.prepare_weights(quantize)
def forward( def forward(
self, self,
input_ids, input_ids,
@ -410,29 +385,15 @@ class FlashLlamaModel(torch.nn.Module):
class FlashLlamaForCausalLM(torch.nn.Module): class FlashLlamaForCausalLM(torch.nn.Module):
def __init__(self, config, process_group=None): def __init__(self, config, weights):
super().__init__() super().__init__()
self.process_group = process_group self.model = FlashLlamaModel(config, weights)
if self.process_group is not None: self.lm_head = TensorParallelHead.load(
self.world_size = self.process_group.size() config,
else: prefix="lm_head",
self.world_size = 1 weights=weights,
self.model = FlashLlamaModel(config, process_group)
if self.model.tp_embeddings:
self.lm_head = FastLinear(
config.hidden_size,
config.vocab_size // process_group.size(),
bias=False,
) )
else:
self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False)
def post_load_weights(self, quantize: Optional[str] = None):
self.model.post_load_weights(quantize)
self.lm_head.prepare_weights()
def forward( def forward(
self, self,
@ -457,12 +418,4 @@ class FlashLlamaForCausalLM(torch.nn.Module):
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
if self.model.tp_embeddings:
# Logits are sharded, so we need to gather them
world_logits = [torch.empty_like(logits) for _ in range(self.world_size)]
torch.distributed.all_gather(world_logits, logits, group=self.process_group)
world_logits = torch.cat(world_logits, dim=1)
return world_logits, present
return logits, present return logits, present

View File

@ -31,61 +31,81 @@ from typing import Optional
import flash_attn_cuda import flash_attn_cuda
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
FastLinear,
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
TensorParallelHead,
FastLayerNorm, FastLayerNorm,
PositionRotaryEmbedding, PositionRotaryEmbedding,
get_linear,
) )
class FlashNeoxAttention(torch.nn.Module): def load_row(config, prefix: str, weights, bias: bool):
def __init__( weight = weights.get_sharded(f"{prefix}.weight", dim=1)
self, if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process
bias = weights.get_tensor(f"{prefix}.bias")
else:
bias = None
linear = get_linear(weight, bias, config.quantize)
if config.use_parallel_residual:
return linear
else:
return TensorParallelRowLinear(linear, process_group=weights.process_group)
def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size):
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
weight = (
weight.view(
num_heads, num_heads,
3,
head_size,
hidden_size, hidden_size,
rotary_pct, )
rotary_emb_base, .permute(1, 0, 2, 3)
process_group=None, .reshape(-1, hidden_size)
reduce=True, )
): bias = bias.view(num_heads, 3, head_size).permute(1, 0, 2).reshape(-1)
linear = get_linear(weight, bias, config.quantize)
if config.use_parallel_residual:
return linear
else:
return TensorParallelColumnLinear(linear)
class FlashNeoxAttention(torch.nn.Module):
def __init__(self, config, prefix, weights):
super().__init__() super().__init__()
num_heads = config.num_attention_heads
hidden_size = config.hidden_size
self.num_heads = num_heads self.num_heads = num_heads
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.head_size = hidden_size // num_heads self.head_size = hidden_size // num_heads
self.num_heads = self.num_heads // weights.process_group.size()
self.rotary_emb = PositionRotaryEmbedding.load(
prefix=f"{prefix}.rotary_emb", weights=weights
)
rotary_ndims = int(self.head_size * rotary_pct)
self.rotary_emb = PositionRotaryEmbedding(rotary_ndims, base=rotary_emb_base)
self.softmax_scale = self.head_size ** (-0.5) self.softmax_scale = self.head_size ** (-0.5)
if process_group is None: self.query_key_value = load_qkv(
self.query_key_value = FastLinear(hidden_size, 3 * hidden_size) config,
self.dense = FastLinear(hidden_size, hidden_size) prefix=f"{prefix}.query_key_value",
else: weights=weights,
self.num_heads = self.num_heads // process_group.size() num_heads=self.num_heads,
self.query_key_value = TensorParallelColumnLinear( head_size=self.head_size,
hidden_size, hidden_size=self.hidden_size,
3 * hidden_size,
process_group=process_group,
) )
self.dense = TensorParallelRowLinear( self.dense = load_row(
hidden_size, hidden_size, process_group=process_group, reduce=reduce config, prefix=f"{prefix}.dense", weights=weights, bias=True
)
def shuffle_qkv_dims(self):
"""Swap dims to avoid an additional permute"""
self.query_key_value.weight = torch.nn.Parameter(
self.query_key_value.weight.view(
self.num_heads, 3, self.head_size, self.hidden_size
)
.permute(1, 0, 2, 3)
.reshape(-1, self.hidden_size)
)
self.query_key_value.bias = torch.nn.Parameter(
self.query_key_value.bias.view(self.num_heads, 3, self.head_size)
.permute(1, 0, 2)
.reshape(-1)
) )
def forward( def forward(
@ -162,10 +182,9 @@ class FlashNeoxAttention(torch.nn.Module):
class FlashMLP(nn.Module): class FlashMLP(nn.Module):
def __init__( def __init__(self, config, prefix, weights):
self, act, hidden_size, intermediate_size, process_group=None, reduce=True
):
super().__init__() super().__init__()
act = config.hidden_act
self.act = ( self.act = (
ACT2FN[act] ACT2FN[act]
if "gelu" not in act if "gelu" not in act
@ -177,22 +196,12 @@ class FlashMLP(nn.Module):
) )
) )
if process_group is None: self.dense_h_to_4h = TensorParallelColumnLinear.load(
self.dense_h_to_4h = FastLinear(hidden_size, intermediate_size) config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=True
self.dense_4h_to_h = FastLinear(intermediate_size, hidden_size)
else:
self.dense_h_to_4h = TensorParallelColumnLinear(
hidden_size,
intermediate_size,
process_group=process_group,
) )
self.dense_4h_to_h = TensorParallelRowLinear( self.dense_4h_to_h = load_row(
intermediate_size, config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=True
hidden_size,
process_group=process_group,
reduce=reduce,
) )
self.process_group = process_group
def forward(self, hidden_states): def forward(self, hidden_states):
hidden_states = self.dense_h_to_4h(hidden_states) hidden_states = self.dense_h_to_4h(hidden_states)
@ -202,38 +211,28 @@ class FlashMLP(nn.Module):
class FlashNeoXLayer(nn.Module): class FlashNeoXLayer(nn.Module):
def __init__( def __init__(self, layer_id, config, weights):
self,
num_heads,
act,
hidden_size,
intermediate_size,
rotary_pct,
rotary_emb_base,
layer_norm_eps,
use_parallel_residual,
process_group=None,
):
super().__init__() super().__init__()
self.use_parallel_residual = use_parallel_residual
self.input_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps) layer_norm_eps = config.layer_norm_eps
self.post_attention_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps)
prefix = f"gpt_neox.layers.{layer_id}"
self.use_parallel_residual = config.use_parallel_residual
self.input_layernorm = FastLayerNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=layer_norm_eps
)
self.post_attention_layernorm = FastLayerNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=layer_norm_eps,
)
self.attention = FlashNeoxAttention( self.attention = FlashNeoxAttention(
num_heads, config, prefix=f"{prefix}.attention", weights=weights
hidden_size,
rotary_pct,
rotary_emb_base,
process_group,
reduce=not use_parallel_residual,
) )
self.mlp = FlashMLP(
act, self.mlp = FlashMLP(config, prefix=f"{prefix}.mlp", weights=weights)
hidden_size, self.process_group = weights.process_group
intermediate_size,
process_group,
reduce=not use_parallel_residual,
)
self.process_group = process_group
def forward( def forward(
self, self,
@ -266,8 +265,6 @@ class FlashNeoXLayer(nn.Module):
mlp_output = self.mlp(ln2_hidden_states) mlp_output = self.mlp(ln2_hidden_states)
intermediate = mlp_output + attn_output intermediate = mlp_output + attn_output
# Only reduce once and after the addition instead of once per layer
if self.process_group is not None:
torch.distributed.all_reduce(intermediate, group=self.process_group) torch.distributed.all_reduce(intermediate, group=self.process_group)
return intermediate + hidden_states, None return intermediate + hidden_states, None
@ -302,42 +299,24 @@ class FlashGPTNeoXPreTrainedModel(PreTrainedModel):
class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
def __init__(self, config, process_group=None): def __init__(self, config, weights):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
self.tp_embeddings = False
if process_group is not None:
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
if config.vocab_size % self.tp_world_size == 0:
self.tp_embeddings = True
if self.tp_embeddings:
self.embed_in = TensorParallelEmbedding( self.embed_in = TensorParallelEmbedding(
config.vocab_size, config.hidden_size, process_group=process_group prefix="gpt_neox.embed_in", weights=weights
) )
else:
self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
FlashNeoXLayer( FlashNeoXLayer(layer_id, config, weights)
config.num_attention_heads, for layer_id in range(config.num_hidden_layers)
config.hidden_act,
config.hidden_size,
config.intermediate_size,
config.rotary_pct,
config.rotary_emb_base,
config.layer_norm_eps,
config.use_parallel_residual,
process_group,
)
for _ in range(config.num_hidden_layers)
] ]
) )
self.final_layer_norm = FastLayerNorm( self.final_layer_norm = FastLayerNorm.load(
config.hidden_size, eps=config.layer_norm_eps prefix="gpt_neox.final_layer_norm",
weights=weights,
eps=config.layer_norm_eps,
) )
self.gradient_checkpointing = False self.gradient_checkpointing = False
@ -345,29 +324,6 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
self.head_size = self.layers[0].attention.head_size self.head_size = self.layers[0].attention.head_size
self.num_heads = self.layers[0].attention.num_heads self.num_heads = self.layers[0].attention.num_heads
def post_load_weights(self, quantize: Optional[str] = None):
if isinstance(self.embed_in, TensorParallelEmbedding):
self.embed_in.add_null_idx()
for layer in self.layers:
layer: FlashNeoXLayer
layer.attention.shuffle_qkv_dims()
layer.attention.query_key_value.prepare_weights(quantize)
layer.attention.dense.prepare_weights(quantize)
layer.mlp.dense_h_to_4h.prepare_weights(quantize)
layer.mlp.dense_4h_to_h.prepare_weights(quantize)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# Pop here as we will replace the layer in our own logic and don't want from_pretrained
# to do it for us
load_in_8bit = kwargs.pop("load_in_8bit", False)
model = super(FlashGPTNeoXModel, cls).from_pretrained(
pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs
)
model.post_load_weights("bitsandbytes" if load_in_8bit else None)
return model
def forward( def forward(
self, self,
input_ids, input_ids,
@ -435,42 +391,13 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
def __init__(self, config, process_group=None): def __init__(self, config, weights):
super().__init__(config) super().__init__(config)
self.gpt_neox = FlashGPTNeoXModel(config, weights)
self.process_group = process_group self.embed_out = TensorParallelHead.load(
if self.process_group is not None: config, prefix="embed_out", weights=weights
self.world_size = self.process_group.size()
else:
self.world_size = 1
self.gpt_neox = FlashGPTNeoXModel(config, process_group)
if self.gpt_neox.tp_embeddings:
self.embed_out = FastLinear(
config.hidden_size,
config.vocab_size // process_group.size(),
bias=False,
) )
else:
self.embed_out = FastLinear(
config.hidden_size, config.vocab_size, bias=False
)
def post_load_weights(self, quantize: Optional[str] = None):
self.gpt_neox.post_load_weights(quantize)
self.embed_out.prepare_weights()
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# Pop here as we will replace the layer in our own logic and don't want from_pretrained
# to do it for us
load_in_8bit = kwargs.pop("load_in_8bit", False)
model = super(FlashGPTNeoXForCausalLM, cls).from_pretrained(
pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs
)
model.post_load_weights("bitsandbytes" if load_in_8bit else None)
return model
def forward( def forward(
self, self,
@ -495,12 +422,4 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]
logits = self.embed_out(hidden_states) logits = self.embed_out(hidden_states)
if self.gpt_neox.tp_embeddings:
# Logits are sharded, so we need to gather them
world_logits = [torch.empty_like(logits) for _ in range(self.world_size)]
torch.distributed.all_gather(world_logits, logits, group=self.process_group)
world_logits = torch.cat(world_logits, dim=1)
return world_logits, present
return logits, present return logits, present

View File

@ -1,5 +1,3 @@
import os
import torch import torch
import torch.distributed import torch.distributed
@ -12,15 +10,31 @@ from typing import Optional
import flash_attn_cuda import flash_attn_cuda
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
FastLinear,
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
TensorParallelHead,
FastLayerNorm, FastLayerNorm,
PositionRotaryEmbedding, PositionRotaryEmbedding,
get_linear,
) )
def load_row(config, prefix: str, weights, bias: bool):
weight = weights.get_sharded(f"{prefix}.weight", dim=1)
if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process
bias = weights.get_tensor(f"{prefix}.bias")
else:
bias = None
linear = get_linear(weight, bias, config.quantize)
if config.parallel_attn:
return linear
else:
return TensorParallelRowLinear(linear, process_group=weights.process_group)
class RWConfig(PretrainedConfig): class RWConfig(PretrainedConfig):
attribute_map = { attribute_map = {
"num_hidden_layers": "n_layer", "num_hidden_layers": "n_layer",
@ -85,44 +99,31 @@ class RWConfig(PretrainedConfig):
class FlashRWAttention(torch.nn.Module): class FlashRWAttention(torch.nn.Module):
def __init__( def __init__(
self, self,
num_heads, config,
num_heads_kv, prefix,
hidden_size, weights,
bias,
process_group=None,
reduce=True,
): ):
super().__init__() super().__init__()
self.num_heads = num_heads self.num_heads = config.n_head
self.num_heads_kv = num_heads_kv self.num_heads_kv = config.n_head_kv
self.hidden_size = hidden_size self.hidden_size = config.hidden_size
self.head_size = hidden_size // num_heads self.head_size = self.hidden_size // self.num_heads
self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000) self.rotary_emb = PositionRotaryEmbedding.static(
dim=self.head_size, base=10000.0, device=weights.device
)
self.softmax_scale = self.head_size ** (-0.5) self.softmax_scale = self.head_size ** (-0.5)
self.num_heads = self.num_heads // weights.process_group.size()
if process_group is None: self.query_key_value = TensorParallelColumnLinear.load(
self.query_key_value = FastLinear( config,
hidden_size, prefix=f"{prefix}.query_key_value",
self.head_size * (self.num_heads + 2 * self.num_heads_kv), weights=weights,
bias=bias, bias=config.bias,
) )
self.dense = FastLinear(hidden_size, hidden_size, bias=bias) self.dense = load_row(
else: config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
self.query_key_value = TensorParallelColumnLinear(
hidden_size,
self.head_size * (self.num_heads + 2 * self.num_heads_kv),
bias=bias,
process_group=process_group,
) )
self.dense = TensorParallelRowLinear(
hidden_size,
hidden_size,
bias=bias,
process_group=process_group,
reduce=reduce,
)
self.num_heads = self.num_heads // process_group.size()
def forward( def forward(
self, self,
@ -212,58 +213,49 @@ class FlashRWAttention(torch.nn.Module):
class FlashRWLargeAttention(torch.nn.Module): class FlashRWLargeAttention(torch.nn.Module):
def __init__( def __init__(
self, self,
num_heads, config,
num_heads_kv, prefix,
hidden_size, weights,
bias,
process_group=None,
reduce=True,
): ):
super().__init__() super().__init__()
hidden_size = config.hidden_size
num_heads = config.n_head
num_heads_kv = config.n_head_kv
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.head_size = hidden_size // num_heads self.head_size = hidden_size // num_heads
self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000) self.rotary_emb = PositionRotaryEmbedding.static(
self.head_size, base=10000.0, device=weights.device
)
self.softmax_scale = self.head_size ** (-0.5) self.softmax_scale = self.head_size ** (-0.5)
self.num_groups = num_heads // (num_heads_kv * 2) self.num_groups = num_heads // (num_heads_kv * 2)
self.num_heads = num_heads // self.num_groups self.num_heads = num_heads // self.num_groups
self.num_heads_kv = num_heads_kv // self.num_groups self.num_heads_kv = num_heads_kv // self.num_groups
process_group = weights.process_group
if process_group is None:
self.query_key_value = FastLinear(
hidden_size,
self.num_groups
* self.head_size
* (self.num_heads + 2 * self.num_heads_kv),
bias=bias,
)
self.dense = FastLinear(hidden_size, hidden_size, bias=bias)
else:
if process_group.size() > self.num_groups: if process_group.size() > self.num_groups:
raise NotImplementedError( raise NotImplementedError(
f"Tensor Parallelism is not implemented for world_size > n groups" f"Tensor Parallelism is not implemented for world_size > n groups"
) )
if self.num_groups % process_group.size() != 0:
self.query_key_value = TensorParallelColumnLinear( raise NotImplementedError(
hidden_size, f"Tensor Parallelism is not implemented for {self.num_groups} not divisible by {process_group.size()}"
self.num_groups
* self.head_size
* (self.num_heads + 2 * self.num_heads_kv),
bias=bias,
process_group=process_group,
) )
self.dense = TensorParallelRowLinear(
hidden_size,
hidden_size,
bias=bias,
process_group=process_group,
reduce=reduce,
)
self.num_groups = self.num_groups // process_group.size() self.num_groups = self.num_groups // process_group.size()
self.query_key_value = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.query_key_value",
weights=weights,
bias=config.bias,
)
self.dense = load_row(
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
)
def forward( def forward(
self, self,
hidden_states, hidden_states,
@ -359,28 +351,16 @@ class FlashRWLargeAttention(torch.nn.Module):
class FlashMLP(nn.Module): class FlashMLP(nn.Module):
def __init__(self, hidden_size, bias, process_group=None, reduce=True): def __init__(self, config, prefix, weights):
super().__init__() super().__init__()
self.act = torch.nn.functional.gelu self.act = torch.nn.functional.gelu
if process_group is None: self.dense_h_to_4h = TensorParallelColumnLinear.load(
self.dense_h_to_4h = FastLinear(hidden_size, 4 * hidden_size, bias=bias) config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=config.bias
self.dense_4h_to_h = FastLinear(4 * hidden_size, hidden_size, bias=bias)
else:
self.dense_h_to_4h = TensorParallelColumnLinear(
hidden_size,
4 * hidden_size,
bias=bias,
process_group=process_group,
) )
self.dense_4h_to_h = TensorParallelRowLinear( self.dense_4h_to_h = load_row(
4 * hidden_size, config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=config.bias
hidden_size,
bias=bias,
process_group=process_group,
reduce=reduce,
) )
self.process_group = process_group
def forward(self, hidden_states): def forward(self, hidden_states):
hidden_states = self.dense_h_to_4h(hidden_states) hidden_states = self.dense_h_to_4h(hidden_states)
@ -392,38 +372,44 @@ class FlashMLP(nn.Module):
class FlashRWLayer(nn.Module): class FlashRWLayer(nn.Module):
def __init__( def __init__(
self, self,
num_heads, layer_id,
num_heads_kv, config,
hidden_size, weights,
bias,
layer_norm_eps,
parallel_attn,
process_group=None,
): ):
super().__init__() super().__init__()
parallel_attn = config.parallel_attn
self.parallel_attn = parallel_attn self.parallel_attn = parallel_attn
self.input_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps) prefix = f"transformer.h.{layer_id}"
self.input_layernorm = FastLayerNorm.load(
prefix=f"{prefix}.input_layernorm",
weights=weights,
eps=config.layer_norm_epsilon,
)
self.self_attention = FlashRWAttention( self.self_attention = FlashRWAttention(
num_heads, config,
num_heads_kv, prefix=f"{prefix}.self_attention",
hidden_size, weights=weights,
bias,
process_group=process_group,
reduce=False,
) )
self.post_attention_layernorm = ( self.post_attention_layernorm = (
FastLayerNorm(hidden_size, eps=layer_norm_eps) FastLayerNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.layer_norm_epsilon,
)
if not parallel_attn if not parallel_attn
else None else None
) )
self.mlp = FlashMLP( self.mlp = FlashMLP(
hidden_size, bias, process_group=process_group, reduce=False config,
prefix=f"{prefix}.mlp",
weights=weights,
) )
self.process_group = process_group self.process_group = weights.process_group
def forward( def forward(
self, self,
@ -454,8 +440,6 @@ class FlashRWLayer(nn.Module):
mlp_output = self.mlp(ln_hidden_states) mlp_output = self.mlp(ln_hidden_states)
intermediate = mlp_output + attn_output intermediate = mlp_output + attn_output
# Only reduce once and after the addition instead of once per layer
if self.process_group is not None:
torch.distributed.all_reduce(intermediate, group=self.process_group) torch.distributed.all_reduce(intermediate, group=self.process_group)
return intermediate, residual return intermediate, residual
@ -483,33 +467,30 @@ class FlashRWLayer(nn.Module):
class FlashRWLargeLayer(nn.Module): class FlashRWLargeLayer(nn.Module):
def __init__( def __init__(self, layer_id, config, weights):
self,
num_heads,
num_heads_kv,
hidden_size,
bias,
layer_norm_eps,
process_group=None,
):
super().__init__() super().__init__()
self.ln_attn = FastLayerNorm(hidden_size, eps=layer_norm_eps) prefix = f"transformer.h.{layer_id}"
self.ln_mlp = FastLayerNorm(hidden_size, eps=layer_norm_eps) self.ln_attn = FastLayerNorm.load(
prefix=f"{prefix}.ln_attn",
weights=weights,
eps=config.layer_norm_epsilon,
)
self.ln_mlp = FastLayerNorm.load(
prefix=f"{prefix}.ln_mlp",
weights=weights,
eps=config.layer_norm_epsilon,
)
self.self_attention = FlashRWLargeAttention( self.self_attention = FlashRWLargeAttention(
num_heads, config,
num_heads_kv, prefix=f"{prefix}.self_attention",
hidden_size, weights=weights,
bias,
process_group=process_group,
reduce=False,
) )
assert config.parallel_attn, "This version doesn't support non parallel_attn"
self.mlp = FlashMLP( self.mlp = FlashMLP(config, prefix=f"{prefix}.mlp", weights=weights)
hidden_size, bias, process_group=process_group, reduce=False
)
self.process_group = process_group self.process_group = weights.process_group
def forward( def forward(
self, self,
@ -543,8 +524,6 @@ class FlashRWLargeLayer(nn.Module):
intermediate = attn_output + mlp_output intermediate = attn_output + mlp_output
# Only reduce once and after the addition instead of once per layer
if self.process_group is not None:
torch.distributed.all_reduce(intermediate, group=self.process_group) torch.distributed.all_reduce(intermediate, group=self.process_group)
return intermediate, residual return intermediate, residual
@ -555,37 +534,18 @@ class FlashRWPreTrainedModel(PreTrainedModel):
class FlashRWModel(FlashRWPreTrainedModel): class FlashRWModel(FlashRWPreTrainedModel):
def __init__(self, config, process_group=None): def __init__(self, config, weights):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
self.tp_embeddings = False
if process_group is not None:
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
if config.vocab_size % self.tp_world_size == 0:
self.tp_embeddings = True
if self.tp_embeddings:
self.word_embeddings = TensorParallelEmbedding( self.word_embeddings = TensorParallelEmbedding(
config.vocab_size, config.hidden_size, process_group=process_group prefix="transformer.word_embeddings", weights=weights
) )
else:
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
if config.model_type == "RefinedWebModel": if config.model_type == "RefinedWebModel":
self.h = nn.ModuleList( self.h = nn.ModuleList(
[ [
FlashRWLayer( FlashRWLayer(layer_id, config, weights)
config.n_head, for layer_id in range(config.num_hidden_layers)
config.n_head_kv,
config.hidden_size,
config.bias,
config.layer_norm_epsilon,
config.parallel_attn,
process_group,
)
for _ in range(config.num_hidden_layers)
] ]
) )
self.cache_size = ( self.cache_size = (
@ -596,15 +556,8 @@ class FlashRWModel(FlashRWPreTrainedModel):
elif config.model_type == "RefinedWeb": elif config.model_type == "RefinedWeb":
self.h = nn.ModuleList( self.h = nn.ModuleList(
[ [
FlashRWLargeLayer( FlashRWLargeLayer(layer_id, config, weights)
config.n_head, for layer_id in range(config.num_hidden_layers)
config.n_head_kv,
config.hidden_size,
config.bias,
config.layer_norm_epsilon,
process_group,
)
for _ in range(config.num_hidden_layers)
] ]
) )
self.cache_size = ( self.cache_size = (
@ -617,31 +570,13 @@ class FlashRWModel(FlashRWPreTrainedModel):
f"model_type {config.model_type} is not supported." f"model_type {config.model_type} is not supported."
) )
self.ln_f = FastLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.ln_f = FastLayerNorm.load(
prefix="transformer.ln_f",
self.head_size = self.h[0].self_attention.head_size weights=weights,
eps=config.layer_norm_epsilon,
def post_load_weights(self, quantize: Optional[str] = None):
if isinstance(self.word_embeddings, TensorParallelEmbedding):
self.word_embeddings.add_null_idx()
for layer in self.h:
layer: FlashRWLayer
layer.self_attention.query_key_value.prepare_weights(quantize)
layer.self_attention.dense.prepare_weights(quantize)
layer.mlp.dense_h_to_4h.prepare_weights(quantize)
layer.mlp.dense_4h_to_h.prepare_weights(quantize)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# Pop here as we will replace the layer in our own logic and don't want from_pretrained
# to do it for us
load_in_8bit = kwargs.pop("load_in_8bit", False)
model = super(FlashRWModel, cls).from_pretrained(
pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs
) )
model.post_load_weights("bitsandbytes" if load_in_8bit else None) self.head_size = self.h[0].self_attention.head_size
return model
def forward( def forward(
self, self,
@ -708,40 +643,14 @@ class FlashRWModel(FlashRWPreTrainedModel):
class FlashRWForCausalLM(FlashRWPreTrainedModel): class FlashRWForCausalLM(FlashRWPreTrainedModel):
def __init__(self, config, process_group=None): def __init__(self, config, weights):
super().__init__(config) super().__init__(config)
self.process_group = process_group self.transformer = FlashRWModel(config, weights)
if self.process_group is not None:
self.world_size = self.process_group.size()
else:
self.world_size = 1
self.transformer = FlashRWModel(config, process_group) self.lm_head = TensorParallelHead.load(
config, prefix="lm_head", weights=weights
if self.transformer.tp_embeddings:
self.lm_head = FastLinear(
config.hidden_size,
config.vocab_size // process_group.size(),
bias=False,
) )
else:
self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False)
def post_load_weights(self, quantize: Optional[str] = None):
self.transformer.post_load_weights(quantize)
self.lm_head.prepare_weights()
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# Pop here as we will replace the layer in our own logic and don't want from_pretrained
# to do it for us
load_in_8bit = kwargs.pop("load_in_8bit", False)
model = super(FlashRWForCausalLM, cls).from_pretrained(
pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs
)
model.post_load_weights("bitsandbytes" if load_in_8bit else None)
return model
def forward( def forward(
self, self,
@ -766,12 +675,4 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
if self.transformer.tp_embeddings:
# Logits are sharded, so we need to gather them
world_logits = [torch.empty_like(logits) for _ in range(self.world_size)]
torch.distributed.all_gather(world_logits, logits, group=self.process_group)
world_logits = torch.cat(world_logits, dim=1)
return world_logits, present
return logits, present return logits, present

View File

@ -8,38 +8,141 @@ from typing import Optional
# Flash attention imports # Flash attention imports
import flash_attn_cuda import flash_attn_cuda
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
FastLinear,
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelHead,
TensorParallelEmbedding, TensorParallelEmbedding,
FastLayerNorm, FastLayerNorm,
get_linear,
) )
class FlashMQAttention(torch.nn.Module): def load_multi_mqa(
def __init__( config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
self, ):
num_heads, if any("c_attn" in k for k in weights.routing.keys()):
slice_ = weights._get_slice(f"{prefix}.c_attn.weight")
shape = slice_.get_shape()
world_size = weights.process_group.size()
rank = weights.process_group.rank()
if config.transpose:
block_size = (shape[1] - 2 * head_size) // world_size
start = rank * block_size
stop = (rank + 1) * block_size
assert (shape[1] - 2 * head_size) % world_size == 0
q_tensor = slice_[:, start:stop]
kv_tensor = slice_[:, -2 * head_size :]
weight = torch.cat([q_tensor, kv_tensor], dim=1).T
else:
block_size = (shape[0] - 2 * head_size) // world_size
start = rank * block_size
stop = (rank + 1) * block_size
assert (shape[0] - 2 * head_size) % world_size == 0
q_tensor = slice_[start:stop]
kv_tensor = slice_[-2 * head_size :]
weight = torch.cat([q_tensor, kv_tensor], dim=0)
if bias:
slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
shape = slice_.get_shape()
block_size = (shape[0] - 2 * head_size) // world_size
assert (shape[0] - 2 * head_size) % world_size == 0
q_tensor = slice_[start:stop]
start = rank * block_size
stop = (rank + 1) * block_size
q_tensor = slice_[start:stop]
kv_tensor = slice_[-2 * head_size :]
bias = torch.cat([q_tensor, kv_tensor], dim=0)
else:
if config.transpose:
w = [
weights.get_sharded(f"{prefix}.q_attn.weight", dim=1).T,
weights.get_tensor(f"{prefix}.kv_attn.weight").T,
]
weight = torch.cat(w, dim=0)
else:
w = [
weights.get_sharded(f"{prefix}.q_attn.weight", dim=0),
weights.get_tensor(f"{prefix}.kv_attn.weight"),
]
weight = torch.cat(w, dim=1)
if bias:
b = [
weights.get_sharded(f"{prefix}.q_attn.bias", dim=0),
weights.get_tensor(f"{prefix}.kv_attn.bias"),
]
bias = torch.cat(b, dim=0)
else:
bias = None
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
assert list(weight.shape) == [
(num_heads + 2) * head_size,
hidden_size, hidden_size,
process_group=None, ], f"{weight.shape} != {[(num_heads + 2) * head_size, hidden_size]}"
): if bias is not None:
bias = bias.to(dtype=weights.dtype).to(device=weights.device)
assert list(bias.shape) == [
(num_heads + 2) * head_size
], f"{weight.shape} != {[(num_heads + 2) * head_size]}"
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
def load_col(config, prefix: str, weights, bias: bool):
if config.transpose:
weight = weights.get_sharded(f"{prefix}.weight", dim=1).T
else:
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
if bias:
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
else:
bias = None
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
def load_row(config, prefix: str, weights, bias: bool):
if config.transpose:
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T
else:
weight = weights.get_sharded(f"{prefix}.weight", dim=1)
if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process
bias = weights.get_tensor(f"{prefix}.bias")
else:
bias = None
return TensorParallelRowLinear(
get_linear(weight, bias, config.quantize), process_group=weights.process_group
)
class FlashMQAttention(torch.nn.Module):
def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
num_heads = config.num_attention_heads
hidden_size = config.hidden_size
self.num_heads = num_heads self.num_heads = num_heads
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.head_size = hidden_size // num_heads self.head_size = hidden_size // num_heads
assert self.num_heads % weights.process_group.size() == 0
self.num_heads = self.num_heads // weights.process_group.size()
self.softmax_scale = self.head_size ** (-0.5) self.softmax_scale = self.head_size ** (-0.5)
if process_group is None: self.c_attn = load_multi_mqa(
self.c_attn = FastLinear(hidden_size, hidden_size + 2 * self.head_size) config,
self.c_proj = FastLinear(hidden_size, hidden_size) prefix=prefix,
else: weights=weights,
self.num_heads = self.num_heads // process_group.size() bias=True,
self.c_attn = FastLinear(hidden_size, self.head_size * (self.num_heads + 2)) head_size=self.head_size,
self.c_proj = TensorParallelRowLinear( hidden_size=hidden_size,
hidden_size, num_heads=self.num_heads,
hidden_size, )
process_group=process_group, self.c_proj = load_row(
config, prefix=f"{prefix}.c_proj", weights=weights, bias=True
) )
def forward( def forward(
@ -121,8 +224,9 @@ class FlashMQAttention(torch.nn.Module):
class MLP(nn.Module): class MLP(nn.Module):
def __init__(self, act, hidden_size, intermediate_size, process_group=None): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
act = config.activation_function
self.act = ( self.act = (
ACT2FN[act] ACT2FN[act]
if "gelu" not in act if "gelu" not in act
@ -134,19 +238,11 @@ class MLP(nn.Module):
) )
) )
if process_group is None: self.c_fc = load_col(
self.c_fc = FastLinear(hidden_size, intermediate_size) config, prefix=f"{prefix}.c_fc", weights=weights, bias=True
self.c_proj = FastLinear(intermediate_size, hidden_size)
else:
self.c_fc = TensorParallelColumnLinear(
hidden_size,
intermediate_size,
process_group=process_group,
) )
self.c_proj = TensorParallelRowLinear( self.c_proj = load_row(
intermediate_size, config, prefix=f"{prefix}.c_proj", weights=weights, bias=True
hidden_size,
process_group=process_group,
) )
def forward(self, hidden_states): def forward(self, hidden_states):
@ -157,28 +253,24 @@ class MLP(nn.Module):
class Block(nn.Module): class Block(nn.Module):
def __init__( def __init__(self, layer_id, config, weights):
self,
num_heads,
act,
hidden_size,
intermediate_size,
layer_norm_eps,
process_group=None,
):
super().__init__() super().__init__()
self.ln_1 = FastLayerNorm(hidden_size, eps=layer_norm_eps) prefix = f"transformer.h.{layer_id}"
self.ln_2 = FastLayerNorm(hidden_size, eps=layer_norm_eps) self.ln_1 = FastLayerNorm.load(
prefix=f"{prefix}.ln_1", weights=weights, eps=config.layer_norm_epsilon
)
self.ln_2 = FastLayerNorm.load(
prefix=f"{prefix}.ln_2", weights=weights, eps=config.layer_norm_epsilon
)
self.attn = FlashMQAttention( self.attn = FlashMQAttention(
num_heads, prefix=f"{prefix}.attn",
hidden_size, config=config,
process_group, weights=weights,
) )
self.mlp = MLP( self.mlp = MLP(
act, prefix=f"{prefix}.mlp",
hidden_size, config=config,
intermediate_size, weights=weights,
process_group,
) )
def forward( def forward(
@ -210,66 +302,39 @@ class Block(nn.Module):
class FlashSantacoderModel(nn.Module): class FlashSantacoderModel(nn.Module):
def __init__(self, config, process_group=None): def __init__(self, config, weights):
super().__init__() super().__init__()
self.config = config self.config = config
self.process_group = process_group self.process_group = weights.process_group
self.tp_embeddings = False
if process_group is not None:
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
if config.vocab_size % self.tp_world_size == 0:
self.tp_embeddings = True
if self.tp_embeddings:
self.wte = TensorParallelEmbedding( self.wte = TensorParallelEmbedding(
config.vocab_size, prefix="transformer.wte",
config.hidden_size, weights=weights,
reduce=False, reduce=False,
process_group=process_group,
) )
self.wpe = TensorParallelEmbedding( self.wpe = TensorParallelEmbedding(
config.max_position_embeddings, prefix="transformer.wpe",
config.hidden_size, weights=weights,
reduce=False, reduce=False,
process_group=process_group,
) )
else:
self.wte = nn.Embedding(config.vocab_size, config.hidden_size)
self.wpe = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.h = nn.ModuleList( self.h = nn.ModuleList(
[ [
Block( Block(
config.num_attention_heads, layer_id,
config.activation_function, config,
config.hidden_size, weights,
config.n_inner
if config.n_inner is not None
else 4 * config.hidden_size,
config.layer_norm_epsilon,
process_group,
) )
for _ in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
) )
self.ln_f = FastLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.ln_f = FastLayerNorm.load(
prefix="transformer.ln_f", weights=weights, eps=config.layer_norm_epsilon
)
self.head_size = self.h[0].attn.head_size self.head_size = self.h[0].attn.head_size
self.num_heads = self.h[0].attn.num_heads self.num_heads = self.h[0].attn.num_heads
def post_load_weights(self, quantize: Optional[str] = None):
if self.tp_embeddings:
self.wte.add_null_idx()
self.wpe.add_null_idx()
for layer in self.h:
layer: Block
layer.attn.c_attn.prepare_weights(quantize)
layer.attn.c_proj.prepare_weights(quantize)
layer.mlp.c_fc.prepare_weights(quantize)
layer.mlp.c_proj.prepare_weights(quantize)
def forward( def forward(
self, self,
input_ids, input_ids,
@ -281,7 +346,6 @@ class FlashSantacoderModel(nn.Module):
pre_allocate_past_size: Optional[int] = None, pre_allocate_past_size: Optional[int] = None,
): ):
hidden_states = self.wte(input_ids) + self.wpe(position_ids) hidden_states = self.wte(input_ids) + self.wpe(position_ids)
if self.tp_embeddings:
torch.distributed.all_reduce(hidden_states, group=self.process_group) torch.distributed.all_reduce(hidden_states, group=self.process_group)
# Prefill # Prefill
@ -331,23 +395,12 @@ class FlashSantacoderModel(nn.Module):
class FlashSantacoderForCausalLM(nn.Module): class FlashSantacoderForCausalLM(nn.Module):
def __init__(self, config, process_group=None): def __init__(self, config, weights):
super().__init__() super().__init__()
self.transformer = FlashSantacoderModel(config, weights)
self.transformer = FlashSantacoderModel(config, process_group) self.lm_head = TensorParallelHead.load(
config, prefix="transformer.wte", weights=weights
if self.transformer.tp_embeddings:
self.lm_head = FastLinear(
config.hidden_size,
config.vocab_size // process_group.size(),
bias=False,
) )
else:
self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False)
def post_load_weights(self, quantize: Optional[str] = None):
self.transformer.post_load_weights(quantize)
self.lm_head.prepare_weights()
def forward( def forward(
self, self,
@ -372,29 +425,4 @@ class FlashSantacoderForCausalLM(nn.Module):
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
if self.transformer.tp_embeddings:
# Logits are sharded, so we need to gather them
if logits.shape[0] == 1:
# Fast path when batch size is 1
world_logits = logits.new_empty(
(logits.shape[1] * self.transformer.tp_world_size)
)
torch.distributed.all_gather_into_tensor(
world_logits, logits.view(-1), group=self.transformer.process_group
)
world_logits = world_logits.view(1, -1)
else:
# We cannot use all_gather_into_tensor as it only support concatenating on the first dim
world_logits = [
torch.empty_like(logits)
for _ in range(self.transformer.tp_world_size)
]
torch.distributed.all_gather(
world_logits, logits, group=self.transformer.process_group
)
world_logits = torch.cat(world_logits, dim=1)
return world_logits, present
return logits, present return logits, present

View File

@ -0,0 +1,794 @@
# coding=utf-8
# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch GPTNeoX model."""
from typing import Optional, Tuple, Union
import os
import torch
import torch.distributed
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.activations import ACT2FN
from transformers.file_utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from transformers.modeling_utils import PreTrainedModel
from transformers import GPTNeoXConfig
from loguru import logger
from text_generation_server.utils.layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
TensorParallelHead,
)
CUSTOM_KERNELS_ENABLED = False
if not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True":
try:
from custom_kernels import fused_attention_cuda
CUSTOM_KERNELS_ENABLED = True
except ImportError:
pass
if not CUSTOM_KERNELS_ENABLED:
logger.warning("We're not using custom kernels.")
def make_causal_mask(
input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
) -> torch.BoolTensor:
"""
Make causal mask used for self-attention.
"""
batch_size, target_length = input_ids_shape
mask = torch.ones(
(target_length, target_length + past_key_values_length),
dtype=torch.bool,
device=device,
)
mask = mask.triu(1 + past_key_values_length)
expanded_mask = mask.unsqueeze(0).expand(
batch_size, target_length, target_length + past_key_values_length
)
return expanded_mask
def expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
"""
Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`.
"""
batch_size, src_length = mask.shape
tgt_length = tgt_length if tgt_length is not None else src_length
expanded_mask = ~(mask[:, None, :].to(torch.bool))
return expanded_mask.expand(batch_size, tgt_length, src_length)
def prepare_attn_mask(
attention_mask: torch.Tensor,
input_shape: Tuple[int, int],
past_key_values_length: int,
) -> torch.BoolTensor:
# create causal mask
# [batch_size, seq_length] -> [batch_size, tgt_length, src_length]
combined_attention_mask = None
device = attention_mask.device
_, src_length = input_shape
if src_length > 1:
combined_attention_mask = make_causal_mask(
input_shape, device=device, past_key_values_length=past_key_values_length
)
# [batch_size, seq_length] -> [batch_size, tgt_length, src_length]
expanded_attn_mask = expand_mask(attention_mask, tgt_length=src_length)
combined_attention_mask = (
expanded_attn_mask
if combined_attention_mask is None
else expanded_attn_mask | combined_attention_mask
)
return combined_attention_mask
class GPTNeoXPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
class GPTNeoXAttention(nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.num_attention_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_attention_heads
self.rotary_ndims = int(self.head_size * config.rotary_pct)
max_positions = config.max_position_embeddings
# ??? TODO
# self.register_buffer(
# "bias",
# torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
# 1, 1, max_positions, max_positions
# ),
# )
# self.register_buffer("masked_bias", torch.tensor(-1e9))
self.rotary_emb = RotaryEmbedding(
self.rotary_ndims,
config.max_position_embeddings,
base=config.rotary_emb_base,
)
self.rotary_emb.inv_freq = nn.Parameter(
weights.get_tensor(f"{prefix}.rotary_emb.inv_freq")
)
self.inv_norm_factor = 1.0 / torch.sqrt(
torch.tensor(self.head_size, dtype=torch.float32)
).to(torch.get_default_dtype())
assert self.num_attention_heads % weights.process_group.size() == 0
self.num_attention_heads = (
self.num_attention_heads // weights.process_group.size()
)
self.query_key_value = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.query_key_value", weights=weights, bias=True
)
self.dense = TensorParallelRowLinear.load(
config, prefix=f"{prefix}.dense", weights=weights, bias=True
)
def forward(
self,
hidden_states,
position_ids,
attention_mask,
head_mask=None,
layer_past=None,
use_cache=False,
output_attentions=False,
):
has_layer_past = layer_past is not None
# Compute QKV
# Attention heads [batch, seq_len, hidden_size]
# --> [batch, seq_len, (np * 3 * head_size)]
qkv = self.query_key_value(hidden_states)
# [batch, seq_len, (num_heads * 3 * head_size)]
# --> [batch, seq_len, num_heads, 3 * head_size]
new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size)
qkv = qkv.view(*new_qkv_shape).permute(0, 2, 1, 3)
# [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size]
query, key, value = qkv.split(self.head_size, -1)
# Compute token offset for rotary embeddings (when decoding)
seq_len = key.shape[-2]
if has_layer_past:
seq_len += layer_past[0].shape[-2]
# Compute rotary embeddings on rotary_ndims
query_rot = query[..., : self.rotary_ndims]
key_rot = key[..., : self.rotary_ndims]
query_rot, key_rot = self.rotary_emb(query_rot, key_rot, position_ids, seq_len)
query[..., : self.rotary_ndims] = query_rot
key[..., : self.rotary_ndims] = key_rot
if CUSTOM_KERNELS_ENABLED:
attn_output, present, attn_weights = fused_attention_cuda.forward(
query,
key,
value,
layer_past,
attention_mask,
head_mask,
self.inv_norm_factor,
self.num_attention_heads,
use_cache,
)
else:
# Cache QKV values
if has_layer_past:
past_key = layer_past[0]
past_value = layer_past[1]
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
present = (key, value) if use_cache else None
# Compute attention
attn_output, attn_weights = self._attn(
query, key, value, attention_mask, head_mask
)
# Reshape outputs
attn_output = self._merge_heads(
attn_output, self.num_attention_heads, self.head_size
)
attn_output = self.dense(attn_output)
outputs = (attn_output, present)
if output_attentions:
outputs += (attn_weights,)
return outputs
@classmethod
def _split_heads(cls, tensor, num_attention_heads, attn_head_size):
"""
Splits hidden dim into attn_head_size and num_attention_heads
"""
# tensor: [bs, seq_len, hidden_size]
new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size)
# -> [bs, seq_len, num_attention_heads, attn_head_size]
tensor = tensor.view(new_shape)
# -> [bs, num_attention_heads, seq_len, attn_head_size]
tensor = tensor.permute(0, 2, 1, 3)
return tensor
@classmethod
def _merge_heads(cls, tensor, num_attention_heads, attn_head_size):
"""
Merges attn_head_size dim and num_attn_heads dim into hidden dim
"""
# tensor [bs, num_attention_heads, seq_len, attn_head_size]
tensor = tensor.permute(0, 2, 1, 3).contiguous()
# -> [bs, seq_len, num_attention_heads, attn_head_size]
tensor = tensor.view(
tensor.size(0), tensor.size(1), num_attention_heads * attn_head_size
)
# -> [bs, seq_len, hidden_size]
return tensor
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
# q, k, v: [bs, num_attention_heads, seq_len, attn_head_size]
# compute causal mask from causal mask buffer
batch_size, num_attention_heads, query_length, attn_head_size = query.size()
key_length = key.size(-2)
query = query.view(
batch_size * num_attention_heads, query_length, attn_head_size
)
key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
attn_scores = torch.zeros(
1,
dtype=query.dtype,
device=key.device,
).expand(batch_size * num_attention_heads, query_length, key_length)
attn_scores = torch.baddbmm(
attn_scores,
query,
key.transpose(1, 2),
beta=1.0,
alpha=self.inv_norm_factor,
)
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
input_dtype = attn_scores.dtype
if input_dtype in [torch.float16, torch.bfloat16]:
attn_scores = attn_scores.to(torch.float)
attn_scores = torch.where(
attention_mask, torch.finfo(attn_scores.dtype).min, attn_scores
)
attn_scores = attn_scores.view(
batch_size, num_attention_heads, query_length, key_length
)
attn_weights = nn.functional.softmax(attn_scores, dim=-1)
attn_weights = attn_weights.to(value.dtype)
# Mask heads if we want to
if head_mask is not None:
attn_weights = attn_weights * head_mask
attn_output = torch.matmul(attn_weights, value)
return attn_output, attn_weights
class RotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings, base=10000, device=None):
super().__init__()
self.true_inv_freq = 1.0 / (
base ** (torch.arange(0, dim, 2).float().to(device) / dim)
)
self.register_buffer("inv_freq", self.true_inv_freq)
# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
self.cos_cached = None
self.sin_cached = None
@staticmethod
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
@staticmethod
def _create_cos_sin(inv_freq, max_position_embeddings, dtype, device):
t = torch.arange(
max_position_embeddings, device=inv_freq.device, dtype=inv_freq.dtype
)
freqs = torch.einsum("i,j->ij", t, inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
return emb.cos().to(device).to(dtype), emb.sin().to(device).to(dtype)
def forward(self, q, k, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if (
seq_len > self.max_seq_len_cached
or self.cos_cached is None
or self.sin_cached is None
):
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
self.cos_cached, self.sin_cached = self._create_cos_sin(
self.true_inv_freq, self.max_seq_len_cached, q.dtype, q.device
)
return rotary_forward(q, k, self.cos_cached, self.sin_cached, position_ids)
@torch.jit.script
def rotary_forward(q, k, cos, sin, position_ids):
cos = cos[position_ids].unsqueeze(1)
sin = sin[position_ids].unsqueeze(1)
chunk_size = q.shape[-1] // 2
q1, q2 = q.split(chunk_size, -1)
q_rotated = torch.cat((-q2, q1), dim=-1)
k1, k2 = k.split(chunk_size, -1)
k_rotated = torch.cat((-k2, k1), dim=-1)
q_embed = (q * cos) + (q_rotated * sin)
k_embed = (k * cos) + (k_rotated * sin)
return q_embed, k_embed
class GPTNeoXMLP(nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.act = (
ACT2FN[config.hidden_act]
if "gelu_fast" not in config.hidden_act
else lambda x: torch.nn.functional.gelu(x, approximate="tanh")
)
self.dense_h_to_4h = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=True
)
self.dense_4h_to_h = TensorParallelRowLinear.load(
config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=True
)
def forward(self, hidden_states):
hidden_states = self.dense_h_to_4h(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.dense_4h_to_h(hidden_states)
return hidden_states
class GPTNeoXLayer(nn.Module):
def __init__(self, layer_id, config, weights):
super().__init__()
self.use_parallel_residual = config.use_parallel_residual
self.input_layernorm = nn.LayerNorm.load(
prefix=f"gpt_neox.layers.{layer_id}.input_layernorm",
weights=weights,
eps=config.layer_norm_eps,
)
self.post_attention_layernorm = nn.LayerNorm.load(
prefix=f"gpt_neox.layers.{layer_id}.post_attention_layernorm",
weights=weights,
eps=config.layer_norm_eps,
)
self.attention = GPTNeoXAttention(
config, prefix=f"gpt_neox.layers.{layer_id}.attention", weights=weights
)
self.mlp = GPTNeoXMLP(
config, prefix=f"gpt_neox.layers.{layer_id}.mlp", weights=weights
)
def forward(
self,
hidden_states,
position_ids,
attention_mask=None,
head_mask=None,
use_cache=False,
layer_past=None,
output_attentions=False,
):
attention_layer_outputs = self.attention(
self.input_layernorm(hidden_states),
attention_mask=attention_mask,
position_ids=position_ids,
layer_past=layer_past,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
attn_output = attention_layer_outputs[
0
] # output_attn: attn_output, present, (attn_weights)
outputs = attention_layer_outputs[1:]
if self.use_parallel_residual:
# pseudocode:
# x = x + attn(ln1(x)) + mlp(ln2(x))
mlp_output = self.mlp(self.post_attention_layernorm(hidden_states))
hidden_states = mlp_output + attn_output + hidden_states
else:
# pseudocode:
# x = x + attn(ln1(x))
# x = x + mlp(ln2(x))
attn_output = attn_output + hidden_states
mlp_output = self.mlp(self.post_attention_layernorm(attn_output))
hidden_states = mlp_output + attn_output
if use_cache:
outputs = (
hidden_states,
) + outputs # hidden_states, present, (attn_weights)
else:
outputs = (hidden_states,) + outputs[1:] # hidden_states, (attn_weights)
return outputs
class GPTNeoXModel(GPTNeoXPreTrainedModel):
def __init__(self, config, weights):
super().__init__(config)
self.config = config
self.num_attention_heads = config.num_attention_heads
self.embed_in = TensorParallelEmbedding(
prefix="gpt_neox.embed_in", weights=weights
)
self.layers = nn.ModuleList(
[
GPTNeoXLayer(layer_id, config, weights)
for layer_id in range(config.num_hidden_layers)
]
)
self.final_layer_norm = nn.LayerNorm.load(
prefix="gpt_neox.final_layer_norm",
weights=weights,
eps=config.layer_norm_eps,
)
self.tp_world_size = weights.process_group.size()
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
position_ids=None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
r"""
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif input_ids is not None:
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
batch_size, seq_length = input_shape
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * self.config.num_hidden_layers)
else:
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_length, seq_length + past_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_in(input_ids)
hidden_states = inputs_embeds
# Attention mask.
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values[0] is not None:
past_key_values_length = past_key_values[0][0].shape[-1]
seq_length_with_past = seq_length_with_past + past_key_values_length
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), device=hidden_states.device
)
else:
attention_mask = attention_mask.to(hidden_states.device)
causal_mask = prepare_attn_mask(
attention_mask,
input_shape=(batch_size, seq_length),
past_key_values_length=past_key_values_length,
)
assert self.num_attention_heads % self.tp_world_size == 0
block_size = self.num_attention_heads // self.tp_world_size
causal_mask = torch.repeat_interleave(causal_mask, block_size, dim=0)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
presents = () if use_cache else None
all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for i, (layer, layer_past) in enumerate(zip(self.layers, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = layer(
hidden_states,
position_ids=position_ids,
attention_mask=causal_mask,
head_mask=head_mask[i],
layer_past=layer_past,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_attentions = all_attentions + (outputs[2 if use_cache else 1],)
hidden_states = self.final_layer_norm(hidden_states)
# Add last hidden state
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [hidden_states, presents, all_hidden_states, all_attentions]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_attentions,
)
class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
def __init__(self, config, weights):
super().__init__(config)
self.gpt_neox = GPTNeoXModel(config, weights)
self.embed_out = TensorParallelHead.load(
config, prefix="embed_out", weights=weights
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional tensors are
only required when the model is used as a decoder in a Sequence to Sequence model.
Contains pre-computed hidden-states (key and values in the self-attention blocks that can be used (see
`past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
`[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, GPTNeoXForCausalLM, GPTNeoXConfig
>>> import torch
>>> tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
>>> config = GPTNeoXConfig.from_pretrained("EleutherAI/gpt-neox-20b")
>>> config.is_decoder = True
>>> model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/gpt-neox-20b", config=config)
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
>>> prediction_logits = outputs.logits
```"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
outputs = self.gpt_neox(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
lm_logits = self.embed_out(hidden_states)
lm_loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# we are doing next-token prediction; shift prediction scores and input ids by one
shift_logits = lm_logits[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss()
lm_loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
)
if not return_dict:
output = (lm_logits,) + outputs[1:]
return ((lm_loss,) + output) if lm_loss is not None else output
return CausalLMOutputWithPast(
loss=lm_loss,
logits=lm_logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
**kwargs,
):
input_shape = input_ids.shape
# cut decoder_input_ids if past is used
if past_key_values and past_key_values[0] is not None:
input_ids = input_ids[:, -1:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"position_ids": position_ids,
}
)
return model_inputs
def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(
past_state.index_select(0, beam_idx)
for past_state in layer_past[:2]
)
+ layer_past[2:],
)
return reordered_past

View File

@ -0,0 +1,837 @@
# coding=utf-8
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch OPT model."""
import random
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from transformers.activations import ACT2FN
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel
from transformers import OPTConfig
from text_generation_server.utils.layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
TensorParallelHead,
)
EPS = 1e-5
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size,
dtype: torch.dtype,
device: torch.device,
past_key_values_length: int = 0,
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full(
(tgt_len, tgt_len),
torch.tensor(torch.finfo(dtype).min, device=device),
device=device,
)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat(
[
torch.zeros(
tgt_len, past_key_values_length, dtype=dtype, device=device
),
mask,
],
dim=-1,
)
return mask[None, None, :, :].expand(
bsz, 1, tgt_len, tgt_len + past_key_values_length
)
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(
inverted_mask.to(torch.bool), torch.finfo(dtype).min
)
class OPTLearnedPositionalEmbedding(nn.Module):
"""
This module learns positional embeddings up to a fixed maximum size.
"""
def __init__(self, weights):
super().__init__()
self.offset = 2
self.weight = nn.Parameter(
weights.get_tensor("model.decoder.embed_positions.weight")
)
def forward(
self, attention_mask: torch.LongTensor, past_key_values_length: int = 0
):
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
attention_mask = attention_mask.long()
# create positions depending on attention_mask
positions = (
torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask
).long() - 1
# cut positions if `past_key_values_length` is > 0
positions = positions[:, past_key_values_length:]
return torch.nn.functional.embedding(positions + self.offset, self.weight)
class OPTAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
config,
prefix,
weights,
is_decoder: bool = False,
bias: bool = True,
process_group=None,
):
super().__init__()
embed_dim = config.embed_dim
num_heads = config.num_attention_heads
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = config.dropout
self.head_dim = embed_dim // num_heads
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})."
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
process_group = weights.process_group
assert self.num_heads % process_group.size() == 0
self.num_heads = self.num_heads // process_group.size()
self.embed_dim = self.embed_dim // process_group.size()
self.q_proj = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.q_proj", weights=weights, bias=bias
)
self.k_proj = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.k_proj", weights=weights, bias=bias
)
self.v_proj = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.v_proj", weights=weights, bias=bias
)
self.out_proj = TensorParallelRowLinear.load(
config, prefix=f"{prefix}.out_proj", weights=weights, bias=bias
)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return (
tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
.transpose(1, 2)
.contiguous()
)
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()
# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
# get key, value proj
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
attn_weights = (
attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attention_mask
)
attn_weights = torch.max(
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
# upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
if attn_weights.dtype == torch.float16:
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(torch.float16)
else:
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(
bsz, self.num_heads, tgt_len, src_len
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if output_attentions:
# this operation is a bit awkward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to be reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(
bsz, self.num_heads, tgt_len, src_len
)
attn_weights = attn_weights_reshaped.view(
bsz * self.num_heads, tgt_len, src_len
)
else:
attn_weights_reshaped = None
attn_probs = nn.functional.dropout(
attn_weights, p=self.dropout, training=self.training
)
attn_output = torch.bmm(attn_probs, value_states)
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned aross GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights_reshaped, past_key_value
class OPTDecoderLayer(nn.Module):
def __init__(self, layer_id: int, config: OPTConfig, weights):
super().__init__()
self.process_group = weights.process_group
self.embed_dim = config.hidden_size
prefix = f"model.decoder.layers.{layer_id}"
self.self_attn = OPTAttention(
config,
prefix=f"{prefix}.self_attn",
weights=weights,
is_decoder=True,
bias=config.enable_bias,
)
self.do_layer_norm_before = config.do_layer_norm_before
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.self_attn_layer_norm = nn.LayerNorm.load(
prefix=f"{prefix}.self_attn_layer_norm", weights=weights, eps=EPS
)
self.fc1 = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.fc1", weights=weights, bias=config.enable_bias
)
self.fc2 = TensorParallelRowLinear.load(
config, prefix=f"{prefix}.fc2", weights=weights, bias=config.enable_bias
)
self.final_layer_norm = nn.LayerNorm.load(
prefix=f"{prefix}.final_layer_norm", weights=weights, eps=EPS
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
) -> Tuple[
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size
`(encoder_attention_heads,)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
residual = hidden_states
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if self.do_layer_norm_before:
hidden_states = self.self_attn_layer_norm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
past_key_value=past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
hidden_states = nn.functional.dropout(
hidden_states, p=self.dropout, training=self.training
)
hidden_states = residual + hidden_states
# 350m applies layer norm AFTER attention
if not self.do_layer_norm_before:
hidden_states = self.self_attn_layer_norm(hidden_states)
# Fully Connected
hidden_states_shape = hidden_states.shape
hidden_states = hidden_states.reshape(-1, hidden_states.size(-1))
residual = hidden_states
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if self.do_layer_norm_before:
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
hidden_states = nn.functional.dropout(
hidden_states, p=self.dropout, training=self.training
)
hidden_states = (residual + hidden_states).view(hidden_states_shape)
# 350m applies layer norm AFTER attention
if not self.do_layer_norm_before:
hidden_states = self.final_layer_norm(hidden_states)
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
class OPTPreTrainedModel(PreTrainedModel):
config_class = OPTConfig
class OPTDecoder(OPTPreTrainedModel):
def __init__(self, config: OPTConfig, weights):
super().__init__(config)
self.dropout = config.dropout
self.layerdrop = config.layerdrop
self.padding_idx = config.pad_token_id
self.max_target_positions = config.max_position_embeddings
self.vocab_size = config.vocab_size
self.embed_tokens = TensorParallelEmbedding(
prefix="model.decoder.embed_tokens", weights=weights
)
self.embed_positions = OPTLearnedPositionalEmbedding(weights)
if config.word_embed_proj_dim != config.hidden_size:
self.project_out = FastLinear.load(
config, prefix="model.decoder.project_out", bias=False
)
else:
self.project_out = None
if config.word_embed_proj_dim != config.hidden_size:
self.project_in = FastLinear.load(
config, prefix="model.decoder.project_in", bias=False
)
else:
self.project_in = None
# Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
# with checkpoints that have been fine-tuned before transformers v4.20.1
# see https://github.com/facebookresearch/metaseq/pull/164
if config.do_layer_norm_before and not config._remove_final_layer_norm:
self.final_layer_norm = nn.LayerNorm.load(
prefix="model.decoder.final_layer_norm", weights=weights, eps=EPS
)
else:
self.final_layer_norm = None
self.layers = nn.ModuleList(
[
OPTDecoderLayer(layer_id, config, weights)
for layer_id in range(config.num_hidden_layers)
]
)
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
).to(inputs_embeds.device)
combined_attention_mask = (
expanded_attn_mask
if combined_attention_mask is None
else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
batch_size, seq_length = input_shape
past_key_values_length = (
past_key_values[0][0].shape[2] if past_key_values is not None else 0
)
# required mask seq length can be calculated via length of past
mask_seq_length = past_key_values_length + seq_length
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
batch_size, mask_seq_length, device=inputs_embeds.device
)
causal_attention_mask = self._prepare_decoder_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)
pos_embeds = self.embed_positions(attention_mask, past_key_values_length)
if self.project_in is not None:
inputs_embeds = self.project_in(inputs_embeds)
hidden_states = inputs_embeds + pos_embeds
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
# check if head_mask has a correct number of layers specified if desired
for attn_mask, mask_name in zip([head_mask], ["head_mask"]):
if attn_mask is not None:
if attn_mask.size()[0] != (len(self.layers)):
raise ValueError(
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
f" {head_mask.size()[0]}."
)
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states:
all_hidden_states += (hidden_states,)
dropout_probability = random.uniform(0, 1)
if self.training and (dropout_probability < self.layerdrop):
continue
past_key_value = (
past_key_values[idx] if past_key_values is not None else None
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
if self.final_layer_norm is not None:
hidden_states = self.final_layer_norm(hidden_states)
if self.project_out is not None:
hidden_states = self.project_out(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class OPTModel(OPTPreTrainedModel):
def __init__(self, config: OPTConfig, weights):
super().__init__(config)
self.decoder = OPTDecoder(config, weights)
# Initialize weights and apply final processing
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
decoder_outputs = self.decoder(
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if not return_dict:
return decoder_outputs
return BaseModelOutputWithPast(
last_hidden_state=decoder_outputs.last_hidden_state,
past_key_values=decoder_outputs.past_key_values,
hidden_states=decoder_outputs.hidden_states,
attentions=decoder_outputs.attentions,
)
class OPTForCausalLM(OPTPreTrainedModel):
def __init__(self, config, weights):
super().__init__(config)
self.model = OPTModel(config, weights)
self.lm_head = TensorParallelHead.load(
config, prefix="model.decoder.embed_tokens", weights=weights
)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model.decoder(
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
logits = self.lm_head(outputs[0]).contiguous()
loss = None
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
**kwargs,
):
if past_key_values:
input_ids = input_ids[:, -1:]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(
past_state.index_select(0, beam_idx) for past_state in layer_past
),
)
return reordered_past

File diff suppressed because it is too large Load Diff

View File

@ -1,154 +1,25 @@
import torch import torch
import torch.distributed import torch.distributed
from accelerate import init_empty_weights
from opentelemetry import trace from opentelemetry import trace
from pathlib import Path
from safetensors import safe_open
from transformers import AutoConfig from transformers import AutoConfig
from transformers.models.llama import LlamaTokenizer from transformers.models.llama import LlamaTokenizer
from typing import Optional, List from typing import Optional
from text_generation_server.models import FlashCausalLM from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_llama_modeling import ( from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM, FlashLlamaForCausalLM,
TensorParallelEmbedding,
TensorParallelRowLinear,
TensorParallelColumnLinear,
) )
from text_generation_server.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
download_weights, Weights,
weight_hub_files,
LocalEntryNotFoundError,
) )
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
class FlashLlama(FlashCausalLM): class FlashLlama(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.float16
else:
raise NotImplementedError("FlashLlama is only available on GPU")
tokenizer = LlamaTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
# We do not use from_pretrained as we modified the model internal module layout
try:
filenames = weight_files(model_id, revision, ".bin")
# Local files not found
except LocalEntryNotFoundError:
hub_files = weight_hub_files(model_id, revision, ".bin")
filenames = download_weights(hub_files, model_id, revision)
with init_empty_weights():
model = FlashLlamaForCausalLM(config)
self.load_weights(model, filenames, quantize, device, dtype)
super(FlashCausalLM, self).__init__(
model=model.to(device),
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
device=device,
)
@staticmethod
def load_weights(
model,
filenames: List[Path],
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
):
for filename in filenames:
state_dict = torch.load(filename, map_location="cpu")
for key, value in state_dict.items():
value = value.to(device if quantize is None else "cpu").to(dtype)
layer_name = ".".join(key.split(".")[:4])
# Fused qkv
if "q_proj" in key or "k_proj" in key or "v_proj" in key:
final_key = layer_name + ".query_key_value.weight"
# Fused gate and up projs
elif "gate_proj" in key or "up_proj" in key:
final_key = layer_name + ".gate_up_proj.weight"
else:
final_key = key
module_name, param_name = final_key.rsplit(".", 1)
module = model.get_submodule(module_name)
try:
current_parameter_tensor = module._parameters[param_name]
except KeyError:
current_parameter_tensor = None
if current_parameter_tensor is not None:
if current_parameter_tensor.device == torch.device("meta"):
# Init qkv
if "query_key_value" in final_key:
module._parameters[param_name] = value.new_empty(
(value.shape[0] * 3, value.shape[1])
)
# Init gate and up proj
elif "gate_up_proj" in final_key:
module._parameters[param_name] = value.new_empty(
(value.shape[0] * 2, value.shape[1])
)
# Copy to correct slice
if "q_proj" in key:
module._parameters[param_name][: value.shape[0]] = value
elif "k_proj" in key:
module._parameters[param_name][
value.shape[0] : value.shape[0] * 2
] = value
elif "v_proj" in key:
module._parameters[param_name][value.shape[0] * 2 :] = value
elif "gate_proj" in key:
module._parameters[param_name][: value.shape[0]] = value
elif "up_proj" in key:
module._parameters[param_name][value.shape[0] :] = value
else:
if current_parameter_tensor.shape != value.shape:
raise ValueError(
f"Name {final_key} -- Current {current_parameter_tensor.shape} and got {value.shape}"
)
module._parameters[param_name] = value
else:
module._buffers[param_name] = value
del value
torch.cuda.empty_cache()
model.post_load_weights(quantize)
class FlashLlamaSharded(FlashLlama):
def __init__( def __init__(
self, self,
model_id: str, model_id: str,
@ -176,24 +47,16 @@ class FlashLlamaSharded(FlashLlama):
) )
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
with init_empty_weights(): config.quantize = quantize
model = FlashLlamaForCausalLM(config, process_group=self.process_group) model = FlashLlamaForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
self.load_weights(
model,
filenames,
quantize=quantize,
device=device,
dtype=dtype,
rank=rank,
world_size=world_size,
)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__( super(FlashCausalLM, self).__init__(
model=model.to(device), model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=False, requires_padding=False,
dtype=dtype, dtype=dtype,
@ -201,114 +64,3 @@ class FlashLlamaSharded(FlashLlama):
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
) )
@staticmethod
def load_weights(
model,
filenames: List[str],
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
rank: int,
world_size: int,
):
for file in filenames:
with safe_open(
file, framework="pt", device=str(device) if quantize is None else "cpu"
) as f:
for name in f.keys():
slice_ = f.get_slice(name)
layer_name = ".".join(name.split(".")[:4])
# Fused qkv
if "q_proj" in name or "k_proj" in name or "v_proj" in name:
final_name = layer_name + ".query_key_value.weight"
# Fused gate and up projs
elif "gate_proj" in name or "up_proj" in name:
final_name = layer_name + ".gate_up_proj.weight"
else:
final_name = name
module_name, param_name = final_name.rsplit(".", 1)
module = model.get_submodule(module_name)
if isinstance(module, TensorParallelColumnLinear):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif isinstance(module, TensorParallelRowLinear):
size = slice_.get_shape()[1]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[:, start:stop]
elif isinstance(module, TensorParallelEmbedding):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif name == "lm_head.weight" and model.model.tp_embeddings:
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
else:
try:
tensor = slice_[:]
except:
tensor = f.get_tensor(name)
tensor = tensor.contiguous().to(dtype)
try:
current_parameter_tensor = module._parameters[param_name]
except KeyError:
current_parameter_tensor = None
if current_parameter_tensor is not None:
if current_parameter_tensor.device == torch.device("meta"):
# Init qkv
if "query_key_value" in final_name:
module._parameters[param_name] = tensor.new_empty(
(tensor.shape[0] * 3, tensor.shape[1])
)
# Init gate and up proj
elif "gate_up_proj" in final_name:
module._parameters[param_name] = tensor.new_empty(
(tensor.shape[0] * 2, tensor.shape[1])
)
# Init gate and up proj
if "q_proj" in name:
module._parameters[param_name][: tensor.shape[0]] = tensor
elif "k_proj" in name:
module._parameters[param_name][
tensor.shape[0] : tensor.shape[0] * 2
] = tensor
elif "v_proj" in name:
module._parameters[param_name][
tensor.shape[0] * 2 :
] = tensor
elif "gate_proj" in name:
module._parameters[param_name][: tensor.shape[0]] = tensor
elif "up_proj" in name:
module._parameters[param_name][tensor.shape[0] :] = tensor
else:
if current_parameter_tensor.shape != tensor.shape:
raise ValueError(
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
)
module._parameters[param_name] = tensor
else:
module._buffers[param_name] = tensor
torch.cuda.empty_cache()
model.post_load_weights(quantize)

View File

@ -1,45 +1,24 @@
import torch import torch
import torch.distributed import torch.distributed
from accelerate import init_empty_weights
from opentelemetry import trace from opentelemetry import trace
from safetensors import safe_open
from transformers import AutoTokenizer, AutoConfig from transformers import AutoTokenizer, AutoConfig
from typing import Optional, List from typing import Optional
from text_generation_server.models import FlashCausalLM from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_neox_modeling import ( from text_generation_server.models.custom_modeling.flash_neox_modeling import (
FlashGPTNeoXForCausalLM, FlashGPTNeoXForCausalLM,
TensorParallelEmbedding,
TensorParallelRowLinear,
TensorParallelColumnLinear,
) )
from text_generation_server.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
Weights,
) )
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
class FlashNeoX(FlashCausalLM): class FlashNeoXSharded(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
super(FlashNeoX, self).__init__(
FlashGPTNeoXForCausalLM,
model_id,
revision,
quantize,
trust_remote_code=trust_remote_code,
)
class FlashNeoXSharded(FlashNeoX):
def __init__( def __init__(
self, self,
model_id: str, model_id: str,
@ -65,23 +44,16 @@ class FlashNeoXSharded(FlashNeoX):
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
config.quantize = quantize
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
with init_empty_weights(): filenames, device=device, dtype=dtype, process_group=self.process_group
model = FlashGPTNeoXForCausalLM(config, self.process_group)
torch.distributed.barrier(group=self.process_group)
self.load_weights(
model,
filenames,
quantize=quantize,
device=device,
dtype=dtype,
rank=rank,
world_size=world_size,
) )
model = FlashGPTNeoXForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__( super(FlashCausalLM, self).__init__(
model=model.to(device), model=model.to(device),
@ -92,79 +64,3 @@ class FlashNeoXSharded(FlashNeoX):
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
) )
@staticmethod
def load_weights(
model,
filenames: List[str],
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
rank: int,
world_size: int,
):
parameters = dict(model.named_parameters())
for file in filenames:
with safe_open(
file, framework="pt", device=str(device) if quantize is None else "cpu"
) as f:
for name in f.keys():
module_name, param_name = name.rsplit(".", 1)
module = model.get_submodule(module_name)
current_parameter_tensor = parameters.get(name, None)
slice_ = f.get_slice(name)
if isinstance(module, TensorParallelColumnLinear):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif isinstance(module, TensorParallelRowLinear):
if param_name == "weight":
size = slice_.get_shape()[1]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[:, start:stop]
else:
tensor = slice_[:]
# XXX: Hack for Rowlinear to add the bias only once.
if rank != 0:
tensor = torch.zeros_like(tensor)
elif isinstance(module, TensorParallelEmbedding):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif name == "embed_out.weight" and model.gpt_neox.tp_embeddings:
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
else:
try:
tensor = slice_[:]
except:
tensor = f.get_tensor(name)
if (
current_parameter_tensor is not None
and current_parameter_tensor.shape != tensor.shape
):
raise ValueError(
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
)
tensor = tensor.contiguous().to(dtype)
if current_parameter_tensor is not None:
module._parameters[param_name] = tensor
else:
module._buffers[param_name] = tensor
model.post_load_weights(quantize)

View File

@ -1,119 +1,25 @@
import torch import torch
import torch.distributed import torch.distributed
from pathlib import Path
from accelerate import init_empty_weights
from opentelemetry import trace from opentelemetry import trace
from safetensors import safe_open from transformers import AutoTokenizer
from transformers import AutoTokenizer, AutoConfig from typing import Optional
from typing import Optional, List
from text_generation_server.models import FlashCausalLM from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_rw_modeling import ( from text_generation_server.models.custom_modeling.flash_rw_modeling import (
RWConfig, RWConfig,
FlashRWForCausalLM, FlashRWForCausalLM,
TensorParallelEmbedding,
TensorParallelRowLinear,
TensorParallelColumnLinear,
) )
from text_generation_server.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
download_weights, Weights,
weight_hub_files,
LocalEntryNotFoundError,
) )
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
class FlashRW(FlashCausalLM): class FlashRWSharded(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.float16
else:
raise NotImplementedError("RW is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = RWConfig.from_pretrained(
model_id,
revision=revision,
)
# We do not use from_pretrained as it is too slow
try:
filenames = weight_files(model_id, revision, ".bin")
# Local files not found
except LocalEntryNotFoundError:
hub_files = weight_hub_files(model_id, revision, ".bin")
filenames = download_weights(hub_files, model_id, revision)
with init_empty_weights():
model = FlashRWForCausalLM(config)
self.load_weights(
model,
filenames,
quantize,
device,
dtype,
)
super(FlashCausalLM, self).__init__(
model=model.to(device),
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
device=device,
)
@staticmethod
def load_weights(
model: FlashRWForCausalLM,
filenames: List[Path],
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
):
for filename in filenames:
state_dict = torch.load(filename, map_location="cpu")
for key, value in state_dict.items():
value = value.to(device if quantize is None else "cpu").to(dtype)
module_name, param_name = key.rsplit(".", 1)
module = model.get_submodule(module_name)
try:
current_parameter_tensor = module._parameters[param_name]
if current_parameter_tensor.shape != value.shape:
raise ValueError(
f"Name {key} -- Current {current_parameter_tensor.shape} and got {value.shape}"
)
module._parameters[param_name] = value
except KeyError:
module._buffers[param_name] = value
del value
torch.cuda.empty_cache()
model.post_load_weights(quantize)
class FlashRWSharded(FlashRW):
def __init__( def __init__(
self, self,
model_id: str, model_id: str,
@ -142,20 +48,12 @@ class FlashRWSharded(FlashRW):
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
with init_empty_weights(): config.quantize = quantize
model = FlashRWForCausalLM(config, self.process_group)
model = FlashRWForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
self.load_weights(
model,
filenames,
quantize=quantize,
device=device,
dtype=dtype,
rank=rank,
world_size=world_size,
)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__( super(FlashCausalLM, self).__init__(
model=model.to(device), model=model.to(device),
@ -166,79 +64,3 @@ class FlashRWSharded(FlashRW):
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
) )
@staticmethod
def load_weights(
model,
filenames: List[str],
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
rank: int,
world_size: int,
):
parameters = dict(model.named_parameters())
for file in filenames:
with safe_open(
file, framework="pt", device=str(device) if quantize is None else "cpu"
) as f:
for name in f.keys():
module_name, param_name = name.rsplit(".", 1)
module = model.get_submodule(module_name)
current_parameter_tensor = parameters.get(name, None)
slice_ = f.get_slice(name)
if isinstance(module, TensorParallelColumnLinear):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif isinstance(module, TensorParallelRowLinear):
if param_name == "weight":
size = slice_.get_shape()[1]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[:, start:stop]
else:
tensor = slice_[:]
# XXX: Hack for Rowlinear to add the bias only once.
if rank != 0:
tensor = torch.zeros_like(tensor)
elif isinstance(module, TensorParallelEmbedding):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif name == "lm_head.weight" and model.transformer.tp_embeddings:
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
else:
try:
tensor = slice_[:]
except:
tensor = f.get_tensor(name)
if (
current_parameter_tensor is not None
and current_parameter_tensor.shape != tensor.shape
):
raise ValueError(
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
)
tensor = tensor.contiguous().to(dtype)
if current_parameter_tensor is not None:
module._parameters[param_name] = tensor
else:
module._buffers[param_name] = tensor
model.post_load_weights(quantize)

View File

@ -1,197 +1,24 @@
import torch import torch
import torch.distributed import torch.distributed
from accelerate import init_empty_weights
from opentelemetry import trace from opentelemetry import trace
from safetensors import safe_open from transformers import AutoTokenizer, AutoConfig
from pathlib import Path
from transformers import AutoTokenizer, GPT2Config
from typing import Optional, List from typing import Optional, List
from text_generation_server.models import FlashCausalLM from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_santacoder_modeling import ( from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
FlashSantacoderForCausalLM, FlashSantacoderForCausalLM,
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
) )
from text_generation_server.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
download_weights, Weights,
weight_hub_files,
LocalEntryNotFoundError,
) )
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
class FlashSantacoder(FlashCausalLM): class FlashSantacoderSharded(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.float16
else:
raise NotImplementedError("FlashSantacoder is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = GPT2Config.from_pretrained(
model_id,
revision=revision,
)
# We do not use from_pretrained as we modified the model internal module layout
filenames = weight_files(model_id, revision, ".safetensors")
with init_empty_weights():
model = FlashSantacoderForCausalLM(config)
self.load_weights(
model,
filenames,
quantize,
device,
dtype,
config.architectures[0].startswith("GPT2"),
)
super(FlashCausalLM, self).__init__(
model=model.to(device),
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
device=device,
)
@staticmethod
def load_weights(
model: FlashSantacoderForCausalLM,
filenames: List[Path],
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
transpose: bool,
):
for filename in filenames:
with safe_open(
filename,
framework="pt",
device=str(device) if quantize is None else "cpu",
) as f:
for key in f.keys():
value = f.get_tensor(key)
value = value.to(device if quantize is None else "cpu").to(dtype)
layer_name = ".".join(key.split(".")[:4])
# Fused qkv
if "q_attn.weight" in key or "kv_attn.weight" in key:
final_key = layer_name + ".c_attn.weight"
elif "q_attn.bias" in key or "kv_attn.bias" in key:
final_key = layer_name + ".c_attn.bias"
else:
final_key = key
module_name, param_name = final_key.rsplit(".", 1)
module = model.get_submodule(module_name)
try:
current_parameter_tensor = module._parameters[param_name]
except KeyError:
current_parameter_tensor = None
if current_parameter_tensor is not None:
if transpose and (
"c_fc.weight" in key
or "c_proj.weight" in key
or "q_attn.weight" in key
or "kv_attn.weight" in key
or "c_attn.weight" in key
):
# Tranpose as we use nn.Linear instead of Conv1D
value = value.T
if current_parameter_tensor.device == torch.device("meta"):
# Init qkv
if "c_attn.weight" in final_key:
module._parameters[param_name] = value.new_empty(
(
model.transformer.head_size
* (model.transformer.num_heads + 2),
value.shape[1],
)
)
elif "c_attn.bias" in final_key:
module._parameters[param_name] = value.new_empty(
(
model.transformer.head_size
* (model.transformer.num_heads + 2)
)
)
# Copy to correct slice
if "q_attn.weight" in key:
module._parameters[param_name][: value.shape[0]] = value
elif "q_attn.bias" in key:
module._parameters[param_name][: value.shape[0]] = value
elif "kv_attn.weight" in key:
module._parameters[param_name][
model.transformer.head_size
* model.transformer.num_heads :
] = value
elif "kv_attn.bias" in key:
module._parameters[param_name][
model.transformer.head_size
* model.transformer.num_heads :
] = value
else:
if current_parameter_tensor.shape != value.shape:
raise ValueError(
f"Name {final_key} -- Current {current_parameter_tensor.shape} and got {value.shape}"
)
module._parameters[param_name] = value
else:
module._buffers[param_name] = value
del value
if model.lm_head.weight.device == torch.device("meta"):
model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight)
torch.cuda.empty_cache()
model.post_load_weights(quantize)
uninitialized_parameters = []
for n, p in model.named_parameters():
if p.data.device == torch.device("meta"):
uninitialized_parameters.append(n)
if uninitialized_parameters:
raise RuntimeError(
f"found uninitialized parameters in model : {uninitialized_parameters}"
)
def decode(self, generated_ids: List[int]) -> str:
# Do not skip special tokens as they are used for custom parsing rules of the generated text
return self.tokenizer.decode(
generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
)
class FlashSantacoderSharded(FlashSantacoder):
def __init__( def __init__(
self, self,
model_id: str, model_id: str,
@ -214,28 +41,22 @@ class FlashSantacoderSharded(FlashSantacoder):
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
config = GPT2Config.from_pretrained( config = AutoConfig.from_pretrained(
model_id, model_id,
revision=revision, revision=revision,
trust_remote_code=True,
) )
config.quantize = quantize
config.transpose = config.architectures[0].startswith("GPT2")
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
with init_empty_weights(): filenames, device=device, dtype=dtype, process_group=self.process_group
model = FlashSantacoderForCausalLM(config, self.process_group)
torch.distributed.barrier(group=self.process_group)
self.load_weights(
model,
filenames,
quantize=quantize,
device=device,
dtype=dtype,
rank=rank,
world_size=world_size,
transpose=config.architectures[0].startswith("GPT2"),
) )
model = FlashSantacoderForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__( super(FlashCausalLM, self).__init__(
model=model.to(device), model=model.to(device),
@ -247,164 +68,8 @@ class FlashSantacoderSharded(FlashSantacoder):
world_size=world_size, world_size=world_size,
) )
@staticmethod def decode(self, generated_ids: List[int]) -> str:
def load_weights( # Do not skip special tokens as they are used for custom parsing rules of the generated text
model, return self.tokenizer.decode(
filenames: List[str], generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
rank: int,
world_size: int,
transpose: bool,
):
for file in filenames:
with safe_open(
file, framework="pt", device=str(device) if quantize is None else "cpu"
) as f:
for key in f.keys():
slice_ = f.get_slice(key)
layer_name = ".".join(key.split(".")[:4])
# Fused qkv
if "q_attn.weight" in key or "kv_attn.weight" in key:
final_key = layer_name + ".c_attn.weight"
elif "q_attn.bias" in key or "kv_attn.bias" in key:
final_key = layer_name + ".c_attn.bias"
else:
final_key = key
module_name, param_name = final_key.rsplit(".", 1)
module = model.get_submodule(module_name)
if isinstance(module, TensorParallelColumnLinear):
dim = 1 if transpose and "weight" in param_name else 0
size = slice_.get_shape()[dim]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = (
slice_[start:stop] if dim == 0 else slice_[:, start:stop]
) )
elif isinstance(module, TensorParallelRowLinear):
if param_name == "weight":
dim = 0 if transpose else 1
size = slice_.get_shape()[dim]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = (
slice_[start:stop]
if dim == 0
else slice_[:, start:stop]
)
else:
tensor = slice_[:]
# XXX: Hack for Rowlinear to add the bias only once.
if rank != 0:
tensor = torch.zeros_like(tensor)
elif isinstance(module, TensorParallelEmbedding):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif key == "lm_head.weight" and model.transformer.tp_embeddings:
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
else:
try:
tensor = slice_[:]
except:
tensor = f.get_tensor(key)
tensor = tensor.contiguous().to(dtype)
try:
current_parameter_tensor = module._parameters[param_name]
except KeyError:
current_parameter_tensor = None
if current_parameter_tensor is not None:
if transpose and (
"c_fc.weight" in key
or "c_proj.weight" in key
or "q_attn.weight" in key
or "kv_attn.weight" in key
or "c_attn.weight" in key
):
# Tranpose as we use nn.Linear instead of Conv1D
tensor = tensor.T
if current_parameter_tensor.device == torch.device("meta"):
# Init qkv
if "c_attn.weight" in final_key:
module._parameters[param_name] = tensor.new_empty(
(
model.transformer.head_size
* (model.transformer.num_heads + 2),
tensor.shape[1],
)
)
elif "c_attn.bias" in final_key:
module._parameters[param_name] = tensor.new_empty(
(
model.transformer.head_size
* (model.transformer.num_heads + 2)
)
)
# Copy to correct slice
if "q_attn" in key:
size = tensor.shape[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = tensor[start:stop]
module._parameters[param_name][: tensor.shape[0]] = tensor
elif "kv_attn.weight" in key:
module._parameters[param_name][
model.transformer.head_size
* model.transformer.num_heads :
] = tensor
elif "kv_attn.bias" in key:
module._parameters[param_name][
model.transformer.head_size
* model.transformer.num_heads :
] = tensor
elif "c_attn" in key:
# Slice q_tensor by shard
q_tensor = tensor[: -2 * model.transformer.head_size]
block_size = q_tensor.shape[0] // world_size
start = rank * block_size
stop = (rank + 1) * block_size
q_tensor = q_tensor[start:stop]
module._parameters[param_name][
: q_tensor.shape[0]
] = q_tensor
# Kv tensor is copied for every shard
kv_tensor = tensor[-2 * model.transformer.head_size :]
module._parameters[param_name][
q_tensor.shape[0] :
] = kv_tensor
else:
if current_parameter_tensor.shape != tensor.shape:
raise ValueError(
f"Name {key} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
)
module._parameters[param_name] = tensor
else:
module._buffers[param_name] = tensor
if model.lm_head.weight.device == torch.device("meta"):
model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight)
torch.cuda.empty_cache()
model.post_load_weights(quantize)

View File

@ -2,41 +2,25 @@ import re
import torch import torch
import torch.distributed import torch.distributed
from typing import List, Optional, Type, Tuple from typing import List, Optional, Type
from accelerate import init_empty_weights
from safetensors import safe_open
from transformers import ( from transformers import (
AutoTokenizer, AutoTokenizer,
AutoModelForCausalLM,
AutoConfig, AutoConfig,
PreTrainedTokenizerBase, PreTrainedTokenizerBase,
) )
from transformers.models.opt.parallel_layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
)
from text_generation_server.models import CausalLM from text_generation_server.models import CausalLM
from text_generation_server.models.causal_lm import CausalLMBatch from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.models.opt import OPT from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
from text_generation_server.utils import ( from text_generation_server.utils import (
NextTokenChooser, NextTokenChooser,
StoppingCriteria, StoppingCriteria,
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
Weights,
) )
HAS_BITS_AND_BYTES = True
try:
import bitsandbytes as bnb
from bitsandbytes.nn import Int8Params
except Exception as e:
HAS_BITS_AND_BYTES = False
# CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py # CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py
# we split individual characters inside special tokens like [START_DNA] # we split individual characters inside special tokens like [START_DNA]
@ -168,33 +152,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
) )
class Galactica(OPT): class GalacticaSharded(CausalLM):
@property
def batch_type(self) -> Type[CausalLMBatch]:
return GalacticaCausalLMBatch
def decode(self, generated_ids: List[int]) -> str:
# Do not skip special tokens as they are used for custom parsing rules of the generated text
return self.tokenizer.decode(
generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
)
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
"""Overwrite forward to ignore position_ids"""
# Model Forward
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
return outputs.logits, outputs.past_key_values
class GalacticaSharded(Galactica):
def __init__( def __init__(
self, self,
model_id: str, model_id: str,
@ -224,26 +182,17 @@ class GalacticaSharded(Galactica):
tp_parallel=True, tp_parallel=True,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
config.quantize = quantize
tokenizer.pad_token_id = config.pad_token_id tokenizer.pad_token_id = config.pad_token_id
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
with init_empty_weights(): filenames, device=device, dtype=dtype, process_group=self.process_group
model = AutoModelForCausalLM.from_config(
config, trust_remote_code=trust_remote_code
) )
torch.distributed.barrier(group=self.process_group) model = OPTForCausalLM(config, weights)
self.load_weights(
model,
filenames,
quantize=quantize,
device=device,
dtype=dtype,
rank=rank,
world_size=world_size,
)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
model=model, model=model,
@ -255,128 +204,16 @@ class GalacticaSharded(Galactica):
world_size=world_size, world_size=world_size,
) )
@staticmethod @property
def load_weights( def batch_type(self) -> Type[CausalLMBatch]:
model, return GalacticaCausalLMBatch
filenames: List[str],
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
rank: int,
world_size: int,
):
parameters = dict(model.named_parameters())
for file in filenames:
with safe_open(
file, framework="pt", device=str(device) if quantize is None else "cpu"
) as f:
for name in f.keys():
if name == "lm_head.weight":
continue
module_name, param_name = name.rsplit(".", 1) def decode(self, generated_ids: List[int]) -> str:
module = model.get_submodule(module_name) # Do not skip special tokens as they are used for custom parsing rules of the generated text
current_tensor = parameters[name] return self.tokenizer.decode(
generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
slice_ = f.get_slice(name)
if isinstance(module, TensorParallelColumnLinear):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif isinstance(module, TensorParallelRowLinear):
if param_name == "weight":
size = slice_.get_shape()[1]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[:, start:stop]
else:
tensor = slice_[:]
# XXX: Hack for Rowlinear to add the bias only once.
if rank != 0:
tensor = torch.zeros_like(tensor)
elif isinstance(module, TensorParallelEmbedding):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
else:
tensor = slice_[:]
if current_tensor.shape != tensor.shape:
raise ValueError(
f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}"
) )
tensor = tensor.contiguous().to(dtype)
if quantize == "bitsandbytes":
if not HAS_BITS_AND_BYTES:
raise ImportError(
"bitsandbytes is not available on your machine either because it is not installed "
"or you don't have a GPU.\n"
"You can install it with `pip install bitsandbytes`."
)
if (
type(module)
in [TensorParallelRowLinear, TensorParallelColumnLinear]
and param_name == "weight"
):
tensor = Int8Params(
tensor,
has_fp16_weights=False,
requires_grad=False,
).to(device)
state = bnb.MatmulLtState()
state.threshold = 6.0
state.has_fp16_weights = False
state.memory_efficient_backward = False
state.use_pool = True
state.CB = tensor.CB
state.SCB = tensor.SCB
tensor.CB = None
tensor.SCB = None
def replace_linear(state):
def linear(input, weight, bias):
out = bnb.matmul(
input,
weight,
state=state,
threshold=state.threshold,
bias=bias,
)
if state.CB is not None:
# we converted 8-bit row major to turing/ampere format
# in the first inference pass
# we no longer need the row-major weight
del state.CB
weight.data = state.CxB
return out
return linear
module.linear = replace_linear(state)
else:
tensor = tensor.to(device)
elif quantize == "gptq":
raise NotImplementedError("`gptq` is not implemented for now")
elif quantize is None:
tensor = tensor.to(device)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
module._parameters[param_name] = tensor
if name == "model.decoder.embed_tokens.weight":
model.lm_head._parameters["weight"] = tensor
def forward( def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
): ):
@ -386,10 +223,4 @@ class GalacticaSharded(Galactica):
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=True, use_cache=True,
) )
return outputs.logits, outputs.past_key_values
# Logits are sharded, so we need to gather them
logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
torch.distributed.all_gather(logits, outputs.logits, group=self.process_group)
logits = torch.cat(logits, dim=2)
return logits, outputs.past_key_values

View File

@ -1,34 +1,22 @@
import torch import torch
import torch.distributed import torch.distributed
from typing import List, Optional from typing import Optional
from accelerate import init_empty_weights
from safetensors import safe_open
from transformers import ( from transformers import (
AutoTokenizer, AutoTokenizer,
AutoModelForCausalLM,
AutoConfig, AutoConfig,
) )
from transformers.models.gpt_neox.parallel_layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
)
from text_generation_server.models import CausalLM from text_generation_server.models import CausalLM
from text_generation_server.models.custom_modeling.neox_modeling import (
GPTNeoxForCausalLM,
)
from text_generation_server.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
Weights,
) )
HAS_BITS_AND_BYTES = True
try:
import bitsandbytes as bnb
from bitsandbytes.nn import Int8Params
except Exception as e:
HAS_BITS_AND_BYTES = False
class GPTNeoxSharded(CausalLM): class GPTNeoxSharded(CausalLM):
def __init__( def __init__(
@ -58,28 +46,18 @@ class GPTNeoxSharded(CausalLM):
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
model_id, model_id,
revision=revision, revision=revision,
tp_parallel=True,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
config.quantize = quantize
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
with init_empty_weights(): filenames, device=device, dtype=dtype, process_group=self.process_group
model = AutoModelForCausalLM.from_config(
config, trust_remote_code=trust_remote_code
) )
torch.distributed.barrier(group=self.process_group) model = GPTNeoxForCausalLM(config, weights)
self.load_weights(
model,
filenames,
quantize=quantize,
device=device,
dtype=dtype,
rank=rank,
world_size=world_size,
)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
model=model, model=model,
@ -91,143 +69,9 @@ class GPTNeoxSharded(CausalLM):
world_size=world_size, world_size=world_size,
) )
@staticmethod
def load_weights(
model,
filenames: List[str],
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
rank: int,
world_size: int,
):
parameters = dict(model.named_parameters())
for file in filenames:
with safe_open(
file, framework="pt", device=str(device) if quantize is None else "cpu"
) as f:
for name in f.keys():
module_name, param_name = name.rsplit(".", 1)
module = model.get_submodule(module_name)
current_parameter_tensor = parameters.get(name, None)
slice_ = f.get_slice(name)
if isinstance(module, TensorParallelColumnLinear):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif isinstance(module, TensorParallelRowLinear):
if param_name == "weight":
size = slice_.get_shape()[1]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[:, start:stop]
else:
tensor = slice_[:]
# XXX: Hack for Rowlinear to add the bias only once.
if rank != 0:
tensor = torch.zeros_like(tensor)
elif isinstance(module, TensorParallelEmbedding):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif name == "embed_out.weight" and model.gpt_neox.tp_embeddings:
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
else:
try:
tensor = slice_[:]
except:
tensor = f.get_tensor(name)
if (
current_parameter_tensor is not None
and current_parameter_tensor.shape != tensor.shape
):
raise ValueError(
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
)
tensor = tensor.contiguous().to(dtype)
if quantize == "bitsandbytes":
if not HAS_BITS_AND_BYTES:
raise ImportError(
"bitsandbytes is not available on your machine either because it is not installed "
"or you don't have a GPU.\n"
"You can install it with `pip install bitsandbytes`."
)
if (
type(module)
in [TensorParallelRowLinear, TensorParallelColumnLinear]
and param_name == "weight"
):
tensor = Int8Params(
tensor,
has_fp16_weights=False,
requires_grad=False,
).to(device)
state = bnb.MatmulLtState()
state.threshold = 6.0
state.has_fp16_weights = False
state.memory_efficient_backward = False
state.use_pool = True
state.CB = tensor.CB
state.SCB = tensor.SCB
tensor.CB = None
tensor.SCB = None
def replace_linear(state):
def linear(input, weight, bias):
out = bnb.matmul(
input,
weight,
state=state,
threshold=state.threshold,
bias=bias,
)
if state.CB is not None:
# we converted 8-bit row major to turing/ampere format
# in the first inference pass
# we no longer need the row-major weight
del state.CB
weight.data = state.CxB
return out
return linear
module.linear = replace_linear(state)
else:
tensor = tensor.to(device)
elif quantize == "gptq":
raise NotImplementedError("`gptq` is not implemented for now")
elif quantize is None:
tensor = tensor.to(device)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
if current_parameter_tensor is not None:
module._parameters[param_name] = tensor
else:
module._buffers[param_name] = tensor
def forward( def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
): ):
if self.model.gpt_neox.tp_embeddings:
outputs = self.model.forward( outputs = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
@ -236,16 +80,5 @@ class GPTNeoxSharded(CausalLM):
use_cache=True, use_cache=True,
) )
# Logits are sharded, so we need to gather them logits = outputs.logits
logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
torch.distributed.all_gather(
logits, outputs.logits, group=self.process_group
)
logits = torch.cat(logits, dim=2)
return logits, outputs.past_key_values return logits, outputs.past_key_values
# While the model itself is sharded, the embeddings might not as they might not be dividable by num-shard
else:
return super(GPTNeoxSharded, self).forward(
input_ids, attention_mask, position_ids, past_key_values
)

View File

@ -1,52 +1,22 @@
import torch import torch
import torch.distributed import torch.distributed
from typing import List, Optional, Tuple from typing import Optional
from accelerate import init_empty_weights
from safetensors import safe_open
from transformers import ( from transformers import (
AutoTokenizer, AutoTokenizer,
AutoModelForCausalLM,
AutoConfig, AutoConfig,
) )
from transformers.models.opt.parallel_layers import ( from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
)
from text_generation_server.models import CausalLM from text_generation_server.models import CausalLM
from text_generation_server.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
Weights,
) )
HAS_BITS_AND_BYTES = True
try:
import bitsandbytes as bnb
from bitsandbytes.nn import Int8Params
except Exception as e:
HAS_BITS_AND_BYTES = False
class OPTSharded(CausalLM):
class OPT(CausalLM):
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
"""Overwrite forward to ignore position_ids"""
# Model Forward
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
return outputs.logits, outputs.past_key_values
class OPTSharded(OPT):
def __init__( def __init__(
self, self,
model_id: str, model_id: str,
@ -73,29 +43,19 @@ class OPTSharded(OPT):
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
model_id, model_id,
revision=revision, revision=revision,
tp_parallel=True,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
config.quantize = quantize
tokenizer.pad_token_id = config.pad_token_id tokenizer.pad_token_id = config.pad_token_id
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
with init_empty_weights(): filenames, device=device, dtype=dtype, process_group=self.process_group
model = AutoModelForCausalLM.from_config(
config, trust_remote_code=trust_remote_code
) )
torch.distributed.barrier(group=self.process_group) model = OPTForCausalLM(config, weights)
self.load_weights(
model,
filenames,
quantize=quantize,
device=device,
dtype=dtype,
rank=rank,
world_size=world_size,
)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
model=model, model=model,
@ -107,128 +67,6 @@ class OPTSharded(OPT):
world_size=world_size, world_size=world_size,
) )
@staticmethod
def load_weights(
model,
filenames: List[str],
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
rank: int,
world_size: int,
):
parameters = dict(model.named_parameters())
for file in filenames:
with safe_open(
file, framework="pt", device=str(device) if quantize is None else "cpu"
) as f:
for name in f.keys():
if name == "lm_head.weight":
continue
module_name, param_name = name.rsplit(".", 1)
module = model.get_submodule(module_name)
current_tensor = parameters[name]
slice_ = f.get_slice(name)
if isinstance(module, TensorParallelColumnLinear):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif isinstance(module, TensorParallelRowLinear):
if param_name == "weight":
size = slice_.get_shape()[1]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[:, start:stop]
else:
tensor = slice_[:]
# XXX: Hack for Rowlinear to add the bias only once.
if rank != 0:
tensor = torch.zeros_like(tensor)
elif isinstance(module, TensorParallelEmbedding):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
else:
tensor = slice_[:]
if current_tensor.shape != tensor.shape:
raise ValueError(
f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}"
)
tensor = tensor.contiguous().to(dtype)
if quantize == "bitsandbytes":
if not HAS_BITS_AND_BYTES:
raise ImportError(
"bitsandbytes is not available on your machine either because it is not installed "
"or you don't have a GPU.\n"
"You can install it with `pip install bitsandbytes`."
)
if (
type(module)
in [TensorParallelRowLinear, TensorParallelColumnLinear]
and param_name == "weight"
):
tensor = Int8Params(
tensor,
has_fp16_weights=False,
requires_grad=False,
).to(device)
state = bnb.MatmulLtState()
state.threshold = 6.0
state.has_fp16_weights = False
state.memory_efficient_backward = False
state.use_pool = True
state.CB = tensor.CB
state.SCB = tensor.SCB
tensor.CB = None
tensor.SCB = None
def replace_linear(state):
def linear(input, weight, bias):
out = bnb.matmul(
input,
weight,
state=state,
threshold=state.threshold,
bias=bias,
)
if state.CB is not None:
# we converted 8-bit row major to turing/ampere format
# in the first inference pass
# we no longer need the row-major weight
del state.CB
weight.data = state.CxB
return out
return linear
module.linear = replace_linear(state)
else:
tensor = tensor.to(device)
elif quantize == "gptq":
raise NotImplementedError("`gptq` is not implemented for now")
elif quantize is None:
tensor = tensor.to(device)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
module._parameters[param_name] = tensor
if name == "model.decoder.embed_tokens.weight":
model.lm_head._parameters["weight"] = tensor
def forward( def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
): ):
@ -239,9 +77,4 @@ class OPTSharded(OPT):
use_cache=True, use_cache=True,
) )
# Logits are sharded, so we need to gather them return outputs.logits, outputs.past_key_values
logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
torch.distributed.all_gather(logits, outputs.logits, group=self.process_group)
logits = torch.cat(logits, dim=2)
return logits, outputs.past_key_values

View File

@ -3,31 +3,20 @@ import torch.distributed
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from accelerate import init_empty_weights
from safetensors import safe_open
from transformers import ( from transformers import (
AutoTokenizer, AutoTokenizer,
AutoModelForSeq2SeqLM,
AutoConfig, AutoConfig,
) )
from text_generation_server.models import Seq2SeqLM from text_generation_server.models import Seq2SeqLM
from text_generation_server.models.custom_modeling.t5_modeling import (
T5ForConditionalGeneration,
)
from text_generation_server.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
Weights,
) )
from transformers.models.t5.parallel_layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
)
HAS_BITS_AND_BYTES = True
try:
import bitsandbytes as bnb
from bitsandbytes.nn import Int8Params
except ImportError as e:
HAS_BITS_AND_BYTES = False
class T5Sharded(Seq2SeqLM): class T5Sharded(Seq2SeqLM):
@ -46,6 +35,13 @@ class T5Sharded(Seq2SeqLM):
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.float32 dtype = torch.float32
config = AutoConfig.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
)
config.quantize = quantize
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, model_id,
revision=revision, revision=revision,
@ -53,33 +49,16 @@ class T5Sharded(Seq2SeqLM):
truncation_side="left", truncation_side="left",
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
config = AutoConfig.from_pretrained(
model_id,
revision=revision,
tp_parallel=True,
trust_remote_code=trust_remote_code,
)
tokenizer.bos_token_id = config.decoder_start_token_id tokenizer.bos_token_id = config.decoder_start_token_id
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
with init_empty_weights(): filenames, device=device, dtype=dtype, process_group=self.process_group
model = AutoModelForSeq2SeqLM.from_config(
config, trust_remote_code=trust_remote_code
) )
torch.distributed.barrier(group=self.process_group) model = T5ForConditionalGeneration(config, weights)
self.load_weights(
model,
filenames,
quantize=quantize,
device=device,
dtype=dtype,
rank=rank,
world_size=world_size,
)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(Seq2SeqLM, self).__init__( super(Seq2SeqLM, self).__init__(
model=model, model=model,
@ -91,151 +70,6 @@ class T5Sharded(Seq2SeqLM):
world_size=world_size, world_size=world_size,
) )
@staticmethod
def load_weights(
model,
filenames: List[str],
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
rank: int,
world_size: int,
):
parameters = dict(model.named_parameters())
for file in filenames:
with safe_open(
file, framework="pt", device=str(device) if quantize is None else "cpu"
) as f:
for name in f.keys():
module_name, param_name = name.rsplit(".", 1)
module = model.get_submodule(module_name)
current_parameter_tensor = parameters.get(name, None)
slice_ = f.get_slice(name)
if isinstance(module, TensorParallelColumnLinear):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif isinstance(module, TensorParallelRowLinear):
if param_name == "weight":
size = slice_.get_shape()[1]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[:, start:stop]
else:
tensor = slice_[:]
# XXX: Hack for Rowlinear to add the bias only once.
if rank != 0:
tensor = torch.zeros_like(tensor)
elif isinstance(module, TensorParallelEmbedding):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif name == "lm_head.weight":
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif "relative_attention_bias.weight" in name:
size = slice_.get_shape()[1]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[:, start:stop]
else:
try:
tensor = slice_[:]
except:
tensor = f.get_tensor(name)
if (
current_parameter_tensor is not None
and current_parameter_tensor.shape != tensor.shape
):
raise ValueError(
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
)
tensor = tensor.contiguous()
# See: https://github.com/huggingface/transformers/blob/1fe1e3caa44617047f149bcc0c0b566343b714a7/src/transformers/models/t5/modeling_t5.py#LL316C15-L316C71
if module_name.endswith("wo"):
tensor = tensor.to(torch.float32)
else:
tensor = tensor.to(dtype)
if quantize == "bitsandbytes" and not module_name.endswith("wo"):
if not HAS_BITS_AND_BYTES:
raise ImportError(
"bitsandbytes is not available on your machine either because it is not installed "
"or you don't have a GPU.\n"
"You can install it with `pip install bitsandbytes`."
)
if (
type(module)
in [TensorParallelRowLinear, TensorParallelColumnLinear]
and param_name == "weight"
):
tensor = Int8Params(
tensor,
has_fp16_weights=False,
requires_grad=False,
).to(device)
state = bnb.MatmulLtState()
state.threshold = 6.0
state.has_fp16_weights = False
state.memory_efficient_backward = False
state.use_pool = True
state.CB = tensor.CB
state.SCB = tensor.SCB
tensor.CB = None
tensor.SCB = None
def replace_linear(state):
def linear(input, weight, bias):
out = bnb.matmul(
input,
weight,
state=state,
threshold=state.threshold,
bias=bias,
)
if state.CB is not None:
# we converted 8-bit row major to turing/ampere format
# in the first inference pass
# we no longer need the row-major weight
del state.CB
weight.data = state.CxB
return out
return linear
module.linear = replace_linear(state)
else:
tensor = tensor.to(device)
elif quantize == "gptq" and not module_name.endswith("wo"):
raise NotImplementedError("`gptq` is not implemented for now")
elif quantize is None or module_name.endswith("wo"):
tensor = tensor.to(device)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
if current_parameter_tensor is not None:
module._parameters[param_name] = tensor
else:
module._buffers[param_name] = tensor
def forward( def forward(
self, self,
input_ids, input_ids,
@ -260,13 +94,8 @@ class T5Sharded(Seq2SeqLM):
use_cache=True, use_cache=True,
) )
# Logits are sharded, so we need to gather them
logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
torch.distributed.all_gather(logits, outputs.logits, group=self.process_group)
logits = torch.cat(logits, dim=2)
return ( return (
logits, outputs.logits,
outputs.encoder_last_hidden_state, outputs.encoder_last_hidden_state,
outputs.past_key_values, outputs.past_key_values,
) )

View File

@ -1,5 +1,6 @@
from text_generation_server.utils.convert import convert_file, convert_files from text_generation_server.utils.convert import convert_file, convert_files
from text_generation_server.utils.dist import initialize_torch_distributed from text_generation_server.utils.dist import initialize_torch_distributed
from text_generation_server.utils.weights import Weights
from text_generation_server.utils.hub import ( from text_generation_server.utils.hub import (
weight_files, weight_files,
weight_hub_files, weight_hub_files,
@ -35,4 +36,5 @@ __all__ = [
"StoppingCriteria", "StoppingCriteria",
"StopSequenceCriteria", "StopSequenceCriteria",
"FinishReason", "FinishReason",
"Weights",
] ]

View File

@ -4,6 +4,37 @@ import torch
from datetime import timedelta from datetime import timedelta
class FakeBarrier:
def wait(self):
pass
class FakeGroup:
def __init__(self, rank, size):
self._rank = rank
self._size = size
def allreduce(self, *args, **kwargs):
return FakeBarrier()
def allgather(self, inputs, local_tensor, **kwargs):
assert (
len(inputs[0]) == len(local_tensor) == 1
), f"{len(inputs[0])} != {len(local_tensor)} != 1, and the FakeGroup is supposed to join on simple tensors"
for input_ in inputs:
input_[0].data = local_tensor[0].data
return FakeBarrier()
def barrier(self, *args, **kwargs):
return FakeBarrier()
def size(self):
return self._size
def rank(self):
return self._rank
def initialize_torch_distributed(): def initialize_torch_distributed():
rank = int(os.getenv("RANK", "0")) rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1")) world_size = int(os.getenv("WORLD_SIZE", "1"))
@ -23,6 +54,11 @@ def initialize_torch_distributed():
backend = "gloo" backend = "gloo"
options = None options = None
if world_size == 1:
return FakeGroup(rank, world_size), rank, world_size
else:
if os.getenv("DEBUG", None) == "1":
return FakeGroup(rank, world_size), rank, world_size
# Call the init process. # Call the init process.
torch.distributed.init_process_group( torch.distributed.init_process_group(
backend=backend, backend=backend,

View File

@ -1,176 +1,240 @@
import torch import torch
import torch.distributed
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from typing import Optional from typing import List
HAS_BITS_AND_BYTES = True HAS_BITS_AND_BYTES = True
try: try:
from bitsandbytes.nn import Linear8bitLt import bitsandbytes as bnb
except ImportError as e: from bitsandbytes.nn import Int8Params
except ImportError:
HAS_BITS_AND_BYTES = False HAS_BITS_AND_BYTES = False
from accelerate import init_empty_weights
class FastLinear(nn.Linear):
# Monkey patching
@classmethod
def load_layer_norm(cls, prefix, weights, eps):
weight = weights.get_tensor(f"{prefix}.weight")
bias = weights.get_tensor(f"{prefix}.bias")
with init_empty_weights():
ln = cls(weight.shape, eps=eps)
ln.weight = nn.Parameter(weight)
ln.bias = nn.Parameter(bias)
return ln
torch.nn.LayerNorm.load = load_layer_norm
class FastLinear(nn.Module):
def __init__( def __init__(
self, self,
in_features: int, weight,
out_features: int, bias,
bias: bool = True,
device=None,
dtype=None,
) -> None: ) -> None:
super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype) super().__init__()
self.quantized = False self.weight = nn.Parameter(weight)
self.bnb_linear = None if bias is not None:
self.bias = nn.Parameter(bias)
def prepare_weights(self, quantize: Optional[str] = None): else:
if quantize == "bitsandbytes":
if not HAS_BITS_AND_BYTES:
raise ImportError(
"bitsandbytes is not available on your machine either because it is not installed "
"or you don't have a GPU.\n"
"You can install it with `pip install bitsandbytes`."
)
self.quantized = True
self.bnb_linear = Linear8bitLt(
self.in_features,
self.out_features,
has_fp16_weights=False,
threshold=6.0,
bias=False,
)
# Copy data to bnb_linear
self.bnb_linear.weight.data = self.weight.data
if self.bias is not None:
self.bnb_linear.bias = nn.Parameter(self.bias)
# Delete reference to data
self.weight = None
self.bias = None self.bias = None
elif quantize == "gptq":
raise NotImplementedError("`gptq` is not implemented for now") @classmethod
elif quantize is None: def load(cls, config, prefix: str, weights, bias: bool):
self.weight = nn.Parameter(self.weight.T) weight = weights.get_tensor(f"{prefix}.weight")
if bias:
bias = weights.get_tensor(f"{prefix}.bias")
else: else:
raise ValueError(f"Unexpected quantize `{quantize}`") bias = None
return cls(weight, bias)
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.quantized: return F.linear(input, self.weight, self.bias)
return self.bnb_linear(input)
else:
if self.bias is not None:
return torch.addmm(self.bias, input, self.weight)
return torch.matmul(input, self.weight)
class TensorParallelColumnLinear(FastLinear): class Linear8bitLt(nn.Module):
def __init__( def __init__(
self, self,
in_features, weight,
out_features, bias,
process_group: torch.distributed.ProcessGroup, has_fp16_weights=True,
bias=True, memory_efficient_backward=False,
device=None, threshold=0.0,
dtype=None, index=None,
): ):
self.process_group = process_group super().__init__()
self.tp_world_size = process_group.size() assert (
assert out_features % self.tp_world_size == 0 not memory_efficient_backward
out_features = out_features // self.tp_world_size ), "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
self.state = bnb.MatmulLtState()
self.index = index
super().__init__( # Necessary for stacked layers
in_features=in_features, self.state.threshold = threshold
out_features=out_features, self.state.has_fp16_weights = has_fp16_weights
bias=bias, self.state.memory_efficient_backward = memory_efficient_backward
device=device, if threshold > 0.0 and not has_fp16_weights:
dtype=dtype, self.state.use_pool = True
self.weight = Int8Params(
weight.data,
has_fp16_weights=has_fp16_weights,
requires_grad=has_fp16_weights,
) )
self.weight.cuda(weight.device)
self.bias = bias
def init_8bit_state(self):
self.state.CB = self.weight.CB
self.state.SCB = self.weight.SCB
self.weight.CB = None
self.weight.SCB = None
class TensorParallelRowLinear(FastLinear): def forward(self, x: torch.Tensor):
def __init__( self.state.is_training = self.training
self, if self.weight.CB is not None:
in_features, self.init_8bit_state()
out_features,
process_group: torch.distributed.ProcessGroup,
reduce=True,
bias=True,
device=None,
dtype=None,
):
self.process_group = process_group
self.tp_world_size = process_group.size()
self.reduce = reduce
assert in_features % self.tp_world_size == 0
in_features = in_features // self.tp_world_size
super().__init__( # weights are cast automatically as Int8Params, but the bias has to be cast manually
in_features=in_features, if self.bias is not None and self.bias.dtype != x.dtype:
out_features=out_features, self.bias.data = self.bias.data.to(x.dtype)
bias=bias,
device=device,
dtype=dtype,
)
def forward(self, input: torch.Tensor) -> torch.Tensor: out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
out = super(TensorParallelRowLinear, self).forward(input)
if self.reduce:
torch.distributed.all_reduce(out, group=self.process_group)
if not self.state.has_fp16_weights:
if self.state.CB is not None and self.state.CxB is not None:
# we converted 8-bit row major to turing/ampere format in the first inference pass
# we no longer need the row-major weight
del self.state.CB
self.weight.data = self.state.CxB
return out return out
class TensorParallelEmbedding(nn.Embedding): def get_linear(weight, bias, quantize):
def __init__( if quantize is None:
self, linear = FastLinear(weight, bias)
num_embeddings, elif quantize == "bitsandbytes":
embedding_dim, linear = Linear8bitLt(
process_group: torch.distributed.ProcessGroup, weight,
reduce=True, bias,
padding_idx=None, has_fp16_weights=False,
max_norm=None, threshold=6.0,
norm_type=2.0, )
scale_grad_by_freq=False, if bias is not None:
sparse=False, linear.bias = nn.Parameter(bias)
_weight=None, elif quantize == "gptq":
device=None, raise NotImplementedError("Soon")
dtype=None, else:
): raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
self.reduce = reduce return linear
class SuperLayer(nn.Module):
def __init__(self, linear):
super().__init__()
self.linear = linear
def forward(self, x):
return self.linear.forward(x)
class TensorParallelHead(SuperLayer):
def __init__(self, linear, process_group):
super().__init__(linear)
self.process_group = process_group self.process_group = process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.original_num_embeddings = num_embeddings @staticmethod
def load(config, prefix: str, weights):
assert num_embeddings % self.tp_world_size == 0 weight = weights.get_sharded(f"{prefix}.weight", dim=0)
block_size = num_embeddings // self.tp_world_size return TensorParallelHead(
# inputs in `[min_id, max_id[` are handled by `self` to get embeddings get_linear(weight, bias=None, quantize=config.quantize),
self.min_id = self.tp_rank * block_size process_group=weights.process_group,
self.max_id = (self.tp_rank + 1) * block_size
# Additional entry that will map to zero
# Used for masking
self.null_idx = block_size
super().__init__(
block_size,
embedding_dim,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse,
_weight=_weight,
device=device,
dtype=dtype,
) )
def add_null_idx(self): def forward(self, input: torch.Tensor) -> torch.Tensor:
output = super().forward(input)
# Logits are sharded, so we need to gather them
world_output = [
torch.empty_like(output) for _ in range(self.process_group.size())
]
torch.distributed.all_gather(world_output, output, group=self.process_group)
world_output = torch.cat(world_output, dim=-1)
return world_output
class TensorParallelColumnLinear(SuperLayer):
@classmethod
def load(cls, config, prefix: str, weights, bias: bool):
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
if bias:
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
else:
bias = None
return cls(get_linear(weight, bias, config.quantize))
@classmethod
def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
weight = torch.cat(w, dim=dim)
if bias:
b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes]
bias = torch.cat(b, dim=0)
else:
bias = None
return cls(get_linear(weight, bias, config.quantize))
class TensorParallelRowLinear(SuperLayer):
def __init__(self, linear, process_group):
super().__init__(linear)
self.process_group = process_group
@classmethod
def load(cls, config, prefix: str, weights, bias: bool):
weight = weights.get_sharded(f"{prefix}.weight", dim=1)
if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process
bias = weights.get_tensor(f"{prefix}.bias")
else:
bias = None
return cls(
get_linear(weight, bias, config.quantize),
process_group=weights.process_group,
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
out = super().forward(input)
torch.distributed.all_reduce(out, group=self.process_group)
return out
class TensorParallelEmbedding(nn.Module):
def __init__(self, prefix: str, weights, reduce=True):
super().__init__()
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
num_embeddings = weights.get_shape(f"{prefix}.weight")[0]
process_group = weights.process_group
world_size = process_group.size()
rank = process_group.rank()
block_size = num_embeddings // world_size
self.min_id = rank * block_size
self.max_id = min(num_embeddings, (rank + 1) * block_size)
self.null_idx = block_size
self.process_group = weights.process_group
self.reduce = reduce
"""Additional 0 entry used for masking""" """Additional 0 entry used for masking"""
self.weight = nn.Parameter(F.pad(self.weight, (0, 0, 0, 1))) self.weight = nn.Parameter(F.pad(weight, (0, 0, 0, 1)))
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
# default all out of bounds values to `self.null_idx` that will then be mapped to 0 # default all out of bounds values to `self.null_idx` that will then be mapped to 0
@ -180,7 +244,7 @@ class TensorParallelEmbedding(nn.Embedding):
self.null_idx, self.null_idx,
input - self.min_id, input - self.min_id,
) )
out = super().forward(input) out = torch.nn.functional.embedding(input, self.weight)
if self.reduce: if self.reduce:
torch.distributed.all_reduce(out, group=self.process_group) torch.distributed.all_reduce(out, group=self.process_group)
return out return out
@ -232,7 +296,34 @@ try:
from flash_attn.layers.rotary import RotaryEmbedding from flash_attn.layers.rotary import RotaryEmbedding
import rotary_emb import rotary_emb
class PositionRotaryEmbedding(RotaryEmbedding): class PositionRotaryEmbedding(nn.Module):
def __init__(self, inv_freq):
super().__init__()
self.register_buffer("inv_freq", inv_freq)
self._seq_len_cached = 0
self._cos_cached = None
self._sin_cached = None
self._cos_k_cached = None
self._sin_k_cached = None
@classmethod
def static(cls, dim, base, device):
inv_freq = 1.0 / (
base
** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
)
return cls(inv_freq)
@classmethod
def load(cls, prefix, weights):
# XXX: Always load this in float32 !
dtype = weights.dtype
weights.dtype = torch.float32
inv_freq = weights.get_tensor(f"{prefix}.inv_freq")
weights.dtype = dtype
return cls(inv_freq)
def _update_cos_sin_cache(self, dtype, device, seqlen): def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed, # Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance) # or if we're on a new device (possibly due to tracing for instance)

View File

@ -0,0 +1,77 @@
from pathlib import Path
from typing import List
from safetensors import safe_open
class Weights:
def __init__(self, filenames: List[Path], device, dtype, process_group):
routing = {}
for filename in filenames:
with safe_open(filename, framework="pytorch") as f:
for k in f.keys():
if k in routing:
raise RuntimeError(
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
)
routing[k] = filename
self.routing = routing
self.device = device
self.dtype = dtype
self.process_group = process_group
self._handles = {}
def _get_handle(self, filename):
if filename not in self._handles:
f = safe_open(filename, framework="pytorch")
self._handles[filename] = f
return self._handles[filename]
def get_filename(self, tensor_name: str) -> str:
filename = self.routing.get(tensor_name, None)
if filename is None:
raise RuntimeError(f"weight {tensor_name} does not exist")
return str(filename)
def _get_slice(self, tensor_name: str):
filename = self.get_filename(tensor_name)
f = self._get_handle(filename)
slice_ = f.get_slice(tensor_name)
return slice_
def get_shape(self, tensor_name: str):
return self._get_slice(tensor_name).get_shape()
def get_tensor(self, tensor_name: str):
filename = self.get_filename(tensor_name)
f = self._get_handle(filename)
tensor = f.get_tensor(tensor_name)
tensor = tensor.to(dtype=self.dtype)
tensor = tensor.to(device=self.device)
return tensor
def get_sharded(self, tensor_name: str, dim: int):
filename = self.get_filename(tensor_name)
world_size = self.process_group.size()
rank = self.process_group.rank()
f = self._get_handle(filename)
slice_ = f.get_slice(tensor_name)
size = slice_.get_shape()[dim]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
assert (
size % world_size == 0
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
if dim == 0:
tensor = slice_[start:stop]
elif dim == 1:
tensor = slice_[:, start:stop]
else:
raise NotImplementedError("Let's make that generic when needed")
tensor = tensor.to(dtype=self.dtype)
tensor = tensor.to(device=self.device)
return tensor