feat(server): Rework model loading (#344)
# What does this PR do? Reworked the loading logic. Idea is to use cleaner loading code: - Remove need for `no_init_weights` - Remove all weird `bnb_linear` and `load_weights` and `post_load_weights`. New code layout: - New class `Weights` in charge of handling loading the weights from multiple files into appropiate tensors (potentially sharded) - TP layers now are "shells", they contain the code to know what kind of sharding we need + eventual `all_reduce`. They do not inherit from linear, but they contain some kind of Linear instead - the contained linear can be either FastLinear, BnbLinear or GPTq Linear next. - All modeling code is explictly made for sharding, process group is just no-ops for non sharded code (removes a lot of test cases) ![Screenshot from 2023-05-19 23-19-59](https://github.com/huggingface/text-generation-inference/assets/204321/9a802654-74a3-488c-87a8-073743a6143f) --------- Co-authored-by: Ubuntu <ubuntu@ip-172-31-41-161.taildb5d.ts.net> Co-authored-by: Ubuntu <ubuntu@ip-172-31-41-161.ec2.internal> Co-authored-by: OlivierDehaene <olivier@huggingface.co> Co-authored-by: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com>
This commit is contained in:
parent
19c41824cb
commit
abd58ff82c
|
@ -1,3 +1,4 @@
|
|||
.idea
|
||||
target
|
||||
router/tokenizer.json
|
||||
*__pycache__*
|
||||
|
|
17
Dockerfile
17
Dockerfile
|
@ -2,6 +2,8 @@
|
|||
FROM lukemathwalker/cargo-chef:latest-rust-1.69 AS chef
|
||||
WORKDIR /usr/src
|
||||
|
||||
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
|
||||
|
||||
FROM chef as planner
|
||||
COPY Cargo.toml Cargo.toml
|
||||
COPY rust-toolchain.toml rust-toolchain.toml
|
||||
|
@ -98,14 +100,14 @@ COPY server/Makefile-flash-att Makefile
|
|||
RUN make build-flash-attention
|
||||
|
||||
# Build Transformers CUDA kernels
|
||||
FROM kernel-builder as transformers-builder
|
||||
FROM kernel-builder as custom-kernels-builder
|
||||
|
||||
WORKDIR /usr/src
|
||||
|
||||
COPY server/Makefile-transformers Makefile
|
||||
COPY server/custom_kernels/ .
|
||||
|
||||
# Build specific version of transformers
|
||||
RUN BUILD_EXTENSIONS="True" make build-transformers
|
||||
RUN python setup.py build
|
||||
|
||||
# Text Generation Inference base image
|
||||
FROM nvidia/cuda:11.8.0-base-ubuntu20.04 as base
|
||||
|
@ -136,11 +138,10 @@ COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib
|
|||
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages
|
||||
|
||||
# Copy build artifacts from transformers builder
|
||||
COPY --from=transformers-builder /usr/src/transformers /usr/src/transformers
|
||||
COPY --from=transformers-builder /usr/src/transformers/build/lib.linux-x86_64-cpython-39/transformers /usr/src/transformers/src/transformers
|
||||
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39/custom_kernels /usr/src/custom-kernels/src/custom_kernels
|
||||
|
||||
# Install transformers dependencies
|
||||
RUN cd /usr/src/transformers && pip install -e . --no-cache-dir && pip install einops --no-cache-dir
|
||||
# Install flash-attention dependencies
|
||||
RUN pip install einops --no-cache-dir
|
||||
|
||||
# Install server
|
||||
COPY proto proto
|
||||
|
@ -170,4 +171,4 @@ ENTRYPOINT ["./entrypoint.sh"]
|
|||
FROM base
|
||||
|
||||
ENTRYPOINT ["text-generation-launcher"]
|
||||
CMD ["--json-output"]
|
||||
CMD ["--json-output"]
|
||||
|
|
7
Makefile
7
Makefile
|
@ -1,6 +1,9 @@
|
|||
install-server:
|
||||
cd server && make install
|
||||
|
||||
install-custom-kernels:
|
||||
if [ "$$BUILD_EXTENSIONS" == "True" ]; then cd server/custom_kernels && python setup.py install; else echo "Custom kernels are disabled, you need set to BUILD_EXTENSION environment variable to 'True' in order to build them. (Please read the docs, kernels might not work on all hardware)"; fi
|
||||
|
||||
install-integration-tests:
|
||||
cd integration-tests && pip install -r requirements.txt
|
||||
cd clients/python && pip install .
|
||||
|
@ -14,7 +17,7 @@ install-launcher:
|
|||
install-benchmark:
|
||||
cd benchmark && cargo install --path .
|
||||
|
||||
install: install-server install-router install-launcher
|
||||
install: install-server install-router install-launcher install-custom-kernels
|
||||
|
||||
server-dev:
|
||||
cd server && make run-dev
|
||||
|
@ -52,4 +55,4 @@ run-bloom:
|
|||
text-generation-launcher --model-id bigscience/bloom --num-shard 8 --port 8080
|
||||
|
||||
run-bloom-quantize:
|
||||
text-generation-launcher --model-id bigscience/bloom --num-shard 8 --quantize --port 8080
|
||||
text-generation-launcher --model-id bigscience/bloom --num-shard 8 --quantize --port 8080
|
||||
|
|
|
@ -209,6 +209,7 @@ def launcher(event_loop):
|
|||
num_shard: Optional[int] = None,
|
||||
quantize: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
use_flash_attention: bool = True,
|
||||
):
|
||||
port = random.randint(8000, 10_000)
|
||||
master_port = random.randint(10_000, 20_000)
|
||||
|
@ -240,6 +241,9 @@ def launcher(event_loop):
|
|||
env = os.environ
|
||||
env["LOG_LEVEL"] = "info,text_generation_router=debug"
|
||||
|
||||
if not use_flash_attention:
|
||||
env["USE_FLASH_ATTENTION"] = "false"
|
||||
|
||||
with subprocess.Popen(
|
||||
args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env
|
||||
) as process:
|
||||
|
@ -254,12 +258,16 @@ def launcher(event_loop):
|
|||
process.stdout.close()
|
||||
process.stderr.close()
|
||||
|
||||
if not use_flash_attention:
|
||||
del env["USE_FLASH_ATTENTION"]
|
||||
|
||||
@contextlib.contextmanager
|
||||
def docker_launcher(
|
||||
model_id: str,
|
||||
num_shard: Optional[int] = None,
|
||||
quantize: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
use_flash_attention: bool = True,
|
||||
):
|
||||
port = random.randint(8000, 10_000)
|
||||
|
||||
|
@ -287,6 +295,9 @@ def launcher(event_loop):
|
|||
gpu_count = num_shard if num_shard is not None else 1
|
||||
|
||||
env = {"LOG_LEVEL": "info,text_generation_router=debug"}
|
||||
if not use_flash_attention:
|
||||
env["USE_FLASH_ATTENTION"] = "false"
|
||||
|
||||
if HUGGING_FACE_HUB_TOKEN is not None:
|
||||
env["HUGGING_FACE_HUB_TOKEN"] = HUGGING_FACE_HUB_TOKEN
|
||||
|
||||
|
@ -310,6 +321,9 @@ def launcher(event_loop):
|
|||
|
||||
yield ContainerLauncherHandle(client, container.name, port)
|
||||
|
||||
if not use_flash_attention:
|
||||
del env["USE_FLASH_ATTENTION"]
|
||||
|
||||
try:
|
||||
container.stop()
|
||||
container.wait()
|
||||
|
|
|
@ -11,17 +11,17 @@
|
|||
},
|
||||
{
|
||||
"id": 1459,
|
||||
"logprob": -5.6289062,
|
||||
"logprob": -5.6328125,
|
||||
"text": " print"
|
||||
},
|
||||
{
|
||||
"id": 81,
|
||||
"logprob": -1.6005859,
|
||||
"logprob": -1.6035156,
|
||||
"text": "_"
|
||||
},
|
||||
{
|
||||
"id": 7656,
|
||||
"logprob": -5.9921875,
|
||||
"logprob": -5.9882812,
|
||||
"text": "hello"
|
||||
}
|
||||
],
|
||||
|
@ -59,19 +59,19 @@
|
|||
},
|
||||
{
|
||||
"id": 10896,
|
||||
"logprob": -0.3659668,
|
||||
"logprob": -0.38549805,
|
||||
"special": false,
|
||||
"text": " World"
|
||||
},
|
||||
{
|
||||
"id": 657,
|
||||
"logprob": -0.49804688,
|
||||
"logprob": -0.5229492,
|
||||
"special": false,
|
||||
"text": "\")"
|
||||
},
|
||||
{
|
||||
"id": 203,
|
||||
"logprob": -0.11279297,
|
||||
"logprob": -0.10632324,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
|
@ -113,7 +113,7 @@
|
|||
},
|
||||
{
|
||||
"id": 426,
|
||||
"logprob": -0.051635742,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "name"
|
||||
},
|
||||
|
|
|
@ -0,0 +1,113 @@
|
|||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 50278,
|
||||
"logprob": null,
|
||||
"text": "<|USER|>"
|
||||
},
|
||||
{
|
||||
"id": 1276,
|
||||
"logprob": -4.5546875,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 434,
|
||||
"logprob": -4.1992188,
|
||||
"text": "'s"
|
||||
},
|
||||
{
|
||||
"id": 634,
|
||||
"logprob": -5.125,
|
||||
"text": " your"
|
||||
},
|
||||
{
|
||||
"id": 12315,
|
||||
"logprob": -9.8984375,
|
||||
"text": " mood"
|
||||
},
|
||||
{
|
||||
"id": 3063,
|
||||
"logprob": -4.0976562,
|
||||
"text": " today"
|
||||
},
|
||||
{
|
||||
"id": 32,
|
||||
"logprob": -0.14562988,
|
||||
"text": "?"
|
||||
},
|
||||
{
|
||||
"id": 50279,
|
||||
"logprob": -0.26733398,
|
||||
"text": "<|ASSISTANT|>"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 42,
|
||||
"logprob": -0.86279297,
|
||||
"special": false,
|
||||
"text": "I"
|
||||
},
|
||||
{
|
||||
"id": 1353,
|
||||
"logprob": -0.94921875,
|
||||
"special": false,
|
||||
"text": "'m"
|
||||
},
|
||||
{
|
||||
"id": 7016,
|
||||
"logprob": -2.1835938,
|
||||
"special": false,
|
||||
"text": " sorry"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.074035645,
|
||||
"special": false,
|
||||
"text": ","
|
||||
},
|
||||
{
|
||||
"id": 1394,
|
||||
"logprob": -0.86376953,
|
||||
"special": false,
|
||||
"text": "You"
|
||||
},
|
||||
{
|
||||
"id": 452,
|
||||
"logprob": -1.2070312,
|
||||
"special": false,
|
||||
"text": " have"
|
||||
},
|
||||
{
|
||||
"id": 247,
|
||||
"logprob": -1.4365234,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 4327,
|
||||
"logprob": -1.109375,
|
||||
"special": false,
|
||||
"text": " choice"
|
||||
},
|
||||
{
|
||||
"id": 273,
|
||||
"logprob": -0.93408203,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 752,
|
||||
"logprob": -1.8808594,
|
||||
"special": false,
|
||||
"text": " what"
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_text": "I'm sorry,You have a choice of what"
|
||||
}
|
|
@ -0,0 +1,454 @@
|
|||
[
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 50278,
|
||||
"logprob": null,
|
||||
"text": "<|USER|>"
|
||||
},
|
||||
{
|
||||
"id": 1276,
|
||||
"logprob": -4.5546875,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 434,
|
||||
"logprob": -4.1953125,
|
||||
"text": "'s"
|
||||
},
|
||||
{
|
||||
"id": 634,
|
||||
"logprob": -5.125,
|
||||
"text": " your"
|
||||
},
|
||||
{
|
||||
"id": 12315,
|
||||
"logprob": -9.8828125,
|
||||
"text": " mood"
|
||||
},
|
||||
{
|
||||
"id": 3063,
|
||||
"logprob": -3.9980469,
|
||||
"text": " today"
|
||||
},
|
||||
{
|
||||
"id": 32,
|
||||
"logprob": -0.14672852,
|
||||
"text": "?"
|
||||
},
|
||||
{
|
||||
"id": 50279,
|
||||
"logprob": -0.26489258,
|
||||
"text": "<|ASSISTANT|>"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 42,
|
||||
"logprob": -0.8618164,
|
||||
"special": false,
|
||||
"text": "I"
|
||||
},
|
||||
{
|
||||
"id": 1353,
|
||||
"logprob": -0.9506836,
|
||||
"special": false,
|
||||
"text": "'m"
|
||||
},
|
||||
{
|
||||
"id": 7016,
|
||||
"logprob": -2.1738281,
|
||||
"special": false,
|
||||
"text": " sorry"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.0758667,
|
||||
"special": false,
|
||||
"text": ","
|
||||
},
|
||||
{
|
||||
"id": 1394,
|
||||
"logprob": -0.9135742,
|
||||
"special": false,
|
||||
"text": "You"
|
||||
},
|
||||
{
|
||||
"id": 452,
|
||||
"logprob": -1.1445312,
|
||||
"special": false,
|
||||
"text": " have"
|
||||
},
|
||||
{
|
||||
"id": 247,
|
||||
"logprob": -1.4375,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 4327,
|
||||
"logprob": -1.1103516,
|
||||
"special": false,
|
||||
"text": " choice"
|
||||
},
|
||||
{
|
||||
"id": 273,
|
||||
"logprob": -1.0058594,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 752,
|
||||
"logprob": -1.921875,
|
||||
"special": false,
|
||||
"text": " what"
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_text": "I'm sorry,You have a choice of what"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 50278,
|
||||
"logprob": null,
|
||||
"text": "<|USER|>"
|
||||
},
|
||||
{
|
||||
"id": 1276,
|
||||
"logprob": -4.5546875,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 434,
|
||||
"logprob": -4.1953125,
|
||||
"text": "'s"
|
||||
},
|
||||
{
|
||||
"id": 634,
|
||||
"logprob": -5.125,
|
||||
"text": " your"
|
||||
},
|
||||
{
|
||||
"id": 12315,
|
||||
"logprob": -9.8828125,
|
||||
"text": " mood"
|
||||
},
|
||||
{
|
||||
"id": 3063,
|
||||
"logprob": -3.9980469,
|
||||
"text": " today"
|
||||
},
|
||||
{
|
||||
"id": 32,
|
||||
"logprob": -0.14672852,
|
||||
"text": "?"
|
||||
},
|
||||
{
|
||||
"id": 50279,
|
||||
"logprob": -0.26489258,
|
||||
"text": "<|ASSISTANT|>"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 42,
|
||||
"logprob": -0.8618164,
|
||||
"special": false,
|
||||
"text": "I"
|
||||
},
|
||||
{
|
||||
"id": 1353,
|
||||
"logprob": -0.9506836,
|
||||
"special": false,
|
||||
"text": "'m"
|
||||
},
|
||||
{
|
||||
"id": 7016,
|
||||
"logprob": -2.1738281,
|
||||
"special": false,
|
||||
"text": " sorry"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.0758667,
|
||||
"special": false,
|
||||
"text": ","
|
||||
},
|
||||
{
|
||||
"id": 1394,
|
||||
"logprob": -0.9135742,
|
||||
"special": false,
|
||||
"text": "You"
|
||||
},
|
||||
{
|
||||
"id": 452,
|
||||
"logprob": -1.1445312,
|
||||
"special": false,
|
||||
"text": " have"
|
||||
},
|
||||
{
|
||||
"id": 247,
|
||||
"logprob": -1.4375,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 4327,
|
||||
"logprob": -1.1103516,
|
||||
"special": false,
|
||||
"text": " choice"
|
||||
},
|
||||
{
|
||||
"id": 273,
|
||||
"logprob": -1.0058594,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 752,
|
||||
"logprob": -1.921875,
|
||||
"special": false,
|
||||
"text": " what"
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_text": "I'm sorry,You have a choice of what"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 50278,
|
||||
"logprob": null,
|
||||
"text": "<|USER|>"
|
||||
},
|
||||
{
|
||||
"id": 1276,
|
||||
"logprob": -4.5546875,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 434,
|
||||
"logprob": -4.1953125,
|
||||
"text": "'s"
|
||||
},
|
||||
{
|
||||
"id": 634,
|
||||
"logprob": -5.125,
|
||||
"text": " your"
|
||||
},
|
||||
{
|
||||
"id": 12315,
|
||||
"logprob": -9.8828125,
|
||||
"text": " mood"
|
||||
},
|
||||
{
|
||||
"id": 3063,
|
||||
"logprob": -3.9980469,
|
||||
"text": " today"
|
||||
},
|
||||
{
|
||||
"id": 32,
|
||||
"logprob": -0.14672852,
|
||||
"text": "?"
|
||||
},
|
||||
{
|
||||
"id": 50279,
|
||||
"logprob": -0.26489258,
|
||||
"text": "<|ASSISTANT|>"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 42,
|
||||
"logprob": -0.8618164,
|
||||
"special": false,
|
||||
"text": "I"
|
||||
},
|
||||
{
|
||||
"id": 1353,
|
||||
"logprob": -0.9506836,
|
||||
"special": false,
|
||||
"text": "'m"
|
||||
},
|
||||
{
|
||||
"id": 7016,
|
||||
"logprob": -2.1738281,
|
||||
"special": false,
|
||||
"text": " sorry"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.0758667,
|
||||
"special": false,
|
||||
"text": ","
|
||||
},
|
||||
{
|
||||
"id": 1394,
|
||||
"logprob": -0.9135742,
|
||||
"special": false,
|
||||
"text": "You"
|
||||
},
|
||||
{
|
||||
"id": 452,
|
||||
"logprob": -1.1445312,
|
||||
"special": false,
|
||||
"text": " have"
|
||||
},
|
||||
{
|
||||
"id": 247,
|
||||
"logprob": -1.4375,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 4327,
|
||||
"logprob": -1.1103516,
|
||||
"special": false,
|
||||
"text": " choice"
|
||||
},
|
||||
{
|
||||
"id": 273,
|
||||
"logprob": -1.0058594,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 752,
|
||||
"logprob": -1.921875,
|
||||
"special": false,
|
||||
"text": " what"
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_text": "I'm sorry,You have a choice of what"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 50278,
|
||||
"logprob": null,
|
||||
"text": "<|USER|>"
|
||||
},
|
||||
{
|
||||
"id": 1276,
|
||||
"logprob": -4.5546875,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 434,
|
||||
"logprob": -4.1953125,
|
||||
"text": "'s"
|
||||
},
|
||||
{
|
||||
"id": 634,
|
||||
"logprob": -5.125,
|
||||
"text": " your"
|
||||
},
|
||||
{
|
||||
"id": 12315,
|
||||
"logprob": -9.8828125,
|
||||
"text": " mood"
|
||||
},
|
||||
{
|
||||
"id": 3063,
|
||||
"logprob": -3.9980469,
|
||||
"text": " today"
|
||||
},
|
||||
{
|
||||
"id": 32,
|
||||
"logprob": -0.14672852,
|
||||
"text": "?"
|
||||
},
|
||||
{
|
||||
"id": 50279,
|
||||
"logprob": -0.26489258,
|
||||
"text": "<|ASSISTANT|>"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 42,
|
||||
"logprob": -0.8618164,
|
||||
"special": false,
|
||||
"text": "I"
|
||||
},
|
||||
{
|
||||
"id": 1353,
|
||||
"logprob": -0.9506836,
|
||||
"special": false,
|
||||
"text": "'m"
|
||||
},
|
||||
{
|
||||
"id": 7016,
|
||||
"logprob": -2.1738281,
|
||||
"special": false,
|
||||
"text": " sorry"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -0.0758667,
|
||||
"special": false,
|
||||
"text": ","
|
||||
},
|
||||
{
|
||||
"id": 1394,
|
||||
"logprob": -0.9135742,
|
||||
"special": false,
|
||||
"text": "You"
|
||||
},
|
||||
{
|
||||
"id": 452,
|
||||
"logprob": -1.1445312,
|
||||
"special": false,
|
||||
"text": " have"
|
||||
},
|
||||
{
|
||||
"id": 247,
|
||||
"logprob": -1.4375,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 4327,
|
||||
"logprob": -1.1103516,
|
||||
"special": false,
|
||||
"text": " choice"
|
||||
},
|
||||
{
|
||||
"id": 273,
|
||||
"logprob": -1.0058594,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 752,
|
||||
"logprob": -1.921875,
|
||||
"special": false,
|
||||
"text": " what"
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_text": "I'm sorry,You have a choice of what"
|
||||
}
|
||||
]
|
|
@ -0,0 +1,163 @@
|
|||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 50278,
|
||||
"logprob": null,
|
||||
"text": "<|prompter|>"
|
||||
},
|
||||
{
|
||||
"id": 1276,
|
||||
"logprob": -8.0234375,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 310,
|
||||
"logprob": -5.4179688,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 247,
|
||||
"logprob": -2.1542969,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 1167,
|
||||
"logprob": -5.359375,
|
||||
"text": " mem"
|
||||
},
|
||||
{
|
||||
"id": 70,
|
||||
"logprob": -0.006038666,
|
||||
"text": "e"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -7.328125,
|
||||
"text": ","
|
||||
},
|
||||
{
|
||||
"id": 285,
|
||||
"logprob": -0.3173828,
|
||||
"text": " and"
|
||||
},
|
||||
{
|
||||
"id": 752,
|
||||
"logprob": -2.0625,
|
||||
"text": " what"
|
||||
},
|
||||
{
|
||||
"id": 434,
|
||||
"logprob": -5.7734375,
|
||||
"text": "'s"
|
||||
},
|
||||
{
|
||||
"id": 253,
|
||||
"logprob": -0.74072266,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 2892,
|
||||
"logprob": -6.5898438,
|
||||
"text": " history"
|
||||
},
|
||||
{
|
||||
"id": 3212,
|
||||
"logprob": -2.2949219,
|
||||
"text": " behind"
|
||||
},
|
||||
{
|
||||
"id": 436,
|
||||
"logprob": -11.40625,
|
||||
"text": " this"
|
||||
},
|
||||
{
|
||||
"id": 3159,
|
||||
"logprob": -2.1113281,
|
||||
"text": " word"
|
||||
},
|
||||
{
|
||||
"id": 32,
|
||||
"logprob": -0.008056641,
|
||||
"text": "?"
|
||||
},
|
||||
{
|
||||
"id": 0,
|
||||
"logprob": -2.3300781,
|
||||
"text": "<|endoftext|>"
|
||||
},
|
||||
{
|
||||
"id": 50281,
|
||||
"logprob": -18.28125,
|
||||
"text": "<|assistant|>"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 510,
|
||||
"logprob": -0.5878906,
|
||||
"special": false,
|
||||
"text": "The"
|
||||
},
|
||||
{
|
||||
"id": 3159,
|
||||
"logprob": -0.5449219,
|
||||
"special": false,
|
||||
"text": " word"
|
||||
},
|
||||
{
|
||||
"id": 346,
|
||||
"logprob": -0.05038452,
|
||||
"special": false,
|
||||
"text": " \""
|
||||
},
|
||||
{
|
||||
"id": 6441,
|
||||
"logprob": -0.002292633,
|
||||
"special": false,
|
||||
"text": "mem"
|
||||
},
|
||||
{
|
||||
"id": 70,
|
||||
"logprob": -1.3828278e-05,
|
||||
"special": false,
|
||||
"text": "e"
|
||||
},
|
||||
{
|
||||
"id": 3,
|
||||
"logprob": -0.0010242462,
|
||||
"special": false,
|
||||
"text": "\""
|
||||
},
|
||||
{
|
||||
"id": 369,
|
||||
"logprob": -0.090270996,
|
||||
"special": false,
|
||||
"text": " was"
|
||||
},
|
||||
{
|
||||
"id": 806,
|
||||
"logprob": -0.12719727,
|
||||
"special": false,
|
||||
"text": " first"
|
||||
},
|
||||
{
|
||||
"id": 908,
|
||||
"logprob": -0.016571045,
|
||||
"special": false,
|
||||
"text": " used"
|
||||
},
|
||||
{
|
||||
"id": 275,
|
||||
"logprob": -0.43432617,
|
||||
"special": false,
|
||||
"text": " in"
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_text": "The word \"meme\" was first used in"
|
||||
}
|
|
@ -0,0 +1,654 @@
|
|||
[
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 50278,
|
||||
"logprob": null,
|
||||
"text": "<|prompter|>"
|
||||
},
|
||||
{
|
||||
"id": 1276,
|
||||
"logprob": -8.0234375,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 310,
|
||||
"logprob": -5.4179688,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 247,
|
||||
"logprob": -2.1542969,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 1167,
|
||||
"logprob": -5.359375,
|
||||
"text": " mem"
|
||||
},
|
||||
{
|
||||
"id": 70,
|
||||
"logprob": -0.006038666,
|
||||
"text": "e"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -7.328125,
|
||||
"text": ","
|
||||
},
|
||||
{
|
||||
"id": 285,
|
||||
"logprob": -0.3173828,
|
||||
"text": " and"
|
||||
},
|
||||
{
|
||||
"id": 752,
|
||||
"logprob": -2.0625,
|
||||
"text": " what"
|
||||
},
|
||||
{
|
||||
"id": 434,
|
||||
"logprob": -5.7734375,
|
||||
"text": "'s"
|
||||
},
|
||||
{
|
||||
"id": 253,
|
||||
"logprob": -0.74072266,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 2892,
|
||||
"logprob": -6.5898438,
|
||||
"text": " history"
|
||||
},
|
||||
{
|
||||
"id": 3212,
|
||||
"logprob": -2.2949219,
|
||||
"text": " behind"
|
||||
},
|
||||
{
|
||||
"id": 436,
|
||||
"logprob": -11.40625,
|
||||
"text": " this"
|
||||
},
|
||||
{
|
||||
"id": 3159,
|
||||
"logprob": -2.1113281,
|
||||
"text": " word"
|
||||
},
|
||||
{
|
||||
"id": 32,
|
||||
"logprob": -0.008056641,
|
||||
"text": "?"
|
||||
},
|
||||
{
|
||||
"id": 0,
|
||||
"logprob": -2.3300781,
|
||||
"text": "<|endoftext|>"
|
||||
},
|
||||
{
|
||||
"id": 50281,
|
||||
"logprob": -18.28125,
|
||||
"text": "<|assistant|>"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 510,
|
||||
"logprob": -0.5878906,
|
||||
"special": false,
|
||||
"text": "The"
|
||||
},
|
||||
{
|
||||
"id": 3159,
|
||||
"logprob": -0.5498047,
|
||||
"special": false,
|
||||
"text": " word"
|
||||
},
|
||||
{
|
||||
"id": 346,
|
||||
"logprob": -0.04815674,
|
||||
"special": false,
|
||||
"text": " \""
|
||||
},
|
||||
{
|
||||
"id": 6441,
|
||||
"logprob": -0.002313614,
|
||||
"special": false,
|
||||
"text": "mem"
|
||||
},
|
||||
{
|
||||
"id": 70,
|
||||
"logprob": -1.2636185e-05,
|
||||
"special": false,
|
||||
"text": "e"
|
||||
},
|
||||
{
|
||||
"id": 3,
|
||||
"logprob": -0.0010147095,
|
||||
"special": false,
|
||||
"text": "\""
|
||||
},
|
||||
{
|
||||
"id": 369,
|
||||
"logprob": -0.0859375,
|
||||
"special": false,
|
||||
"text": " was"
|
||||
},
|
||||
{
|
||||
"id": 806,
|
||||
"logprob": -0.12609863,
|
||||
"special": false,
|
||||
"text": " first"
|
||||
},
|
||||
{
|
||||
"id": 908,
|
||||
"logprob": -0.016601562,
|
||||
"special": false,
|
||||
"text": " used"
|
||||
},
|
||||
{
|
||||
"id": 275,
|
||||
"logprob": -0.38256836,
|
||||
"special": false,
|
||||
"text": " in"
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_text": "The word \"meme\" was first used in"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 50278,
|
||||
"logprob": null,
|
||||
"text": "<|prompter|>"
|
||||
},
|
||||
{
|
||||
"id": 1276,
|
||||
"logprob": -8.0234375,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 310,
|
||||
"logprob": -5.421875,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 247,
|
||||
"logprob": -2.1640625,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 1167,
|
||||
"logprob": -5.40625,
|
||||
"text": " mem"
|
||||
},
|
||||
{
|
||||
"id": 70,
|
||||
"logprob": -0.005420685,
|
||||
"text": "e"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -7.2226562,
|
||||
"text": ","
|
||||
},
|
||||
{
|
||||
"id": 285,
|
||||
"logprob": -0.26879883,
|
||||
"text": " and"
|
||||
},
|
||||
{
|
||||
"id": 752,
|
||||
"logprob": -2.1992188,
|
||||
"text": " what"
|
||||
},
|
||||
{
|
||||
"id": 434,
|
||||
"logprob": -5.46875,
|
||||
"text": "'s"
|
||||
},
|
||||
{
|
||||
"id": 253,
|
||||
"logprob": -0.8017578,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 2892,
|
||||
"logprob": -6.6796875,
|
||||
"text": " history"
|
||||
},
|
||||
{
|
||||
"id": 3212,
|
||||
"logprob": -2.1972656,
|
||||
"text": " behind"
|
||||
},
|
||||
{
|
||||
"id": 436,
|
||||
"logprob": -11.4453125,
|
||||
"text": " this"
|
||||
},
|
||||
{
|
||||
"id": 3159,
|
||||
"logprob": -2.1933594,
|
||||
"text": " word"
|
||||
},
|
||||
{
|
||||
"id": 32,
|
||||
"logprob": -0.007858276,
|
||||
"text": "?"
|
||||
},
|
||||
{
|
||||
"id": 0,
|
||||
"logprob": -2.328125,
|
||||
"text": "<|endoftext|>"
|
||||
},
|
||||
{
|
||||
"id": 50281,
|
||||
"logprob": -18.21875,
|
||||
"text": "<|assistant|>"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 510,
|
||||
"logprob": -0.6201172,
|
||||
"special": false,
|
||||
"text": "The"
|
||||
},
|
||||
{
|
||||
"id": 3159,
|
||||
"logprob": -0.546875,
|
||||
"special": false,
|
||||
"text": " word"
|
||||
},
|
||||
{
|
||||
"id": 346,
|
||||
"logprob": -0.051879883,
|
||||
"special": false,
|
||||
"text": " \""
|
||||
},
|
||||
{
|
||||
"id": 6441,
|
||||
"logprob": -0.0020179749,
|
||||
"special": false,
|
||||
"text": "mem"
|
||||
},
|
||||
{
|
||||
"id": 70,
|
||||
"logprob": -9.059906e-06,
|
||||
"special": false,
|
||||
"text": "e"
|
||||
},
|
||||
{
|
||||
"id": 3,
|
||||
"logprob": -0.00096797943,
|
||||
"special": false,
|
||||
"text": "\""
|
||||
},
|
||||
{
|
||||
"id": 369,
|
||||
"logprob": -0.07940674,
|
||||
"special": false,
|
||||
"text": " was"
|
||||
},
|
||||
{
|
||||
"id": 806,
|
||||
"logprob": -0.12182617,
|
||||
"special": false,
|
||||
"text": " first"
|
||||
},
|
||||
{
|
||||
"id": 908,
|
||||
"logprob": -0.017227173,
|
||||
"special": false,
|
||||
"text": " used"
|
||||
},
|
||||
{
|
||||
"id": 275,
|
||||
"logprob": -0.44482422,
|
||||
"special": false,
|
||||
"text": " in"
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_text": "The word \"meme\" was first used in"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 50278,
|
||||
"logprob": null,
|
||||
"text": "<|prompter|>"
|
||||
},
|
||||
{
|
||||
"id": 1276,
|
||||
"logprob": -8.0234375,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 310,
|
||||
"logprob": -5.421875,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 247,
|
||||
"logprob": -2.1640625,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 1167,
|
||||
"logprob": -5.40625,
|
||||
"text": " mem"
|
||||
},
|
||||
{
|
||||
"id": 70,
|
||||
"logprob": -0.005420685,
|
||||
"text": "e"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -7.2226562,
|
||||
"text": ","
|
||||
},
|
||||
{
|
||||
"id": 285,
|
||||
"logprob": -0.26879883,
|
||||
"text": " and"
|
||||
},
|
||||
{
|
||||
"id": 752,
|
||||
"logprob": -2.1992188,
|
||||
"text": " what"
|
||||
},
|
||||
{
|
||||
"id": 434,
|
||||
"logprob": -5.46875,
|
||||
"text": "'s"
|
||||
},
|
||||
{
|
||||
"id": 253,
|
||||
"logprob": -0.8017578,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 2892,
|
||||
"logprob": -6.6796875,
|
||||
"text": " history"
|
||||
},
|
||||
{
|
||||
"id": 3212,
|
||||
"logprob": -2.1972656,
|
||||
"text": " behind"
|
||||
},
|
||||
{
|
||||
"id": 436,
|
||||
"logprob": -11.4453125,
|
||||
"text": " this"
|
||||
},
|
||||
{
|
||||
"id": 3159,
|
||||
"logprob": -2.1933594,
|
||||
"text": " word"
|
||||
},
|
||||
{
|
||||
"id": 32,
|
||||
"logprob": -0.007858276,
|
||||
"text": "?"
|
||||
},
|
||||
{
|
||||
"id": 0,
|
||||
"logprob": -2.328125,
|
||||
"text": "<|endoftext|>"
|
||||
},
|
||||
{
|
||||
"id": 50281,
|
||||
"logprob": -18.21875,
|
||||
"text": "<|assistant|>"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 510,
|
||||
"logprob": -0.6201172,
|
||||
"special": false,
|
||||
"text": "The"
|
||||
},
|
||||
{
|
||||
"id": 3159,
|
||||
"logprob": -0.546875,
|
||||
"special": false,
|
||||
"text": " word"
|
||||
},
|
||||
{
|
||||
"id": 346,
|
||||
"logprob": -0.051879883,
|
||||
"special": false,
|
||||
"text": " \""
|
||||
},
|
||||
{
|
||||
"id": 6441,
|
||||
"logprob": -0.0020179749,
|
||||
"special": false,
|
||||
"text": "mem"
|
||||
},
|
||||
{
|
||||
"id": 70,
|
||||
"logprob": -9.059906e-06,
|
||||
"special": false,
|
||||
"text": "e"
|
||||
},
|
||||
{
|
||||
"id": 3,
|
||||
"logprob": -0.00096797943,
|
||||
"special": false,
|
||||
"text": "\""
|
||||
},
|
||||
{
|
||||
"id": 369,
|
||||
"logprob": -0.07940674,
|
||||
"special": false,
|
||||
"text": " was"
|
||||
},
|
||||
{
|
||||
"id": 806,
|
||||
"logprob": -0.12182617,
|
||||
"special": false,
|
||||
"text": " first"
|
||||
},
|
||||
{
|
||||
"id": 908,
|
||||
"logprob": -0.017227173,
|
||||
"special": false,
|
||||
"text": " used"
|
||||
},
|
||||
{
|
||||
"id": 275,
|
||||
"logprob": -0.44482422,
|
||||
"special": false,
|
||||
"text": " in"
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_text": "The word \"meme\" was first used in"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 50278,
|
||||
"logprob": null,
|
||||
"text": "<|prompter|>"
|
||||
},
|
||||
{
|
||||
"id": 1276,
|
||||
"logprob": -8.0234375,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 310,
|
||||
"logprob": -5.421875,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 247,
|
||||
"logprob": -2.1640625,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 1167,
|
||||
"logprob": -5.40625,
|
||||
"text": " mem"
|
||||
},
|
||||
{
|
||||
"id": 70,
|
||||
"logprob": -0.005420685,
|
||||
"text": "e"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": -7.2226562,
|
||||
"text": ","
|
||||
},
|
||||
{
|
||||
"id": 285,
|
||||
"logprob": -0.26879883,
|
||||
"text": " and"
|
||||
},
|
||||
{
|
||||
"id": 752,
|
||||
"logprob": -2.1992188,
|
||||
"text": " what"
|
||||
},
|
||||
{
|
||||
"id": 434,
|
||||
"logprob": -5.46875,
|
||||
"text": "'s"
|
||||
},
|
||||
{
|
||||
"id": 253,
|
||||
"logprob": -0.8017578,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 2892,
|
||||
"logprob": -6.6796875,
|
||||
"text": " history"
|
||||
},
|
||||
{
|
||||
"id": 3212,
|
||||
"logprob": -2.1972656,
|
||||
"text": " behind"
|
||||
},
|
||||
{
|
||||
"id": 436,
|
||||
"logprob": -11.4453125,
|
||||
"text": " this"
|
||||
},
|
||||
{
|
||||
"id": 3159,
|
||||
"logprob": -2.1933594,
|
||||
"text": " word"
|
||||
},
|
||||
{
|
||||
"id": 32,
|
||||
"logprob": -0.007858276,
|
||||
"text": "?"
|
||||
},
|
||||
{
|
||||
"id": 0,
|
||||
"logprob": -2.328125,
|
||||
"text": "<|endoftext|>"
|
||||
},
|
||||
{
|
||||
"id": 50281,
|
||||
"logprob": -18.21875,
|
||||
"text": "<|assistant|>"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 510,
|
||||
"logprob": -0.6201172,
|
||||
"special": false,
|
||||
"text": "The"
|
||||
},
|
||||
{
|
||||
"id": 3159,
|
||||
"logprob": -0.546875,
|
||||
"special": false,
|
||||
"text": " word"
|
||||
},
|
||||
{
|
||||
"id": 346,
|
||||
"logprob": -0.051879883,
|
||||
"special": false,
|
||||
"text": " \""
|
||||
},
|
||||
{
|
||||
"id": 6441,
|
||||
"logprob": -0.0020179749,
|
||||
"special": false,
|
||||
"text": "mem"
|
||||
},
|
||||
{
|
||||
"id": 70,
|
||||
"logprob": -1.04904175e-05,
|
||||
"special": false,
|
||||
"text": "e"
|
||||
},
|
||||
{
|
||||
"id": 3,
|
||||
"logprob": -0.0009560585,
|
||||
"special": false,
|
||||
"text": "\""
|
||||
},
|
||||
{
|
||||
"id": 369,
|
||||
"logprob": -0.08557129,
|
||||
"special": false,
|
||||
"text": " was"
|
||||
},
|
||||
{
|
||||
"id": 806,
|
||||
"logprob": -0.12084961,
|
||||
"special": false,
|
||||
"text": " first"
|
||||
},
|
||||
{
|
||||
"id": 908,
|
||||
"logprob": -0.01737976,
|
||||
"special": false,
|
||||
"text": " used"
|
||||
},
|
||||
{
|
||||
"id": 275,
|
||||
"logprob": -0.4025879,
|
||||
"special": false,
|
||||
"text": " in"
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_text": "The word \"meme\" was first used in"
|
||||
}
|
||||
]
|
|
@ -37,8 +37,8 @@ async def test_flash_neox_load(flash_neox, generate_load, response_snapshot):
|
|||
generated_texts = [r.generated_text for r in responses]
|
||||
|
||||
assert len(generated_texts) == 4
|
||||
assert generated_texts, all(
|
||||
assert all(
|
||||
[text == generated_texts[0] for text in generated_texts]
|
||||
)
|
||||
), generated_texts
|
||||
|
||||
assert responses == response_snapshot
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def neox_handle(launcher):
|
||||
with launcher(
|
||||
"stabilityai/stablelm-tuned-alpha-3b", num_shard=1, use_flash_attention=False
|
||||
) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def neox(neox_handle):
|
||||
await neox_handle.health(300)
|
||||
return neox_handle.client
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_neox(neox, response_snapshot):
|
||||
response = await neox.generate(
|
||||
"<|USER|>What's your mood today?<|ASSISTANT|>",
|
||||
max_new_tokens=10,
|
||||
decoder_input_details=True,
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_neox_load(neox, generate_load, response_snapshot):
|
||||
responses = await generate_load(
|
||||
neox,
|
||||
"<|USER|>What's your mood today?<|ASSISTANT|>",
|
||||
max_new_tokens=10,
|
||||
n=4,
|
||||
)
|
||||
|
||||
generated_texts = [r.generated_text for r in responses]
|
||||
|
||||
assert len(generated_texts) == 4
|
||||
assert generated_texts, all(
|
||||
[text == generated_texts[0] for text in generated_texts]
|
||||
)
|
||||
|
||||
assert responses == response_snapshot
|
|
@ -0,0 +1,44 @@
|
|||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def neox_sharded_handle(launcher):
|
||||
with launcher(
|
||||
"OpenAssistant/oasst-sft-1-pythia-12b", num_shard=2, use_flash_attention=False
|
||||
) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def neox_sharded(neox_sharded_handle):
|
||||
await neox_sharded_handle.health(300)
|
||||
return neox_sharded_handle.client
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_neox(neox_sharded, response_snapshot):
|
||||
response = await neox_sharded.generate(
|
||||
"<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>",
|
||||
max_new_tokens=10,
|
||||
decoder_input_details=True,
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_neox_load(neox_sharded, generate_load, response_snapshot):
|
||||
responses = await generate_load(
|
||||
neox_sharded,
|
||||
"<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>",
|
||||
max_new_tokens=10,
|
||||
n=4,
|
||||
)
|
||||
|
||||
assert len(responses) == 4
|
||||
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||
|
||||
assert responses == response_snapshot
|
|
@ -1,4 +1,5 @@
|
|||
[pytest]
|
||||
addopts = --snapshot-warn-unused
|
||||
asyncio_mode = auto
|
||||
markers =
|
||||
private: marks tests as requiring an admin hf token (deselect with '-m "not private"')
|
|
@ -1,4 +1,3 @@
|
|||
include Makefile-transformers
|
||||
include Makefile-flash-att
|
||||
|
||||
unit-tests:
|
||||
|
@ -17,7 +16,7 @@ install-torch:
|
|||
# Install specific version of torch
|
||||
pip install torch --extra-index-url https://download.pytorch.org/whl/cu118 --no-cache-dir
|
||||
|
||||
install: gen-server install-torch install-transformers
|
||||
install: gen-server install-torch
|
||||
pip install pip --upgrade
|
||||
pip install -r requirements.txt
|
||||
pip install -e ".[bnb, accelerate]"
|
||||
|
@ -26,4 +25,4 @@ run-dev:
|
|||
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded
|
||||
|
||||
export-requirements:
|
||||
poetry export -o requirements.txt -E bnb --without-hashes
|
||||
poetry export -o requirements.txt -E bnb --without-hashes
|
||||
|
|
|
@ -1,13 +0,0 @@
|
|||
transformers_commit := 69009822aa7897ffab97afb814e38126b83f639e
|
||||
|
||||
transformers:
|
||||
# Clone fork of transformers with custom CUDA kernels and sharding logic
|
||||
pip install --upgrade setuptools
|
||||
git clone https://github.com/OlivierDehaene/transformers.git
|
||||
|
||||
build-transformers: transformers
|
||||
cd transformers && git fetch && git checkout $(transformers_commit) && python setup.py build
|
||||
|
||||
install-transformers: build-transformers
|
||||
pip uninstall transformers -y || true
|
||||
cd transformers && python setup.py install
|
|
@ -0,0 +1,250 @@
|
|||
#include <ATen/Dispatch.h>
|
||||
#include <THC/THCAtomics.cuh>
|
||||
#include <ATen/ATen.h>
|
||||
#include <torch/torch.h>
|
||||
#include <vector>
|
||||
|
||||
#include <optional>
|
||||
|
||||
/**
|
||||
* Friendly reminder of how multithreading works in CUDA: https://developer.nvidia.com/blog/even-easier-introduction-cuda
|
||||
* Check example at https://github.com/thomasw21/LinearTransformers/blob/main/model/attention/fast_weight/fast_weight_cuda.cu
|
||||
**/
|
||||
|
||||
// Available in pytorch main
|
||||
//#define DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
// at::AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
|
||||
// at::AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
// at::AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
||||
// at::AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
||||
|
||||
/*
|
||||
* Forward passes
|
||||
*/
|
||||
|
||||
/**
|
||||
* cast to fp32 if in fp16 + mask + softmax computation in fp32 + cast back to original dtype
|
||||
**/
|
||||
template<typename attention_scores_scalar, int64_t min_kv_length_shard_size_per_thread>
|
||||
__global__ void forward_masked_softmax_kernel(
|
||||
const torch::PackedTensorAccessor32<attention_scores_scalar, 2, torch::RestrictPtrTraits> attention_scores, // [B, KV]
|
||||
const torch::PackedTensorAccessor32<bool, 2, torch::RestrictPtrTraits> mask, // [B, KV]
|
||||
torch::PackedTensorAccessor32<attention_scores_scalar, 2, torch::RestrictPtrTraits> result, // [B, KV]
|
||||
const int64_t effective_kv_length,
|
||||
const dim3 blockDim,
|
||||
const int64_t rows_per_block,
|
||||
const int64_t kv_length,
|
||||
const int64_t batch_size
|
||||
) {
|
||||
const auto row_id = threadIdx.x / effective_kv_length;
|
||||
const auto effective_kv_length_id = threadIdx.x % effective_kv_length;
|
||||
const auto kv_length_start = effective_kv_length_id * min_kv_length_shard_size_per_thread;
|
||||
auto kv_length_end_ = (effective_kv_length_id + 1) * min_kv_length_shard_size_per_thread;
|
||||
kv_length_end_ = (kv_length_end_ > kv_length) ? kv_length : kv_length_end_;
|
||||
const auto kv_length_end = kv_length_end_;
|
||||
|
||||
const auto batch_id = blockIdx.x * rows_per_block + row_id;
|
||||
|
||||
// We need 2 float storage for each row, one for max computation, the other for normalizing exponential
|
||||
extern __shared__ float temp_storage[];
|
||||
const auto row_id_mem_offset = row_id * 2;
|
||||
if (effective_kv_length_id == 0) {
|
||||
temp_storage[row_id_mem_offset] = -std::numeric_limits<float>::infinity();
|
||||
temp_storage[row_id_mem_offset + 1] = 0;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Compute mask and max
|
||||
if (batch_id < batch_size) {
|
||||
float thread_max = -std::numeric_limits<float>::infinity();
|
||||
for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
|
||||
if (mask[batch_id][kv_length_id] == 0) {
|
||||
const float candidate = attention_scores[batch_id][kv_length_id];
|
||||
thread_max = (thread_max < candidate) ? candidate : thread_max;
|
||||
}
|
||||
}
|
||||
if (thread_max != -std::numeric_limits<float>::infinity()) {
|
||||
// TODO @thomasw21 with more memory we can probably compute a much faster `max-reduce` in parallel O(ln(n)) operations in each memory slot
|
||||
gpuAtomicMax(&temp_storage[row_id_mem_offset], thread_max);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Compute exp(elt - max) masked
|
||||
float exponential[min_kv_length_shard_size_per_thread];
|
||||
if (batch_id < batch_size) {
|
||||
float thread_add = 0;
|
||||
for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
|
||||
if (mask[batch_id][kv_length_id] == 0) {
|
||||
exponential[kv_length_id - kv_length_start] = std::exp(static_cast<float>(attention_scores[batch_id][kv_length_id]) - temp_storage[row_id_mem_offset]);
|
||||
thread_add = thread_add + exponential[kv_length_id - kv_length_start];
|
||||
} else {
|
||||
exponential[kv_length_id - kv_length_start] = 0.;
|
||||
}
|
||||
}
|
||||
if (thread_add > 0) {
|
||||
// TODO @thomasw21 with more memory we can probably compute a much faster `sum-reduce` in parallel O(ln(n)) operations in each memory slot
|
||||
gpuAtomicAdd(&temp_storage[row_id_mem_offset + 1], thread_add);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Compute softmax
|
||||
if (batch_id < batch_size) {
|
||||
// If sum of all exponential is 0, we set the softmax values to 0
|
||||
if (temp_storage[row_id_mem_offset + 1] == 0.) {
|
||||
for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
|
||||
result[batch_id][kv_length_id] = 0.;
|
||||
}
|
||||
} else {
|
||||
for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
|
||||
result[batch_id][kv_length_id] = static_cast<attention_scores_scalar>(exponential[kv_length_id - kv_length_start] / temp_storage[row_id_mem_offset + 1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||
|
||||
std::tuple<at::Tensor, std::optional<std::vector<at::Tensor>>, at::Tensor> forward(
|
||||
const at::Tensor query,
|
||||
const at::Tensor key,
|
||||
const at::Tensor value,
|
||||
const std::optional<std::vector<at::Tensor>> layer_past,
|
||||
const at::Tensor attention_mask,
|
||||
const std::optional<at::Tensor> head_mask,
|
||||
const float inv_norm_factor,
|
||||
const int num_heads,
|
||||
const bool use_cache
|
||||
) {
|
||||
auto query_layer = query;
|
||||
auto key_layer = key;
|
||||
auto value_layer = value;
|
||||
|
||||
if (layer_past) {
|
||||
const auto past_key = (*layer_past).at(0);
|
||||
const auto past_value = (*layer_past).at(1);
|
||||
key_layer = at::cat({past_key, key_layer}, 2);
|
||||
value_layer = at::cat({past_value, value_layer}, 2);
|
||||
}
|
||||
|
||||
std::optional<std::vector<at::Tensor>> present;
|
||||
if (use_cache) {
|
||||
present = {key_layer, value_layer};
|
||||
} else {
|
||||
present = {};
|
||||
}
|
||||
|
||||
const auto batch_size = query_layer.size(0);
|
||||
const auto q_length = query_layer.size(2);
|
||||
const auto attn_head_size = query_layer.size(3);
|
||||
const auto batch_size_times_num_heads = batch_size * num_heads;
|
||||
const auto kv_length = key_layer.size(2);
|
||||
|
||||
const auto query_view = query_layer.reshape({batch_size_times_num_heads, q_length, attn_head_size});
|
||||
auto key_view = key_layer.reshape({batch_size_times_num_heads, kv_length, attn_head_size}).transpose(1, 2);
|
||||
auto value_view = value_layer.reshape({batch_size_times_num_heads, kv_length, attn_head_size});
|
||||
|
||||
auto query_scaled = query_view * inv_norm_factor;
|
||||
auto attention_scores = at::bmm(query_scaled, key_view);
|
||||
|
||||
// Computing `optionally_cast_fp16_to_fp32 + masked_fill + softmax + cast_to_intial_dtype`
|
||||
at::Tensor attention_probs;
|
||||
if (true) {
|
||||
// TODO @thomasw21: it's easier to think of attention_scores as 2D tensors
|
||||
const auto attention_scores_2d = attention_scores.view({batch_size_times_num_heads * q_length, kv_length});
|
||||
const auto attention_mask_2d = attention_mask.view({batch_size_times_num_heads * q_length, kv_length});
|
||||
|
||||
// Custom kernel
|
||||
attention_probs = at::empty_like(attention_scores_2d);
|
||||
|
||||
// Check that inputs and contiguous + cuda tensors
|
||||
CHECK_INPUT(attention_scores_2d);
|
||||
CHECK_INPUT(attention_mask_2d);
|
||||
|
||||
// TODO @thomas21: change by to this as it's cleaner when pytorch 1.13 comes out
|
||||
// DISPATCH_CASE_FLOATING_TYPES(attention_scores.scalar_type(), "masked_softmax", [&] {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, attention_scores.scalar_type(), "masked_softmax", [&] {
|
||||
/*
|
||||
* Understanding how GPUs work: https://developer.nvidia.com/blog/cuda-refresher-cuda-programming-model/
|
||||
* A100 specifications: https://images.nvidia.com/aem-dam/en-zz/Solutions/data-center/nvidia-ampere-architecture-whitepaper.pdf
|
||||
* - SMs: 108
|
||||
* - TPCs: 56 (What's that?)
|
||||
* - Memory size: 40 GB
|
||||
* - L2 Cache size: 40960 KB (shared across all SMs)
|
||||
* - L1/Shared memory size: 192 KB (shared across all threads within a SM)
|
||||
* - Max Threads / SM: 2048
|
||||
* - Max Thread Blocks / SM: 32
|
||||
*/
|
||||
|
||||
/*
|
||||
* We should split [batch_size_times_num_heads_block, q_length] in seperate blocks and [batch_size_times_num_heads_block_size, kv_length] a single block
|
||||
* with multiple threads as we need to `sync_threads` to run exponential sum.
|
||||
* We maximise the usage of threads within a single block
|
||||
*/
|
||||
// TODO @thomasw21 figure out everything warp related:
|
||||
// - why do they have to be power of 2
|
||||
// TODO @thomas21 check why everyone is setting 1024 when officially it's 2048
|
||||
const auto MAX_THREADS_PER_SM = 1024;
|
||||
// TODO @thomasw21 figure out how to have longer sequences, currently the maximum is `max_kv_length = MAX_THREADS_PER_SM * MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD`
|
||||
const auto MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD = 4;
|
||||
// `effective_kv_length = ceil(kv_length / MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD)`
|
||||
const auto effective_kv_length = (kv_length - 1)/ MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD + 1;
|
||||
const auto rows_per_block = MAX_THREADS_PER_SM / effective_kv_length;
|
||||
const auto num_blocks = (batch_size_times_num_heads * q_length - 1) / rows_per_block + 1;
|
||||
|
||||
const dim3 gridDim(num_blocks); // Number of blocks that run
|
||||
const dim3 blockDim(MAX_THREADS_PER_SM); // Number of threads that run per block
|
||||
const int shared_mem_forward = rows_per_block * 2 * sizeof(float);
|
||||
|
||||
// 192 * 2 ** 10
|
||||
// const auto MAX_L1_MEMORY = 196608;
|
||||
// const auto MAX_SMs = 108;
|
||||
// TORCH_CHECK(batch_size_times_num_heads * q_length <= MAX_L1_MEMORY, "Shared memory exceeds 192KB limitation.");
|
||||
// TORCH_CHECK(gridDim.x * gridDim.y * gridDim.z <= MAX_SMs, "A100s only have 108 SMs. Raising as require blocks is bigger.");
|
||||
// TORCH_CHECK(blockDim.x * blockDim.y * blockDim.z <= MAX_THREADS_PER_SM, "A100s only have 2048 threads per block. Raising as require requested threads is higher.");
|
||||
|
||||
forward_masked_softmax_kernel<scalar_t, MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD><<<gridDim, blockDim, shared_mem_forward>>>(
|
||||
attention_scores_2d.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
|
||||
attention_mask_2d.packed_accessor32<bool, 2, torch::RestrictPtrTraits>(),
|
||||
attention_probs.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
|
||||
effective_kv_length,
|
||||
blockDim,
|
||||
rows_per_block,
|
||||
kv_length,
|
||||
batch_size_times_num_heads * q_length
|
||||
);
|
||||
});
|
||||
attention_probs = attention_probs.view({batch_size_times_num_heads, q_length, kv_length});
|
||||
} else {
|
||||
// Pytorch C++ API
|
||||
auto input_dtype = attention_scores.scalar_type();
|
||||
if (input_dtype == at::ScalarType::Float) {
|
||||
attention_scores = attention_scores.to(at::ScalarType::Float);
|
||||
};
|
||||
// TODO @thomasw21 Figure out how to get minimum value
|
||||
auto attn_weights = attention_scores.masked_fill_(attention_mask, -1e34);
|
||||
attention_probs = attn_weights.softmax(-1, at::ScalarType::Float).to(input_dtype);
|
||||
}
|
||||
|
||||
auto context_layer = attention_probs.bmm(value_view);
|
||||
|
||||
// `_merge_heads`
|
||||
context_layer = context_layer.view({batch_size, num_heads, q_length, attn_head_size});
|
||||
context_layer = context_layer.permute({0, 2, 1, 3});
|
||||
context_layer = context_layer.reshape({batch_size, q_length, attn_head_size * num_heads});
|
||||
|
||||
return std::make_tuple(context_layer, present, attention_probs);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"forward",
|
||||
&forward,
|
||||
"GPT-Neox attention mechanism forward (CUDA)"
|
||||
);
|
||||
}
|
|
@ -0,0 +1,250 @@
|
|||
#include <ATen/Dispatch.h>
|
||||
#include <THC/THCAtomics.cuh>
|
||||
#include <ATen/ATen.h>
|
||||
#include <torch/torch.h>
|
||||
#include <vector>
|
||||
|
||||
#include <optional>
|
||||
|
||||
/**
|
||||
* Friendly reminder of how multithreading works in CUDA: https://developer.nvidia.com/blog/even-easier-introduction-cuda
|
||||
* Check example at https://github.com/thomasw21/LinearTransformers/blob/main/model/attention/fast_weight/fast_weight_cuda.cu
|
||||
**/
|
||||
|
||||
// Available in pytorch main
|
||||
//#define DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
// at::AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
|
||||
// at::AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
// at::AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
||||
// at::AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
||||
|
||||
/*
|
||||
* Forward passes
|
||||
*/
|
||||
|
||||
/**
|
||||
* cast to fp32 if in fp16 + mask + softmax computation in fp32 + cast back to original dtype
|
||||
**/
|
||||
template<typename attention_scores_scalar, int64_t min_kv_length_shard_size_per_thread>
|
||||
__global__ void forward_masked_softmax_kernel(
|
||||
const torch::PackedTensorAccessor32<attention_scores_scalar, 2, torch::RestrictPtrTraits> attention_scores, // [B, KV]
|
||||
const torch::PackedTensorAccessor32<bool, 2, torch::RestrictPtrTraits> mask, // [B, KV]
|
||||
torch::PackedTensorAccessor32<attention_scores_scalar, 2, torch::RestrictPtrTraits> result, // [B, KV]
|
||||
const int64_t effective_kv_length,
|
||||
const dim3 blockDim,
|
||||
const int64_t rows_per_block,
|
||||
const int64_t kv_length,
|
||||
const int64_t batch_size
|
||||
) {
|
||||
const auto row_id = threadIdx.x / effective_kv_length;
|
||||
const auto effective_kv_length_id = threadIdx.x % effective_kv_length;
|
||||
const auto kv_length_start = effective_kv_length_id * min_kv_length_shard_size_per_thread;
|
||||
auto kv_length_end_ = (effective_kv_length_id + 1) * min_kv_length_shard_size_per_thread;
|
||||
kv_length_end_ = (kv_length_end_ > kv_length) ? kv_length : kv_length_end_;
|
||||
const auto kv_length_end = kv_length_end_;
|
||||
|
||||
const auto batch_id = blockIdx.x * rows_per_block + row_id;
|
||||
|
||||
// We need 2 float storage for each row, one for max computation, the other for normalizing exponential
|
||||
extern __shared__ float temp_storage[];
|
||||
const auto row_id_mem_offset = row_id * 2;
|
||||
if (effective_kv_length_id == 0) {
|
||||
temp_storage[row_id_mem_offset] = -std::numeric_limits<float>::infinity();
|
||||
temp_storage[row_id_mem_offset + 1] = 0;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Compute mask and max
|
||||
if (batch_id < batch_size) {
|
||||
float thread_max = -std::numeric_limits<float>::infinity();
|
||||
for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
|
||||
if (mask[batch_id][kv_length_id] == 0) {
|
||||
const float candidate = attention_scores[batch_id][kv_length_id];
|
||||
thread_max = (thread_max < candidate) ? candidate : thread_max;
|
||||
}
|
||||
}
|
||||
if (thread_max != -std::numeric_limits<float>::infinity()) {
|
||||
// TODO @thomasw21 with more memory we can probably compute a much faster `max-reduce` in parallel O(ln(n)) operations in each memory slot
|
||||
gpuAtomicMax(&temp_storage[row_id_mem_offset], thread_max);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Compute exp(elt - max) masked
|
||||
float exponential[min_kv_length_shard_size_per_thread];
|
||||
if (batch_id < batch_size) {
|
||||
float thread_add = 0;
|
||||
for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
|
||||
if (mask[batch_id][kv_length_id] == 0) {
|
||||
exponential[kv_length_id - kv_length_start] = std::exp(static_cast<float>(attention_scores[batch_id][kv_length_id]) - temp_storage[row_id_mem_offset]);
|
||||
thread_add = thread_add + exponential[kv_length_id - kv_length_start];
|
||||
} else {
|
||||
exponential[kv_length_id - kv_length_start] = 0.;
|
||||
}
|
||||
}
|
||||
if (thread_add > 0) {
|
||||
// TODO @thomasw21 with more memory we can probably compute a much faster `sum-reduce` in parallel O(ln(n)) operations in each memory slot
|
||||
gpuAtomicAdd(&temp_storage[row_id_mem_offset + 1], thread_add);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Compute softmax
|
||||
if (batch_id < batch_size) {
|
||||
// If sum of all exponential is 0, we set the softmax values to 0
|
||||
if (temp_storage[row_id_mem_offset + 1] == 0.) {
|
||||
for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
|
||||
result[batch_id][kv_length_id] = 0.;
|
||||
}
|
||||
} else {
|
||||
for (int kv_length_id = kv_length_start; kv_length_id < kv_length_end; ++kv_length_id) {
|
||||
result[batch_id][kv_length_id] = static_cast<attention_scores_scalar>(exponential[kv_length_id - kv_length_start] / temp_storage[row_id_mem_offset + 1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||
|
||||
std::tuple<at::Tensor, std::optional<std::vector<at::Tensor>>, at::Tensor> forward(
|
||||
const at::Tensor fused_qkv,
|
||||
const std::optional<std::vector<at::Tensor>> layer_past,
|
||||
const at::Tensor alibi,
|
||||
const at::Tensor attention_mask,
|
||||
const std::optional<at::Tensor> head_mask,
|
||||
const float beta,
|
||||
const float inv_norm_factor,
|
||||
const int num_heads,
|
||||
const bool use_cache
|
||||
) {
|
||||
const auto batch_size = fused_qkv.size(0);
|
||||
const auto q_length = fused_qkv.size(1);
|
||||
const auto three_times_hidden_size = fused_qkv.size(2);
|
||||
const auto head_dim = three_times_hidden_size / (3 * num_heads);
|
||||
const auto batch_size_times_num_heads = batch_size * num_heads;
|
||||
|
||||
// `split_heads`
|
||||
const auto fused_qkv_view = fused_qkv.view({batch_size, q_length, num_heads, 3 * head_dim});
|
||||
const auto tensor_list = fused_qkv_view.split(head_dim, -1);
|
||||
const auto query_layer = tensor_list[0].transpose(1, 2).reshape({batch_size_times_num_heads, q_length, head_dim});
|
||||
auto key_layer = tensor_list[1].permute({0, 2, 3, 1}).reshape({batch_size_times_num_heads, head_dim, q_length});
|
||||
auto value_layer = tensor_list[2].transpose(1, 2).reshape({batch_size_times_num_heads, q_length, head_dim});
|
||||
|
||||
if (layer_past) {
|
||||
const auto past_key = (*layer_past).at(0);
|
||||
const auto past_value = (*layer_past).at(1);
|
||||
key_layer = at::cat({past_key, key_layer}, 2);
|
||||
value_layer = at::cat({past_value, value_layer}, 1);
|
||||
}
|
||||
|
||||
std::optional<std::vector<at::Tensor>> present;
|
||||
if (use_cache) {
|
||||
present = {key_layer, value_layer};
|
||||
} else {
|
||||
present = {};
|
||||
}
|
||||
|
||||
auto attention_scores = alibi.baddbmm(query_layer, key_layer, beta, inv_norm_factor);
|
||||
|
||||
// Computing `optionally_cast_fp16_to_fp32 + masked_fill + softmax + cast_to_intial_dtype`
|
||||
at::Tensor attention_probs;
|
||||
if (true) {
|
||||
const auto kv_length = key_layer.size(2);
|
||||
|
||||
// TODO @thomasw21: it's easier to think of attention_scores as 2D tensors
|
||||
const auto attention_scores_2d = attention_scores.view({batch_size_times_num_heads * q_length, kv_length});
|
||||
const auto attention_mask_2d = attention_mask.view({batch_size_times_num_heads * q_length, kv_length});
|
||||
|
||||
// Custom kernel
|
||||
attention_probs = at::empty_like(attention_scores_2d);
|
||||
|
||||
// Check that inputs and contiguous + cuda tensors
|
||||
CHECK_INPUT(attention_scores_2d);
|
||||
CHECK_INPUT(attention_mask_2d);
|
||||
|
||||
// TODO @thomas21: change by to this as it's cleaner when pytorch 1.13 comes out
|
||||
// DISPATCH_CASE_FLOATING_TYPES(attention_scores.scalar_type(), "masked_softmax", [&] {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, attention_scores.scalar_type(), "masked_softmax", [&] {
|
||||
/*
|
||||
* Understanding how GPUs work: https://developer.nvidia.com/blog/cuda-refresher-cuda-programming-model/
|
||||
* A100 specifications: https://images.nvidia.com/aem-dam/en-zz/Solutions/data-center/nvidia-ampere-architecture-whitepaper.pdf
|
||||
* - SMs: 108
|
||||
* - TPCs: 56 (What's that?)
|
||||
* - Memory size: 40 GB
|
||||
* - L2 Cache size: 40960 KB (shared across all SMs)
|
||||
* - L1/Shared memory size: 192 KB (shared across all threads within a SM)
|
||||
* - Max Threads / SM: 2048
|
||||
* - Max Thread Blocks / SM: 32
|
||||
*/
|
||||
|
||||
/*
|
||||
* We should split [batch_size_times_num_heads_block, q_length] in seperate blocks and [batch_size_times_num_heads_block_size, kv_length] a single block
|
||||
* with multiple threads as we need to `sync_threads` to run exponential sum.
|
||||
* We maximise the usage of threads within a single block
|
||||
*/
|
||||
// TODO @thomasw21 figure out everything warp related:
|
||||
// - why do they have to be power of 2
|
||||
// TODO @thomas21 check why everyone is setting 1024 when officially it's 2048
|
||||
const auto MAX_THREADS_PER_SM = 1024;
|
||||
// TODO @thomasw21 figure out how to have longer sequences, currently the maximum is `max_kv_length = MAX_THREADS_PER_SM * MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD`
|
||||
const auto MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD = 4;
|
||||
// `effective_kv_length = ceil(kv_length / MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD)`
|
||||
const auto effective_kv_length = (kv_length - 1)/ MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD + 1;
|
||||
const auto rows_per_block = MAX_THREADS_PER_SM / effective_kv_length;
|
||||
const auto num_blocks = (batch_size_times_num_heads * q_length - 1) / rows_per_block + 1;
|
||||
|
||||
const dim3 gridDim(num_blocks); // Number of blocks that run
|
||||
const dim3 blockDim(MAX_THREADS_PER_SM); // Number of threads that run per block
|
||||
const int shared_mem_forward = rows_per_block * 2 * sizeof(float);
|
||||
|
||||
// 192 * 2 ** 10
|
||||
// const auto MAX_L1_MEMORY = 196608;
|
||||
// const auto MAX_SMs = 108;
|
||||
// TORCH_CHECK(batch_size_times_num_heads * q_length <= MAX_L1_MEMORY, "Shared memory exceeds 192KB limitation.");
|
||||
// TORCH_CHECK(gridDim.x * gridDim.y * gridDim.z <= MAX_SMs, "A100s only have 108 SMs. Raising as require blocks is bigger.");
|
||||
// TORCH_CHECK(blockDim.x * blockDim.y * blockDim.z <= MAX_THREADS_PER_SM, "A100s only have 2048 threads per block. Raising as require requested threads is higher.");
|
||||
|
||||
forward_masked_softmax_kernel<scalar_t, MIN_KV_LENGTH_SHARD_SIZE_PER_THREAD><<<gridDim, blockDim, shared_mem_forward>>>(
|
||||
attention_scores_2d.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
|
||||
attention_mask_2d.packed_accessor32<bool, 2, torch::RestrictPtrTraits>(),
|
||||
attention_probs.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
|
||||
effective_kv_length,
|
||||
blockDim,
|
||||
rows_per_block,
|
||||
kv_length,
|
||||
batch_size_times_num_heads * q_length
|
||||
);
|
||||
});
|
||||
attention_probs = attention_probs.view({batch_size_times_num_heads, q_length, kv_length});
|
||||
} else {
|
||||
// Pytorch C++ API
|
||||
auto input_dtype = attention_scores.scalar_type();
|
||||
if (input_dtype == at::ScalarType::Float) {
|
||||
attention_scores = attention_scores.to(at::ScalarType::Float);
|
||||
};
|
||||
// TODO @thomasw21 Figure out how to get minimum value
|
||||
auto attn_weights = attention_scores.masked_fill_(attention_mask, -1e34);
|
||||
attention_probs = attn_weights.softmax(-1, at::ScalarType::Float).to(input_dtype);
|
||||
}
|
||||
|
||||
auto context_layer = attention_probs.bmm(value_layer);
|
||||
|
||||
// `_merge_heads`
|
||||
context_layer = context_layer.view({batch_size, num_heads, q_length, head_dim});
|
||||
context_layer = context_layer.permute({0, 2, 1, 3});
|
||||
context_layer = context_layer.reshape({batch_size, q_length, three_times_hidden_size / 3});
|
||||
|
||||
return std::make_tuple(context_layer, present, attention_probs);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"forward",
|
||||
&forward,
|
||||
"Bloom attention mechanism forward (CUDA)"
|
||||
);
|
||||
}
|
|
@ -0,0 +1,19 @@
|
|||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
setup(
|
||||
name="custom_kernels",
|
||||
ext_modules=[
|
||||
CUDAExtension(
|
||||
name="custom_kernels.fused_bloom_attention_cuda",
|
||||
sources=["custom_kernels/fused_bloom_attention_cuda.cu"],
|
||||
extra_compile_args=["-arch=compute_80", "-std=c++17"],
|
||||
),
|
||||
CUDAExtension(
|
||||
name="custom_kernels.fused_attention_cuda",
|
||||
sources=["custom_kernels/fused_attention_cuda.cu"],
|
||||
extra_compile_args=["-arch=compute_80", "-std=c++17"],
|
||||
),
|
||||
],
|
||||
cmdclass={"build_ext": BuildExtension},
|
||||
)
|
|
@ -25,7 +25,8 @@ opentelemetry-instrumentation-grpc = "^0.36b0"
|
|||
hf-transfer = "^0.1.2"
|
||||
sentencepiece = "^0.1.97"
|
||||
tokenizers = "0.13.3"
|
||||
huggingface-hub = "0.14.0"
|
||||
huggingface-hub = "^0.14.1"
|
||||
transformers = "^4.29.2"
|
||||
|
||||
[tool.poetry.extras]
|
||||
accelerate = ["accelerate"]
|
||||
|
|
|
@ -13,8 +13,8 @@ grpcio-reflection==1.55.0 ; python_version >= "3.9" and python_version < "4.0"
|
|||
grpcio-status==1.55.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
grpcio==1.55.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
hf-transfer==0.1.3 ; python_version >= "3.9" and python_version < "4.0"
|
||||
huggingface-hub==0.14.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
idna==3.4 ; python_version >= "3.9" and python_version < "4.0"
|
||||
huggingface-hub==0.14.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||
idna==3.4 ; python_version >= "3.9" and python_version < "4"
|
||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
|
@ -33,6 +33,7 @@ safetensors==0.3.1 ; python_version >= "3.9" and python_version < "4.0"
|
|||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "4.0"
|
||||
setuptools==67.8.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "4.0"
|
||||
transformers==4.29.2 ; python_version >= "3.9" and python_version < "4.0"
|
||||
tqdm==4.65.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "4.0"
|
||||
typing-extensions==4.6.0 ; python_version >= "3.9" and python_version < "4.0"
|
||||
|
|
|
@ -6,12 +6,17 @@ from transformers import AutoTokenizer
|
|||
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.models.causal_lm import CausalLMBatch
|
||||
from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOM
|
||||
from text_generation_server.utils import weight_hub_files, download_weights
|
||||
from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOMSharded
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def default_bloom():
|
||||
return BLOOM("bigscience/bloom-560m")
|
||||
model_id = "bigscience/bloom-560m"
|
||||
revision = "main"
|
||||
filenames = weight_hub_files(model_id, revision, ".safetensors")
|
||||
download_weights(filenames, model_id, revision)
|
||||
return BLOOMSharded(model_id)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import os
|
||||
import torch
|
||||
|
||||
from loguru import logger
|
||||
|
@ -8,17 +9,20 @@ from typing import Optional
|
|||
from text_generation_server.models.model import Model
|
||||
from text_generation_server.models.causal_lm import CausalLM
|
||||
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||
from text_generation_server.models.bloom import BLOOM, BLOOMSharded
|
||||
from text_generation_server.models.bloom import BLOOMSharded
|
||||
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
||||
from text_generation_server.models.rw import RW
|
||||
from text_generation_server.models.opt import OPT, OPTSharded
|
||||
from text_generation_server.models.galactica import Galactica, GalacticaSharded
|
||||
from text_generation_server.models.opt import OPTSharded
|
||||
from text_generation_server.models.galactica import GalacticaSharded
|
||||
from text_generation_server.models.santacoder import SantaCoder
|
||||
from text_generation_server.models.gpt_neox import GPTNeoxSharded
|
||||
from text_generation_server.models.t5 import T5Sharded
|
||||
from text_generation_server.models.gpt_neox import GPTNeoxSharded
|
||||
|
||||
try:
|
||||
if torch.cuda.is_available():
|
||||
if (
|
||||
torch.cuda.is_available()
|
||||
and not os.getenv("USE_FLASH_ATTENTION", "").lower() == "false"
|
||||
):
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
is_sm75 = major == 7 and minor == 5
|
||||
is_sm8x = major == 8 and minor >= 0
|
||||
|
@ -30,14 +34,12 @@ try:
|
|||
f"GPU with CUDA capability {major} {minor} is not supported"
|
||||
)
|
||||
|
||||
from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded
|
||||
from text_generation_server.models.flash_rw import FlashRW, FlashRWSharded
|
||||
from text_generation_server.models.flash_rw import FlashRWSharded
|
||||
from text_generation_server.models.flash_neox import FlashNeoXSharded
|
||||
from text_generation_server.models.flash_llama import (
|
||||
FlashLlama,
|
||||
FlashLlamaSharded,
|
||||
)
|
||||
from text_generation_server.models.flash_santacoder import (
|
||||
FlashSantacoder,
|
||||
FlashSantacoderSharded,
|
||||
)
|
||||
|
||||
|
@ -52,30 +54,22 @@ except ImportError:
|
|||
|
||||
__all__ = [
|
||||
"Model",
|
||||
"BLOOM",
|
||||
"BLOOMSharded",
|
||||
"CausalLM",
|
||||
"FlashCausalLM",
|
||||
"Galactica",
|
||||
"GalacticaSharded",
|
||||
"GPTNeoxSharded",
|
||||
"Seq2SeqLM",
|
||||
"SantaCoder",
|
||||
"OPT",
|
||||
"OPTSharded",
|
||||
"T5Sharded",
|
||||
"get_model",
|
||||
]
|
||||
|
||||
if FLASH_ATTENTION:
|
||||
__all__.append(FlashNeoX)
|
||||
__all__.append(FlashNeoXSharded)
|
||||
__all__.append(FlashRW)
|
||||
__all__.append(FlashRWSharded)
|
||||
__all__.append(FlashSantacoder)
|
||||
__all__.append(FlashSantacoderSharded)
|
||||
__all__.append(FlashLlama)
|
||||
__all__.append(FlashLlamaSharded)
|
||||
|
||||
FLASH_ATT_ERROR_MESSAGE = (
|
||||
"{} requires Flash Attention CUDA kernels to be installed.\n"
|
||||
|
@ -102,36 +96,24 @@ def get_model(
|
|||
trust_remote_code: bool,
|
||||
) -> Model:
|
||||
if "facebook/galactica" in model_id:
|
||||
if sharded:
|
||||
return GalacticaSharded(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
else:
|
||||
return Galactica(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
return GalacticaSharded(
|
||||
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
|
||||
)
|
||||
|
||||
if model_id.startswith("bigcode/"):
|
||||
if sharded:
|
||||
if not FLASH_ATTENTION:
|
||||
raise NotImplementedError(
|
||||
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder")
|
||||
)
|
||||
if FLASH_ATTENTION:
|
||||
return FlashSantacoderSharded(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
elif sharded:
|
||||
raise NotImplementedError(
|
||||
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
|
||||
)
|
||||
else:
|
||||
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
|
||||
return santacoder_cls(
|
||||
return SantaCoder(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
|
@ -144,20 +126,19 @@ def get_model(
|
|||
model_type = config_dict["model_type"]
|
||||
|
||||
if model_type == "gpt_bigcode":
|
||||
if sharded:
|
||||
if not FLASH_ATTENTION:
|
||||
raise NotImplementedError(
|
||||
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder")
|
||||
)
|
||||
if FLASH_ATTENTION:
|
||||
return FlashSantacoderSharded(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
elif sharded:
|
||||
raise NotImplementedError(
|
||||
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
|
||||
)
|
||||
else:
|
||||
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
|
||||
return santacoder_cls(
|
||||
return SantaCoder(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
|
@ -165,33 +146,45 @@ def get_model(
|
|||
)
|
||||
|
||||
if model_type == "bloom":
|
||||
if sharded:
|
||||
return BLOOMSharded(
|
||||
return BLOOMSharded(
|
||||
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
|
||||
)
|
||||
|
||||
elif model_type == "gpt_neox":
|
||||
if FLASH_ATTENTION:
|
||||
return FlashNeoXSharded(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
elif sharded:
|
||||
return GPTNeoxSharded(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
else:
|
||||
return BLOOM(
|
||||
return CausalLM(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_type == "gpt_neox":
|
||||
if sharded:
|
||||
neox_cls = FlashNeoXSharded if FLASH_ATTENTION else GPTNeoxSharded
|
||||
return neox_cls(
|
||||
elif model_type == "llama":
|
||||
if FLASH_ATTENTION:
|
||||
return FlashLlama(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
|
||||
else:
|
||||
neox_cls = FlashNeoX if FLASH_ATTENTION else CausalLM
|
||||
return neox_cls(
|
||||
return CausalLM(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
|
@ -217,7 +210,7 @@ def get_model(
|
|||
)
|
||||
else:
|
||||
if FLASH_ATTENTION and not config_dict.get("alibi", False):
|
||||
return FlashRW(
|
||||
return FlashRWSharded(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
|
@ -231,42 +224,12 @@ def get_model(
|
|||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_type == "llama":
|
||||
if sharded:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashLlamaSharded(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Llama"))
|
||||
else:
|
||||
llama_cls = FlashLlama if FLASH_ATTENTION else CausalLM
|
||||
return llama_cls(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
elif model_type == "opt":
|
||||
return OPTSharded(
|
||||
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
|
||||
)
|
||||
|
||||
if model_type == "opt":
|
||||
if sharded:
|
||||
return OPTSharded(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
else:
|
||||
return OPT(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_type == "t5":
|
||||
elif model_type == "t5":
|
||||
if sharded:
|
||||
return T5Sharded(
|
||||
model_id,
|
||||
|
|
|
@ -1,37 +1,26 @@
|
|||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from typing import List, Optional, Type
|
||||
from typing import Optional, Type
|
||||
|
||||
from accelerate import init_empty_weights
|
||||
from safetensors import safe_open
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoModelForCausalLM,
|
||||
AutoConfig,
|
||||
PreTrainedTokenizerBase,
|
||||
)
|
||||
from transformers.models.bloom.parallel_layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelRowLinear,
|
||||
)
|
||||
|
||||
from text_generation_server.models.custom_modeling.bloom_modeling import (
|
||||
BloomForCausalLM,
|
||||
)
|
||||
from text_generation_server.models import CausalLM
|
||||
from text_generation_server.models.causal_lm import CausalLMBatch
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
|
||||
HAS_BITS_AND_BYTES = True
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
from bitsandbytes.nn import Int8Params
|
||||
except Exception as e:
|
||||
HAS_BITS_AND_BYTES = False
|
||||
|
||||
|
||||
class BloomCausalLMBatch(CausalLMBatch):
|
||||
@classmethod
|
||||
|
@ -42,34 +31,12 @@ class BloomCausalLMBatch(CausalLMBatch):
|
|||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "CausalLMBatch":
|
||||
batch = super(BloomCausalLMBatch, cls).from_pb(
|
||||
pb=pb, tokenizer=tokenizer, dtype=dtype, device=device
|
||||
)
|
||||
batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device)
|
||||
batch.keys_head_dim_last = False
|
||||
return batch
|
||||
|
||||
|
||||
class BLOOM(CausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
super(BLOOM, self).__init__(
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
@property
|
||||
def batch_type(self) -> Type[CausalLMBatch]:
|
||||
return BloomCausalLMBatch
|
||||
|
||||
|
||||
class BLOOMSharded(BLOOM):
|
||||
class BLOOMSharded(CausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
|
@ -101,25 +68,16 @@ class BLOOMSharded(BLOOM):
|
|||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
config.pad_token_id = 3
|
||||
config.quantize = quantize
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
|
||||
with init_empty_weights():
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
config, trust_remote_code=trust_remote_code
|
||||
)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
self.load_weights(
|
||||
model,
|
||||
filenames,
|
||||
quantize=quantize,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
weights = Weights(
|
||||
filenames, device=device, dtype=dtype, process_group=self.process_group
|
||||
)
|
||||
|
||||
model = BloomForCausalLM(config, weights)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(CausalLM, self).__init__(
|
||||
model=model,
|
||||
|
@ -131,132 +89,9 @@ class BLOOMSharded(BLOOM):
|
|||
world_size=world_size,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def load_weights(
|
||||
model,
|
||||
filenames: List[str],
|
||||
quantize: Optional[str],
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
):
|
||||
parameters = dict(model.named_parameters())
|
||||
for file in filenames:
|
||||
with safe_open(
|
||||
file, framework="pt", device=str(device) if quantize is None else "cpu"
|
||||
) as f:
|
||||
for name in f.keys():
|
||||
if name.startswith("transformer.") or name.startswith("lm_head."):
|
||||
full_name = name
|
||||
else:
|
||||
full_name = f"transformer.{name}"
|
||||
|
||||
module_name, param_name = full_name.rsplit(".", 1)
|
||||
module = model.get_submodule(module_name)
|
||||
current_tensor = parameters[full_name]
|
||||
|
||||
slice_ = f.get_slice(name)
|
||||
|
||||
if isinstance(module, TensorParallelColumnLinear):
|
||||
size = slice_.get_shape()[0]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[start:stop]
|
||||
elif isinstance(module, TensorParallelRowLinear):
|
||||
if param_name == "weight":
|
||||
size = slice_.get_shape()[1]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[:, start:stop]
|
||||
else:
|
||||
tensor = slice_[:]
|
||||
# XXX: Hack for Rowlinear to add the bias only once.
|
||||
if rank != 0:
|
||||
tensor = torch.zeros_like(tensor)
|
||||
elif (
|
||||
isinstance(module, TensorParallelEmbedding)
|
||||
or name == "lm_head.weight"
|
||||
):
|
||||
size = slice_.get_shape()[0]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[start:stop]
|
||||
else:
|
||||
tensor = slice_[:]
|
||||
|
||||
if current_tensor.shape != tensor.shape:
|
||||
raise ValueError(
|
||||
f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}"
|
||||
)
|
||||
|
||||
tensor = tensor.contiguous().to(dtype)
|
||||
|
||||
if quantize == "bitsandbytes":
|
||||
if not HAS_BITS_AND_BYTES:
|
||||
raise ImportError(
|
||||
"bitsandbytes is not available on your machine either because it is not installed "
|
||||
"or you don't have a GPU.\n"
|
||||
"You can install it with `pip install bitsandbytes`."
|
||||
)
|
||||
|
||||
if (
|
||||
type(module)
|
||||
in [TensorParallelRowLinear, TensorParallelColumnLinear]
|
||||
and param_name == "weight"
|
||||
):
|
||||
tensor = Int8Params(
|
||||
tensor,
|
||||
has_fp16_weights=False,
|
||||
requires_grad=False,
|
||||
).to(device)
|
||||
state = bnb.MatmulLtState()
|
||||
state.threshold = 6.0
|
||||
state.has_fp16_weights = False
|
||||
state.memory_efficient_backward = False
|
||||
state.use_pool = True
|
||||
state.CB = tensor.CB
|
||||
state.SCB = tensor.SCB
|
||||
tensor.CB = None
|
||||
tensor.SCB = None
|
||||
|
||||
def replace_linear(state):
|
||||
def linear(input, weight, bias):
|
||||
out = bnb.matmul(
|
||||
input,
|
||||
weight,
|
||||
state=state,
|
||||
threshold=state.threshold,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
if state.CB is not None:
|
||||
# we converted 8-bit row major to turing/ampere format
|
||||
# in the first inference pass
|
||||
# we no longer need the row-major weight
|
||||
del state.CB
|
||||
weight.data = state.CxB
|
||||
|
||||
return out
|
||||
|
||||
return linear
|
||||
|
||||
module.linear = replace_linear(state)
|
||||
else:
|
||||
tensor = tensor.to(device)
|
||||
elif quantize == "gptq":
|
||||
raise NotImplementedError("`gptq` is not implemented for now")
|
||||
elif quantize is None:
|
||||
tensor = tensor.to(device)
|
||||
else:
|
||||
raise ValueError(f"Unexpected quantize `{quantize}`")
|
||||
|
||||
module._parameters[param_name] = tensor
|
||||
if name == "word_embeddings.weight":
|
||||
model.lm_head._parameters["weight"] = tensor
|
||||
@property
|
||||
def batch_type(self) -> Type[CausalLMBatch]:
|
||||
return BloomCausalLMBatch
|
||||
|
||||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
|
@ -269,9 +104,5 @@ class BLOOMSharded(BLOOM):
|
|||
use_cache=True,
|
||||
)
|
||||
|
||||
# Logits are sharded, so we need to gather them
|
||||
logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
|
||||
torch.distributed.all_gather(logits, outputs.logits, group=self.process_group)
|
||||
logits = torch.cat(logits, dim=2)
|
||||
|
||||
logits = outputs.logits
|
||||
return logits, outputs.past_key_values
|
||||
|
|
|
@ -0,0 +1,912 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2022 HuggingFace Inc. team and BigScience workshop.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch BLOOM model."""
|
||||
|
||||
import math
|
||||
import os
|
||||
import warnings
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import LayerNorm
|
||||
from torch.nn import functional as F
|
||||
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
CausalLMOutputWithCrossAttentions,
|
||||
)
|
||||
from transformers import BloomConfig, PreTrainedModel
|
||||
|
||||
from text_generation_server.utils.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelHead,
|
||||
)
|
||||
|
||||
CUSTOM_KERNELS_ENABLED = False
|
||||
if not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True":
|
||||
try:
|
||||
from custom_kernels import fused_bloom_attention_cuda
|
||||
|
||||
CUSTOM_KERNELS_ENABLED = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "bigscience/bloom-560m"
|
||||
_CONFIG_FOR_DOC = "BloomConfig"
|
||||
|
||||
BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"bigscience/bigscience-small-testing",
|
||||
"bigscience/bloom-560m",
|
||||
"bigscience/bloom-1b1",
|
||||
"bigscience/bloom-1b7",
|
||||
"bigscience/bloom-3b",
|
||||
"bigscience/bloom-7b1",
|
||||
"bigscience/bloom",
|
||||
]
|
||||
|
||||
|
||||
def _make_causal_mask(
|
||||
input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
|
||||
) -> torch.BoolTensor:
|
||||
"""
|
||||
Make causal mask used for self-attention.
|
||||
"""
|
||||
batch_size, target_length = input_ids_shape
|
||||
mask = torch.ones(
|
||||
(target_length, target_length + past_key_values_length),
|
||||
dtype=torch.bool,
|
||||
device=device,
|
||||
)
|
||||
mask = mask.triu(1 + past_key_values_length)
|
||||
|
||||
expanded_mask = mask.unsqueeze(0).expand(
|
||||
batch_size, target_length, target_length + past_key_values_length
|
||||
)
|
||||
return expanded_mask
|
||||
|
||||
|
||||
def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
|
||||
"""
|
||||
Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`.
|
||||
"""
|
||||
batch_size, src_length = mask.shape
|
||||
tgt_length = tgt_length if tgt_length is not None else src_length
|
||||
|
||||
expanded_mask = ~(mask[:, None, :].to(torch.bool))
|
||||
return expanded_mask.expand(batch_size, tgt_length, src_length)
|
||||
|
||||
|
||||
def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int) -> torch.Tensor:
|
||||
"""
|
||||
Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
|
||||
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
|
||||
`softmax(l+a) = softmax(l)`. Based on
|
||||
https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
|
||||
TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly.
|
||||
|
||||
Args:
|
||||
Returns tensor shaped (batch_size * num_heads, 1, max_seq_len)
|
||||
attention_mask (`torch.Tensor`):
|
||||
Token-wise attention mask, this should be of shape (batch_size, max_seq_len).
|
||||
num_heads (`int`, *required*):
|
||||
number of heads
|
||||
dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`):
|
||||
dtype of the output tensor
|
||||
"""
|
||||
batch_size, seq_length = attention_mask.shape
|
||||
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
|
||||
base = torch.tensor(
|
||||
2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))),
|
||||
device=attention_mask.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
powers = torch.arange(
|
||||
1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32
|
||||
)
|
||||
slopes = torch.pow(base, powers)
|
||||
|
||||
if closest_power_of_2 != num_heads:
|
||||
extra_base = torch.tensor(
|
||||
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))),
|
||||
device=attention_mask.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
|
||||
extra_powers = torch.arange(
|
||||
1,
|
||||
1 + 2 * num_remaining_heads,
|
||||
2,
|
||||
device=attention_mask.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
||||
|
||||
# Note: alibi will added to the attention bias that will be applied to the query, key product of attention
|
||||
# => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
|
||||
# => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
|
||||
# => the query_length dimension will then be broadcasted correctly
|
||||
# This is more or less identical to T5's relative position bias:
|
||||
# https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
|
||||
arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
|
||||
alibi = slopes[..., None] * arange_tensor
|
||||
return alibi
|
||||
|
||||
|
||||
# @torch.jit.script
|
||||
def dropout_add(
|
||||
x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Dropout add function
|
||||
|
||||
Args:
|
||||
x (`torch.tensor`, *required*):
|
||||
input tensor
|
||||
residual (`torch.tensor`, *required*):
|
||||
esidual tensor
|
||||
prob (`float`, *required*):
|
||||
dropout probability
|
||||
training (`bool`, *required*):
|
||||
training mode
|
||||
"""
|
||||
out = F.dropout(x, p=prob, training=training)
|
||||
out = residual + out
|
||||
return out
|
||||
|
||||
|
||||
# @torch.jit.script # this is shit for unknow reasons.
|
||||
def _split_heads(
|
||||
fused_qkv: torch.Tensor, num_heads: int, head_dim: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory
|
||||
storage as `fused_qkv`
|
||||
|
||||
Args:
|
||||
fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
|
||||
|
||||
Returns:
|
||||
query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
|
||||
value: [batch_size, seq_length, num_heads, head_dim]
|
||||
"""
|
||||
batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
|
||||
fused_qkv = fused_qkv.view(batch_size, seq_length, num_heads, 3 * head_dim)
|
||||
query_layer, key_layer, value_layer = fused_qkv.split(head_dim, dim=-1)
|
||||
|
||||
query_layer = query_layer.transpose(1, 2).reshape(
|
||||
batch_size * num_heads, seq_length, head_dim
|
||||
)
|
||||
key_layer = key_layer.permute(0, 2, 3, 1).reshape(
|
||||
batch_size * num_heads, head_dim, seq_length
|
||||
)
|
||||
value_layer = value_layer.transpose(1, 2).reshape(
|
||||
batch_size * num_heads, seq_length, head_dim
|
||||
)
|
||||
|
||||
return query_layer, key_layer, value_layer
|
||||
|
||||
|
||||
# @torch.jit.script
|
||||
def _merge_heads(x: torch.Tensor, num_heads: int, head_dim: int) -> torch.Tensor:
|
||||
"""
|
||||
Merge heads together over the last dimenstion
|
||||
|
||||
Args:
|
||||
x: (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim]
|
||||
|
||||
Returns:
|
||||
torch.tensor: [batch_size, seq_length, num_heads * head_dim]
|
||||
"""
|
||||
# What we want to achieve is:
|
||||
# batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim
|
||||
batch_size_and_num_heads, seq_length, _ = x.shape
|
||||
batch_size = batch_size_and_num_heads // num_heads
|
||||
|
||||
# First view to decompose the batch size
|
||||
# batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim
|
||||
x = x.view(batch_size, num_heads, seq_length, head_dim)
|
||||
|
||||
# batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim
|
||||
x = x.permute(0, 2, 1, 3)
|
||||
|
||||
# batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim
|
||||
return x.reshape(batch_size, seq_length, num_heads * head_dim)
|
||||
|
||||
|
||||
class BloomAttention(nn.Module):
|
||||
def __init__(self, prefix, config: BloomConfig, weights):
|
||||
super().__init__()
|
||||
|
||||
self.pretraining_tp = config.pretraining_tp
|
||||
self.slow_but_exact = config.slow_but_exact
|
||||
|
||||
self.process_group = weights.process_group
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.n_head
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.split_size = self.hidden_size
|
||||
self.hidden_dropout = config.hidden_dropout
|
||||
|
||||
if self.head_dim * self.num_heads != self.hidden_size:
|
||||
raise ValueError(
|
||||
f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
|
||||
f" {self.num_heads})."
|
||||
)
|
||||
|
||||
# Layer-wise attention scaling
|
||||
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
|
||||
self.beta = 1.0
|
||||
|
||||
process_group = weights.process_group
|
||||
self.num_heads = self.num_heads // process_group.size()
|
||||
self.query_key_value = TensorParallelColumnLinear.load(
|
||||
config=config,
|
||||
prefix=f"{prefix}.query_key_value",
|
||||
weights=weights,
|
||||
bias=True,
|
||||
)
|
||||
self.dense = TensorParallelRowLinear.load(
|
||||
config=config, prefix=f"{prefix}.dense", weights=weights, bias=True
|
||||
)
|
||||
self.attention_dropout = nn.Dropout(config.attention_dropout)
|
||||
|
||||
@staticmethod
|
||||
def compute_attention(
|
||||
fused_qkv: torch.Tensor,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
||||
alibi: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
head_mask: Optional[torch.Tensor],
|
||||
beta: float,
|
||||
inv_norm_factor: float,
|
||||
num_heads: int,
|
||||
use_cache: bool,
|
||||
):
|
||||
batch_size, q_length, three_times_hidden_size = fused_qkv.shape
|
||||
head_dim = three_times_hidden_size // (3 * num_heads)
|
||||
batch_size * num_heads
|
||||
|
||||
### TODO @thomasw21: this takes quite a bit of time, how do I accelerate that?
|
||||
# 3 x [batch_size, seq_length, num_heads, head_dim]
|
||||
(query_layer, key_layer, value_layer) = _split_heads(
|
||||
fused_qkv, num_heads=num_heads, head_dim=head_dim
|
||||
)
|
||||
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past
|
||||
# concatenate along seq_length dimension:
|
||||
# - key: [batch_size * self.num_heads, head_dim, kv_length]
|
||||
# - value: [batch_size * self.num_heads, kv_length, head_dim]
|
||||
past_key = past_key.view(-1, *past_key.shape[-2:])
|
||||
key_layer = torch.cat((past_key, key_layer), dim=2)
|
||||
past_value = past_value.view(-1, *past_value.shape[-2:])
|
||||
value_layer = torch.cat((past_value, value_layer), dim=1)
|
||||
|
||||
_, _, kv_length = key_layer.shape
|
||||
|
||||
if use_cache is True:
|
||||
present = (key_layer, value_layer)
|
||||
else:
|
||||
present = None
|
||||
###
|
||||
|
||||
# [batch_size * num_heads, q_length, kv_length]
|
||||
# we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11
|
||||
attention_scores = alibi.baddbmm(
|
||||
batch1=query_layer,
|
||||
batch2=key_layer,
|
||||
beta=beta,
|
||||
alpha=inv_norm_factor,
|
||||
)
|
||||
|
||||
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
|
||||
input_dtype = attention_scores.dtype
|
||||
# `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
|
||||
if input_dtype == torch.float16:
|
||||
attention_scores = attention_scores.to(torch.float)
|
||||
# torch.finfo not supported by torch.jit, we temporarily remplace with `-1e34`
|
||||
attn_weights = attention_scores.masked_fill_(
|
||||
attention_mask, torch.finfo(attention_scores.dtype).min
|
||||
)
|
||||
attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
|
||||
input_dtype
|
||||
)
|
||||
|
||||
# # [batch_size, num_heads, q_length, kv_length]
|
||||
# attention_probs = self.attention_dropout(attention_probs)
|
||||
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
# matmul: [batch_size * num_heads, q_length, head_dim]
|
||||
context_layer = torch.bmm(attention_probs, value_layer, out=query_layer)
|
||||
|
||||
# change view [batch_size, num_heads, q_length, head_dim]
|
||||
context_layer = _merge_heads(
|
||||
context_layer, num_heads=num_heads, head_dim=head_dim
|
||||
)
|
||||
|
||||
return context_layer, present, attention_probs
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
alibi: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
fused_qkv = self.query_key_value(
|
||||
hidden_states
|
||||
) # [batch_size, seq_length, 3 x hidden_size]
|
||||
batch_size, q_length, _ = fused_qkv.shape
|
||||
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past
|
||||
layer_past = (
|
||||
past_key.view(-1, *past_key.shape[-2:]),
|
||||
past_value.view(-1, *past_value.shape[-2:]),
|
||||
)
|
||||
|
||||
if CUSTOM_KERNELS_ENABLED:
|
||||
assert self.training is False, "Only foward pass was implemented"
|
||||
assert (
|
||||
attention_mask.shape[-1] < 4096
|
||||
), "Custom kernel support only up to 4096 tokens"
|
||||
(
|
||||
context_layer,
|
||||
present,
|
||||
attention_probs,
|
||||
) = fused_bloom_attention_cuda.forward(
|
||||
fused_qkv,
|
||||
layer_past,
|
||||
alibi,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
self.beta,
|
||||
self.inv_norm_factor,
|
||||
self.num_heads,
|
||||
use_cache,
|
||||
)
|
||||
else:
|
||||
context_layer, present, attention_probs = self.compute_attention(
|
||||
fused_qkv=fused_qkv,
|
||||
layer_past=layer_past,
|
||||
alibi=alibi,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
beta=self.beta,
|
||||
inv_norm_factor=self.inv_norm_factor,
|
||||
num_heads=self.num_heads,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
|
||||
if self.pretraining_tp > 1 and self.slow_but_exact:
|
||||
slices = self.hidden_size / self.pretraining_tp
|
||||
output_tensor = torch.zeros_like(context_layer)
|
||||
for i in range(self.pretraining_tp):
|
||||
output_tensor = output_tensor + F.linear(
|
||||
context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
|
||||
self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
|
||||
)
|
||||
else:
|
||||
output_tensor = self.dense(context_layer)
|
||||
|
||||
# output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
|
||||
output_tensor += residual
|
||||
|
||||
outputs = (output_tensor, present)
|
||||
if output_attentions:
|
||||
outputs += (attention_probs,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class BloomMLP(nn.Module):
|
||||
def __init__(self, prefix, config: BloomConfig, weights):
|
||||
super().__init__()
|
||||
|
||||
self.pretraining_tp = config.pretraining_tp
|
||||
self.slow_but_exact = config.slow_but_exact
|
||||
self.dense_h_to_4h = TensorParallelColumnLinear.load(
|
||||
config=config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=True
|
||||
)
|
||||
self.dense_4h_to_h = TensorParallelRowLinear.load(
|
||||
config=config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=True
|
||||
)
|
||||
self.gelu_impl = torch.nn.GELU(approximate="tanh")
|
||||
self.hidden_dropout = config.hidden_dropout
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, residual: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states))
|
||||
|
||||
if self.pretraining_tp > 1 and self.slow_but_exact:
|
||||
intermediate_output = torch.zeros_like(residual)
|
||||
slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp
|
||||
for i in range(self.pretraining_tp):
|
||||
intermediate_output = intermediate_output + F.linear(
|
||||
hidden_states[:, :, int(i * slices) : int((i + 1) * slices)],
|
||||
self.dense_4h_to_h.weight[
|
||||
:, int(i * slices) : int((i + 1) * slices)
|
||||
],
|
||||
)
|
||||
else:
|
||||
intermediate_output = self.dense_4h_to_h(hidden_states)
|
||||
|
||||
# output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)
|
||||
intermediate_output += residual
|
||||
|
||||
return intermediate_output
|
||||
|
||||
|
||||
class BloomBlock(nn.Module):
|
||||
def __init__(self, layer_id: int, config: BloomConfig, weights):
|
||||
super().__init__()
|
||||
|
||||
prefix = f"h.{layer_id}"
|
||||
self.input_layernorm = LayerNorm.load(
|
||||
prefix=f"{prefix}.input_layernorm",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_epsilon,
|
||||
)
|
||||
self.num_heads = config.n_head
|
||||
self.self_attention = BloomAttention(
|
||||
prefix=f"{prefix}.self_attention", config=config, weights=weights
|
||||
)
|
||||
self.post_attention_layernorm = LayerNorm.load(
|
||||
prefix=f"{prefix}.post_attention_layernorm",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_epsilon,
|
||||
)
|
||||
|
||||
self.mlp = BloomMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||
self.apply_residual_connection_post_layernorm = (
|
||||
config.apply_residual_connection_post_layernorm
|
||||
)
|
||||
self.hidden_dropout = config.hidden_dropout
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
alibi: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
# hidden_states: [batch_size, seq_length, hidden_size]
|
||||
|
||||
# Layer norm at the beginning of the transformer layer.
|
||||
layernorm_output = self.input_layernorm(hidden_states)
|
||||
|
||||
# Layer norm post the self attention.
|
||||
if self.apply_residual_connection_post_layernorm:
|
||||
residual = layernorm_output
|
||||
else:
|
||||
residual = hidden_states
|
||||
|
||||
# Self attention.
|
||||
attn_outputs = self.self_attention(
|
||||
layernorm_output,
|
||||
residual,
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
alibi=alibi,
|
||||
head_mask=head_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
attention_output = attn_outputs[0]
|
||||
|
||||
outputs = attn_outputs[1:]
|
||||
|
||||
layernorm_output = self.post_attention_layernorm(attention_output)
|
||||
|
||||
# Get residual
|
||||
if self.apply_residual_connection_post_layernorm:
|
||||
residual = layernorm_output
|
||||
else:
|
||||
residual = attention_output
|
||||
|
||||
# MLP.
|
||||
output = self.mlp(layernorm_output, residual)
|
||||
|
||||
if use_cache:
|
||||
outputs = (output,) + outputs
|
||||
else:
|
||||
outputs = (output,) + outputs[1:]
|
||||
|
||||
return outputs # hidden_states, present, attentions
|
||||
|
||||
|
||||
class BloomPreTrainedModel(PreTrainedModel):
|
||||
config_class = BloomConfig
|
||||
base_model_prefix = "transformer"
|
||||
_no_split_modules = ["BloomBlock"]
|
||||
|
||||
@staticmethod
|
||||
def _convert_to_standard_cache(
|
||||
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
|
||||
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
|
||||
num_heads, ...]))
|
||||
"""
|
||||
batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape
|
||||
num_heads = batch_size_times_num_heads // batch_size
|
||||
# key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
|
||||
# value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
|
||||
return tuple(
|
||||
(
|
||||
layer_past[0].view(batch_size, num_heads, head_dim, seq_length),
|
||||
layer_past[1].view(batch_size, num_heads, seq_length, head_dim),
|
||||
)
|
||||
for layer_past in past_key_value
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _convert_to_bloom_cache(
|
||||
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
|
||||
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...]))
|
||||
"""
|
||||
batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
|
||||
batch_size_times_num_heads = batch_size * num_heads
|
||||
# key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
|
||||
# value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
|
||||
return tuple(
|
||||
(
|
||||
layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length),
|
||||
layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim),
|
||||
)
|
||||
for layer_past in past_key_value
|
||||
)
|
||||
|
||||
|
||||
class BloomModel(BloomPreTrainedModel):
|
||||
def __init__(self, config: BloomConfig, weights):
|
||||
super().__init__(config)
|
||||
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.n_head
|
||||
|
||||
process_group = weights.process_group
|
||||
self.tp_rank = process_group.rank()
|
||||
self.tp_world_size = process_group.size()
|
||||
|
||||
self.word_embeddings = TensorParallelEmbedding(
|
||||
prefix="word_embeddings", weights=weights
|
||||
)
|
||||
|
||||
self.word_embeddings_layernorm = LayerNorm.load(
|
||||
prefix="word_embeddings_layernorm",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_epsilon,
|
||||
)
|
||||
|
||||
# Transformer blocks
|
||||
self.h = nn.ModuleList(
|
||||
[
|
||||
BloomBlock(layer_id=layer_id, config=config, weights=weights)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# Final Layer Norm
|
||||
self.ln_f = LayerNorm.load(
|
||||
prefix="ln_f", weights=weights, eps=config.layer_norm_epsilon
|
||||
)
|
||||
|
||||
def _prepare_attn_mask(
|
||||
self,
|
||||
attention_mask: torch.Tensor,
|
||||
input_shape: Tuple[int, int],
|
||||
past_key_values_length: int,
|
||||
) -> torch.BoolTensor:
|
||||
# create causal mask
|
||||
# [batch_size, seq_length] -> [batch_size, tgt_length, src_length]
|
||||
combined_attention_mask = None
|
||||
device = attention_mask.device
|
||||
_, src_length = input_shape
|
||||
|
||||
if src_length > 1:
|
||||
combined_attention_mask = _make_causal_mask(
|
||||
input_shape,
|
||||
device=device,
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
|
||||
# [batch_size, seq_length] -> [batch_size, tgt_length, src_length]
|
||||
expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
|
||||
combined_attention_mask = (
|
||||
expanded_attn_mask
|
||||
if combined_attention_mask is None
|
||||
else expanded_attn_mask | combined_attention_mask
|
||||
)
|
||||
|
||||
return combined_attention_mask
|
||||
|
||||
def set_input_embeddings(self, new_embeddings: torch.Tensor):
|
||||
self.word_embeddings = new_embeddings
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**deprecated_arguments,
|
||||
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
|
||||
if deprecated_arguments.pop("position_ids", False) is not False:
|
||||
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
|
||||
warnings.warn(
|
||||
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
|
||||
" passing `position_ids`.",
|
||||
FutureWarning,
|
||||
)
|
||||
if len(deprecated_arguments) > 0:
|
||||
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
|
||||
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both input_ids and inputs_embeds at the same time"
|
||||
)
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if past_key_values is None:
|
||||
past_key_values = tuple([None] * len(self.h))
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape batch_size x num_heads x N x N
|
||||
# head_mask has shape n_layer x batch x num_heads x N x N
|
||||
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
|
||||
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
|
||||
|
||||
presents = () if use_cache else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
|
||||
# Compute alibi tensor: check build_alibi_tensor documentation
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
if past_key_values[0] is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[-1]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, seq_length_with_past), device=hidden_states.device
|
||||
)
|
||||
else:
|
||||
attention_mask = attention_mask.to(hidden_states.device)
|
||||
|
||||
alibi = build_alibi_tensor(attention_mask, self.num_heads)
|
||||
|
||||
causal_mask = self._prepare_attn_mask(
|
||||
attention_mask,
|
||||
input_shape=(batch_size, seq_length),
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
|
||||
if hasattr(self, "tp_rank"):
|
||||
assert self.num_heads % self.tp_world_size == 0
|
||||
block_size = self.num_heads // self.tp_world_size
|
||||
alibi = alibi[
|
||||
:, self.tp_rank * block_size : (self.tp_rank + 1) * block_size
|
||||
]
|
||||
alibi = alibi.reshape(batch_size * block_size, 1, seq_length_with_past)
|
||||
causal_mask = torch.repeat_interleave(causal_mask, block_size, dim=0)
|
||||
else:
|
||||
alibi = alibi.reshape(batch_size * self.num_heads, 1, seq_length_with_past)
|
||||
causal_mask = torch.repeat_interleave(causal_mask, self.num_heads, dim=0)
|
||||
|
||||
alibi = alibi.to(hidden_states.dtype)
|
||||
|
||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
layer_past=layer_past,
|
||||
attention_mask=causal_mask,
|
||||
head_mask=head_mask[i],
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
alibi=alibi,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
if use_cache is True:
|
||||
presents = presents + (outputs[1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (
|
||||
outputs[2 if use_cache else 1],
|
||||
)
|
||||
|
||||
# Add last hidden state
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
presents,
|
||||
all_hidden_states,
|
||||
all_self_attentions,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=presents,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
)
|
||||
|
||||
|
||||
class BloomForCausalLM(BloomPreTrainedModel):
|
||||
def __init__(self, config, weights):
|
||||
super().__init__(config)
|
||||
self.transformer = BloomModel(config, weights)
|
||||
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
config,
|
||||
prefix="word_embeddings",
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
past_key_values: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
# only last token for input_ids if past is not None
|
||||
if past_key_values:
|
||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||
|
||||
# the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed
|
||||
if past_key_values[0][0].shape[0] == input_ids.shape[0]:
|
||||
past_key_values = self._convert_to_bloom_cache(past_key_values)
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_key_values is None:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**deprecated_arguments,
|
||||
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
||||
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
||||
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
||||
"""
|
||||
if deprecated_arguments.pop("position_ids", False) is not False:
|
||||
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
|
||||
warnings.warn(
|
||||
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
|
||||
" passing `position_ids`.",
|
||||
FutureWarning,
|
||||
)
|
||||
if len(deprecated_arguments) > 0:
|
||||
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
|
||||
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
hidden_states = transformer_outputs[0]
|
||||
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
loss = None
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithCrossAttentions(
|
||||
loss=loss,
|
||||
logits=lm_logits,
|
||||
past_key_values=transformer_outputs.past_key_values,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
|
@ -30,21 +30,23 @@ import flash_attn_cuda
|
|||
import dropout_layer_norm
|
||||
|
||||
from text_generation_server.utils.layers import (
|
||||
FastLinear,
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
PositionRotaryEmbedding,
|
||||
TensorParallelHead,
|
||||
)
|
||||
|
||||
|
||||
class LlamaRMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
def __init__(self, prefix, weights, eps=1e-6):
|
||||
"""
|
||||
LlamaRMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
self.weight = nn.Parameter(weight)
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states, residual=None):
|
||||
|
@ -91,35 +93,35 @@ class LlamaRMSNorm(nn.Module):
|
|||
class FlashLlamaAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads,
|
||||
hidden_size,
|
||||
process_group=None,
|
||||
prefix: str,
|
||||
config,
|
||||
weights,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.head_size = hidden_size // num_heads
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_size = self.hidden_size // self.num_heads
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.load(
|
||||
prefix=f"{prefix}.rotary_emb", weights=weights
|
||||
)
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000)
|
||||
self.softmax_scale = self.head_size ** (-0.5)
|
||||
|
||||
if process_group is None:
|
||||
self.query_key_value = FastLinear(hidden_size, 3 * hidden_size, bias=False)
|
||||
self.o_proj = FastLinear(hidden_size, hidden_size, bias=False)
|
||||
else:
|
||||
self.num_heads = self.num_heads // process_group.size()
|
||||
self.query_key_value = TensorParallelColumnLinear(
|
||||
hidden_size,
|
||||
3 * hidden_size,
|
||||
bias=False,
|
||||
process_group=process_group,
|
||||
)
|
||||
self.o_proj = TensorParallelRowLinear(
|
||||
hidden_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
process_group=process_group,
|
||||
)
|
||||
self.num_heads = self.num_heads // weights.process_group.size()
|
||||
self.query_key_value = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
dim=0,
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.o_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -195,8 +197,9 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
|
||||
|
||||
class LlamaMLP(nn.Module):
|
||||
def __init__(self, act, hidden_size, intermediate_size, process_group=None):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
act = config.hidden_act
|
||||
self.act = (
|
||||
ACT2FN[act]
|
||||
if "gelu" not in act
|
||||
|
@ -207,32 +210,23 @@ class LlamaMLP(nn.Module):
|
|||
else "none",
|
||||
)
|
||||
)
|
||||
|
||||
if process_group is None:
|
||||
# Fuse gate and up proj
|
||||
self.gate_up_proj = FastLinear(
|
||||
hidden_size, 2 * intermediate_size, bias=False
|
||||
)
|
||||
self.down_proj = FastLinear(intermediate_size, hidden_size, bias=False)
|
||||
self.intermediate_size = intermediate_size
|
||||
else:
|
||||
# Fuse gate and up proj
|
||||
self.gate_up_proj = TensorParallelColumnLinear(
|
||||
hidden_size,
|
||||
2 * intermediate_size,
|
||||
bias=False,
|
||||
process_group=process_group,
|
||||
)
|
||||
self.down_proj = TensorParallelRowLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
process_group=process_group,
|
||||
reduce=True,
|
||||
)
|
||||
self.intermediate_size = self.down_proj.in_features
|
||||
|
||||
self.process_group = process_group
|
||||
# Fuse gate and up proj
|
||||
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||
weights=weights,
|
||||
dim=0,
|
||||
bias=False,
|
||||
)
|
||||
self.down_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.intermediate_size = (
|
||||
config.intermediate_size // weights.process_group.size()
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
gate_up_states = self.gate_up_proj(hidden_states)
|
||||
|
@ -241,22 +235,22 @@ class LlamaMLP(nn.Module):
|
|||
|
||||
|
||||
class FlashLlamaLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads,
|
||||
act,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
rms_norm_eps,
|
||||
process_group=None,
|
||||
):
|
||||
def __init__(self, layer_id, config, weights):
|
||||
super().__init__()
|
||||
prefix = f"model.layers.{layer_id}"
|
||||
self.self_attn = FlashLlamaAttention(
|
||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||
)
|
||||
self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||
|
||||
self.self_attn = FlashLlamaAttention(num_heads, hidden_size, process_group)
|
||||
self.mlp = LlamaMLP(act, hidden_size, intermediate_size, process_group)
|
||||
|
||||
self.input_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps)
|
||||
self.post_attention_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps)
|
||||
self.input_layernorm = LlamaRMSNorm(
|
||||
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||
)
|
||||
self.post_attention_layernorm = LlamaRMSNorm(
|
||||
prefix=f"{prefix}.post_attention_layernorm",
|
||||
weights=weights,
|
||||
eps=config.rms_norm_eps,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -295,54 +289,35 @@ class FlashLlamaLayer(nn.Module):
|
|||
|
||||
|
||||
class FlashLlamaModel(torch.nn.Module):
|
||||
def __init__(self, config, process_group=None):
|
||||
super(FlashLlamaModel, self).__init__()
|
||||
def __init__(self, config, weights):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.tp_embeddings = False
|
||||
if process_group is not None:
|
||||
self.tp_rank = process_group.rank()
|
||||
self.tp_world_size = process_group.size()
|
||||
if config.vocab_size % self.tp_world_size == 0:
|
||||
self.tp_embeddings = True
|
||||
|
||||
if self.tp_embeddings:
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
config.vocab_size, config.hidden_size, process_group=process_group
|
||||
)
|
||||
else:
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
|
||||
process_group = weights.process_group
|
||||
self.tp_rank = process_group.rank()
|
||||
self.tp_world_size = process_group.size()
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix="model.embed_tokens", weights=weights
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
FlashLlamaLayer(
|
||||
config.num_attention_heads,
|
||||
config.hidden_act,
|
||||
config.hidden_size,
|
||||
config.intermediate_size,
|
||||
config.rms_norm_eps,
|
||||
process_group,
|
||||
layer_id,
|
||||
config,
|
||||
weights,
|
||||
)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.norm = LlamaRMSNorm(
|
||||
prefix="model.norm", weights=weights, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
self.head_size = self.layers[0].self_attn.head_size
|
||||
self.num_heads = self.layers[0].self_attn.num_heads
|
||||
|
||||
def post_load_weights(self, quantize: Optional[str] = None):
|
||||
if isinstance(self.embed_tokens, TensorParallelEmbedding):
|
||||
self.embed_tokens.add_null_idx()
|
||||
for layer in self.layers:
|
||||
layer: FlashLlamaLayer
|
||||
layer.self_attn.query_key_value.prepare_weights(quantize)
|
||||
layer.self_attn.o_proj.prepare_weights(quantize)
|
||||
layer.mlp.gate_up_proj.prepare_weights(quantize)
|
||||
layer.mlp.down_proj.prepare_weights(quantize)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
|
@ -410,29 +385,15 @@ class FlashLlamaModel(torch.nn.Module):
|
|||
|
||||
|
||||
class FlashLlamaForCausalLM(torch.nn.Module):
|
||||
def __init__(self, config, process_group=None):
|
||||
def __init__(self, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.process_group = process_group
|
||||
if self.process_group is not None:
|
||||
self.world_size = self.process_group.size()
|
||||
else:
|
||||
self.world_size = 1
|
||||
|
||||
self.model = FlashLlamaModel(config, process_group)
|
||||
|
||||
if self.model.tp_embeddings:
|
||||
self.lm_head = FastLinear(
|
||||
config.hidden_size,
|
||||
config.vocab_size // process_group.size(),
|
||||
bias=False,
|
||||
)
|
||||
else:
|
||||
self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
def post_load_weights(self, quantize: Optional[str] = None):
|
||||
self.model.post_load_weights(quantize)
|
||||
self.lm_head.prepare_weights()
|
||||
self.model = FlashLlamaModel(config, weights)
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
config,
|
||||
prefix="lm_head",
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -457,12 +418,4 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
if self.model.tp_embeddings:
|
||||
# Logits are sharded, so we need to gather them
|
||||
world_logits = [torch.empty_like(logits) for _ in range(self.world_size)]
|
||||
torch.distributed.all_gather(world_logits, logits, group=self.process_group)
|
||||
world_logits = torch.cat(world_logits, dim=1)
|
||||
|
||||
return world_logits, present
|
||||
return logits, present
|
||||
|
|
|
@ -31,61 +31,81 @@ from typing import Optional
|
|||
import flash_attn_cuda
|
||||
|
||||
from text_generation_server.utils.layers import (
|
||||
FastLinear,
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelHead,
|
||||
FastLayerNorm,
|
||||
PositionRotaryEmbedding,
|
||||
get_linear,
|
||||
)
|
||||
|
||||
|
||||
def load_row(config, prefix: str, weights, bias: bool):
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||
if bias and weights.process_group.rank() == 0:
|
||||
# Rank is only on the first rank process
|
||||
bias = weights.get_tensor(f"{prefix}.bias")
|
||||
else:
|
||||
bias = None
|
||||
|
||||
linear = get_linear(weight, bias, config.quantize)
|
||||
if config.use_parallel_residual:
|
||||
return linear
|
||||
else:
|
||||
return TensorParallelRowLinear(linear, process_group=weights.process_group)
|
||||
|
||||
|
||||
def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size):
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
||||
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
|
||||
|
||||
weight = (
|
||||
weight.view(
|
||||
num_heads,
|
||||
3,
|
||||
head_size,
|
||||
hidden_size,
|
||||
)
|
||||
.permute(1, 0, 2, 3)
|
||||
.reshape(-1, hidden_size)
|
||||
)
|
||||
bias = bias.view(num_heads, 3, head_size).permute(1, 0, 2).reshape(-1)
|
||||
|
||||
linear = get_linear(weight, bias, config.quantize)
|
||||
if config.use_parallel_residual:
|
||||
return linear
|
||||
else:
|
||||
return TensorParallelColumnLinear(linear)
|
||||
|
||||
|
||||
class FlashNeoxAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads,
|
||||
hidden_size,
|
||||
rotary_pct,
|
||||
rotary_emb_base,
|
||||
process_group=None,
|
||||
reduce=True,
|
||||
):
|
||||
def __init__(self, config, prefix, weights):
|
||||
super().__init__()
|
||||
num_heads = config.num_attention_heads
|
||||
hidden_size = config.hidden_size
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.head_size = hidden_size // num_heads
|
||||
self.num_heads = self.num_heads // weights.process_group.size()
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.load(
|
||||
prefix=f"{prefix}.rotary_emb", weights=weights
|
||||
)
|
||||
|
||||
rotary_ndims = int(self.head_size * rotary_pct)
|
||||
self.rotary_emb = PositionRotaryEmbedding(rotary_ndims, base=rotary_emb_base)
|
||||
self.softmax_scale = self.head_size ** (-0.5)
|
||||
|
||||
if process_group is None:
|
||||
self.query_key_value = FastLinear(hidden_size, 3 * hidden_size)
|
||||
self.dense = FastLinear(hidden_size, hidden_size)
|
||||
else:
|
||||
self.num_heads = self.num_heads // process_group.size()
|
||||
self.query_key_value = TensorParallelColumnLinear(
|
||||
hidden_size,
|
||||
3 * hidden_size,
|
||||
process_group=process_group,
|
||||
)
|
||||
self.dense = TensorParallelRowLinear(
|
||||
hidden_size, hidden_size, process_group=process_group, reduce=reduce
|
||||
)
|
||||
|
||||
def shuffle_qkv_dims(self):
|
||||
"""Swap dims to avoid an additional permute"""
|
||||
self.query_key_value.weight = torch.nn.Parameter(
|
||||
self.query_key_value.weight.view(
|
||||
self.num_heads, 3, self.head_size, self.hidden_size
|
||||
)
|
||||
.permute(1, 0, 2, 3)
|
||||
.reshape(-1, self.hidden_size)
|
||||
self.query_key_value = load_qkv(
|
||||
config,
|
||||
prefix=f"{prefix}.query_key_value",
|
||||
weights=weights,
|
||||
num_heads=self.num_heads,
|
||||
head_size=self.head_size,
|
||||
hidden_size=self.hidden_size,
|
||||
)
|
||||
self.query_key_value.bias = torch.nn.Parameter(
|
||||
self.query_key_value.bias.view(self.num_heads, 3, self.head_size)
|
||||
.permute(1, 0, 2)
|
||||
.reshape(-1)
|
||||
self.dense = load_row(
|
||||
config, prefix=f"{prefix}.dense", weights=weights, bias=True
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
@ -162,10 +182,9 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||
|
||||
|
||||
class FlashMLP(nn.Module):
|
||||
def __init__(
|
||||
self, act, hidden_size, intermediate_size, process_group=None, reduce=True
|
||||
):
|
||||
def __init__(self, config, prefix, weights):
|
||||
super().__init__()
|
||||
act = config.hidden_act
|
||||
self.act = (
|
||||
ACT2FN[act]
|
||||
if "gelu" not in act
|
||||
|
@ -177,22 +196,12 @@ class FlashMLP(nn.Module):
|
|||
)
|
||||
)
|
||||
|
||||
if process_group is None:
|
||||
self.dense_h_to_4h = FastLinear(hidden_size, intermediate_size)
|
||||
self.dense_4h_to_h = FastLinear(intermediate_size, hidden_size)
|
||||
else:
|
||||
self.dense_h_to_4h = TensorParallelColumnLinear(
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
process_group=process_group,
|
||||
)
|
||||
self.dense_4h_to_h = TensorParallelRowLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
process_group=process_group,
|
||||
reduce=reduce,
|
||||
)
|
||||
self.process_group = process_group
|
||||
self.dense_h_to_4h = TensorParallelColumnLinear.load(
|
||||
config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=True
|
||||
)
|
||||
self.dense_4h_to_h = load_row(
|
||||
config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=True
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.dense_h_to_4h(hidden_states)
|
||||
|
@ -202,38 +211,28 @@ class FlashMLP(nn.Module):
|
|||
|
||||
|
||||
class FlashNeoXLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads,
|
||||
act,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
rotary_pct,
|
||||
rotary_emb_base,
|
||||
layer_norm_eps,
|
||||
use_parallel_residual,
|
||||
process_group=None,
|
||||
):
|
||||
def __init__(self, layer_id, config, weights):
|
||||
super().__init__()
|
||||
self.use_parallel_residual = use_parallel_residual
|
||||
self.input_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps)
|
||||
self.post_attention_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps)
|
||||
|
||||
layer_norm_eps = config.layer_norm_eps
|
||||
|
||||
prefix = f"gpt_neox.layers.{layer_id}"
|
||||
|
||||
self.use_parallel_residual = config.use_parallel_residual
|
||||
self.input_layernorm = FastLayerNorm.load(
|
||||
prefix=f"{prefix}.input_layernorm", weights=weights, eps=layer_norm_eps
|
||||
)
|
||||
self.post_attention_layernorm = FastLayerNorm.load(
|
||||
prefix=f"{prefix}.post_attention_layernorm",
|
||||
weights=weights,
|
||||
eps=layer_norm_eps,
|
||||
)
|
||||
self.attention = FlashNeoxAttention(
|
||||
num_heads,
|
||||
hidden_size,
|
||||
rotary_pct,
|
||||
rotary_emb_base,
|
||||
process_group,
|
||||
reduce=not use_parallel_residual,
|
||||
config, prefix=f"{prefix}.attention", weights=weights
|
||||
)
|
||||
self.mlp = FlashMLP(
|
||||
act,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
process_group,
|
||||
reduce=not use_parallel_residual,
|
||||
)
|
||||
self.process_group = process_group
|
||||
|
||||
self.mlp = FlashMLP(config, prefix=f"{prefix}.mlp", weights=weights)
|
||||
self.process_group = weights.process_group
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -266,9 +265,7 @@ class FlashNeoXLayer(nn.Module):
|
|||
mlp_output = self.mlp(ln2_hidden_states)
|
||||
intermediate = mlp_output + attn_output
|
||||
|
||||
# Only reduce once and after the addition instead of once per layer
|
||||
if self.process_group is not None:
|
||||
torch.distributed.all_reduce(intermediate, group=self.process_group)
|
||||
torch.distributed.all_reduce(intermediate, group=self.process_group)
|
||||
|
||||
return intermediate + hidden_states, None
|
||||
else:
|
||||
|
@ -302,42 +299,24 @@ class FlashGPTNeoXPreTrainedModel(PreTrainedModel):
|
|||
|
||||
|
||||
class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
||||
def __init__(self, config, process_group=None):
|
||||
def __init__(self, config, weights):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
self.tp_embeddings = False
|
||||
if process_group is not None:
|
||||
self.tp_rank = process_group.rank()
|
||||
self.tp_world_size = process_group.size()
|
||||
if config.vocab_size % self.tp_world_size == 0:
|
||||
self.tp_embeddings = True
|
||||
|
||||
if self.tp_embeddings:
|
||||
self.embed_in = TensorParallelEmbedding(
|
||||
config.vocab_size, config.hidden_size, process_group=process_group
|
||||
)
|
||||
else:
|
||||
self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.embed_in = TensorParallelEmbedding(
|
||||
prefix="gpt_neox.embed_in", weights=weights
|
||||
)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
FlashNeoXLayer(
|
||||
config.num_attention_heads,
|
||||
config.hidden_act,
|
||||
config.hidden_size,
|
||||
config.intermediate_size,
|
||||
config.rotary_pct,
|
||||
config.rotary_emb_base,
|
||||
config.layer_norm_eps,
|
||||
config.use_parallel_residual,
|
||||
process_group,
|
||||
)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
FlashNeoXLayer(layer_id, config, weights)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.final_layer_norm = FastLayerNorm(
|
||||
config.hidden_size, eps=config.layer_norm_eps
|
||||
self.final_layer_norm = FastLayerNorm.load(
|
||||
prefix="gpt_neox.final_layer_norm",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_eps,
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
@ -345,29 +324,6 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
|||
self.head_size = self.layers[0].attention.head_size
|
||||
self.num_heads = self.layers[0].attention.num_heads
|
||||
|
||||
def post_load_weights(self, quantize: Optional[str] = None):
|
||||
if isinstance(self.embed_in, TensorParallelEmbedding):
|
||||
self.embed_in.add_null_idx()
|
||||
for layer in self.layers:
|
||||
layer: FlashNeoXLayer
|
||||
layer.attention.shuffle_qkv_dims()
|
||||
layer.attention.query_key_value.prepare_weights(quantize)
|
||||
layer.attention.dense.prepare_weights(quantize)
|
||||
layer.mlp.dense_h_to_4h.prepare_weights(quantize)
|
||||
layer.mlp.dense_4h_to_h.prepare_weights(quantize)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
# Pop here as we will replace the layer in our own logic and don't want from_pretrained
|
||||
# to do it for us
|
||||
load_in_8bit = kwargs.pop("load_in_8bit", False)
|
||||
model = super(FlashGPTNeoXModel, cls).from_pretrained(
|
||||
pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs
|
||||
)
|
||||
|
||||
model.post_load_weights("bitsandbytes" if load_in_8bit else None)
|
||||
return model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
|
@ -435,42 +391,13 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
|||
|
||||
|
||||
class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
||||
def __init__(self, config, process_group=None):
|
||||
def __init__(self, config, weights):
|
||||
super().__init__(config)
|
||||
self.gpt_neox = FlashGPTNeoXModel(config, weights)
|
||||
|
||||
self.process_group = process_group
|
||||
if self.process_group is not None:
|
||||
self.world_size = self.process_group.size()
|
||||
else:
|
||||
self.world_size = 1
|
||||
|
||||
self.gpt_neox = FlashGPTNeoXModel(config, process_group)
|
||||
|
||||
if self.gpt_neox.tp_embeddings:
|
||||
self.embed_out = FastLinear(
|
||||
config.hidden_size,
|
||||
config.vocab_size // process_group.size(),
|
||||
bias=False,
|
||||
)
|
||||
else:
|
||||
self.embed_out = FastLinear(
|
||||
config.hidden_size, config.vocab_size, bias=False
|
||||
)
|
||||
|
||||
def post_load_weights(self, quantize: Optional[str] = None):
|
||||
self.gpt_neox.post_load_weights(quantize)
|
||||
self.embed_out.prepare_weights()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
# Pop here as we will replace the layer in our own logic and don't want from_pretrained
|
||||
# to do it for us
|
||||
load_in_8bit = kwargs.pop("load_in_8bit", False)
|
||||
model = super(FlashGPTNeoXForCausalLM, cls).from_pretrained(
|
||||
pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs
|
||||
self.embed_out = TensorParallelHead.load(
|
||||
config, prefix="embed_out", weights=weights
|
||||
)
|
||||
model.post_load_weights("bitsandbytes" if load_in_8bit else None)
|
||||
return model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -495,12 +422,4 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
|||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits = self.embed_out(hidden_states)
|
||||
|
||||
if self.gpt_neox.tp_embeddings:
|
||||
# Logits are sharded, so we need to gather them
|
||||
world_logits = [torch.empty_like(logits) for _ in range(self.world_size)]
|
||||
torch.distributed.all_gather(world_logits, logits, group=self.process_group)
|
||||
world_logits = torch.cat(world_logits, dim=1)
|
||||
|
||||
return world_logits, present
|
||||
return logits, present
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
|
@ -12,15 +10,31 @@ from typing import Optional
|
|||
import flash_attn_cuda
|
||||
|
||||
from text_generation_server.utils.layers import (
|
||||
FastLinear,
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelHead,
|
||||
FastLayerNorm,
|
||||
PositionRotaryEmbedding,
|
||||
get_linear,
|
||||
)
|
||||
|
||||
|
||||
def load_row(config, prefix: str, weights, bias: bool):
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||
if bias and weights.process_group.rank() == 0:
|
||||
# Rank is only on the first rank process
|
||||
bias = weights.get_tensor(f"{prefix}.bias")
|
||||
else:
|
||||
bias = None
|
||||
|
||||
linear = get_linear(weight, bias, config.quantize)
|
||||
if config.parallel_attn:
|
||||
return linear
|
||||
else:
|
||||
return TensorParallelRowLinear(linear, process_group=weights.process_group)
|
||||
|
||||
|
||||
class RWConfig(PretrainedConfig):
|
||||
attribute_map = {
|
||||
"num_hidden_layers": "n_layer",
|
||||
|
@ -85,44 +99,31 @@ class RWConfig(PretrainedConfig):
|
|||
class FlashRWAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads,
|
||||
num_heads_kv,
|
||||
hidden_size,
|
||||
bias,
|
||||
process_group=None,
|
||||
reduce=True,
|
||||
config,
|
||||
prefix,
|
||||
weights,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.num_heads_kv = num_heads_kv
|
||||
self.hidden_size = hidden_size
|
||||
self.head_size = hidden_size // num_heads
|
||||
self.num_heads = config.n_head
|
||||
self.num_heads_kv = config.n_head_kv
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_size = self.hidden_size // self.num_heads
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000)
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
dim=self.head_size, base=10000.0, device=weights.device
|
||||
)
|
||||
self.softmax_scale = self.head_size ** (-0.5)
|
||||
self.num_heads = self.num_heads // weights.process_group.size()
|
||||
|
||||
if process_group is None:
|
||||
self.query_key_value = FastLinear(
|
||||
hidden_size,
|
||||
self.head_size * (self.num_heads + 2 * self.num_heads_kv),
|
||||
bias=bias,
|
||||
)
|
||||
self.dense = FastLinear(hidden_size, hidden_size, bias=bias)
|
||||
else:
|
||||
self.query_key_value = TensorParallelColumnLinear(
|
||||
hidden_size,
|
||||
self.head_size * (self.num_heads + 2 * self.num_heads_kv),
|
||||
bias=bias,
|
||||
process_group=process_group,
|
||||
)
|
||||
self.dense = TensorParallelRowLinear(
|
||||
hidden_size,
|
||||
hidden_size,
|
||||
bias=bias,
|
||||
process_group=process_group,
|
||||
reduce=reduce,
|
||||
)
|
||||
self.num_heads = self.num_heads // process_group.size()
|
||||
self.query_key_value = TensorParallelColumnLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.query_key_value",
|
||||
weights=weights,
|
||||
bias=config.bias,
|
||||
)
|
||||
self.dense = load_row(
|
||||
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -212,57 +213,48 @@ class FlashRWAttention(torch.nn.Module):
|
|||
class FlashRWLargeAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads,
|
||||
num_heads_kv,
|
||||
hidden_size,
|
||||
bias,
|
||||
process_group=None,
|
||||
reduce=True,
|
||||
config,
|
||||
prefix,
|
||||
weights,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
hidden_size = config.hidden_size
|
||||
num_heads = config.n_head
|
||||
num_heads_kv = config.n_head_kv
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.head_size = hidden_size // num_heads
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000)
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
self.head_size, base=10000.0, device=weights.device
|
||||
)
|
||||
self.softmax_scale = self.head_size ** (-0.5)
|
||||
|
||||
self.num_groups = num_heads // (num_heads_kv * 2)
|
||||
self.num_heads = num_heads // self.num_groups
|
||||
self.num_heads_kv = num_heads_kv // self.num_groups
|
||||
process_group = weights.process_group
|
||||
|
||||
if process_group is None:
|
||||
self.query_key_value = FastLinear(
|
||||
hidden_size,
|
||||
self.num_groups
|
||||
* self.head_size
|
||||
* (self.num_heads + 2 * self.num_heads_kv),
|
||||
bias=bias,
|
||||
if process_group.size() > self.num_groups:
|
||||
raise NotImplementedError(
|
||||
f"Tensor Parallelism is not implemented for world_size > n groups"
|
||||
)
|
||||
self.dense = FastLinear(hidden_size, hidden_size, bias=bias)
|
||||
else:
|
||||
if process_group.size() > self.num_groups:
|
||||
raise NotImplementedError(
|
||||
f"Tensor Parallelism is not implemented for world_size > n groups"
|
||||
)
|
||||
if self.num_groups % process_group.size() != 0:
|
||||
raise NotImplementedError(
|
||||
f"Tensor Parallelism is not implemented for {self.num_groups} not divisible by {process_group.size()}"
|
||||
)
|
||||
self.num_groups = self.num_groups // process_group.size()
|
||||
|
||||
self.query_key_value = TensorParallelColumnLinear(
|
||||
hidden_size,
|
||||
self.num_groups
|
||||
* self.head_size
|
||||
* (self.num_heads + 2 * self.num_heads_kv),
|
||||
bias=bias,
|
||||
process_group=process_group,
|
||||
)
|
||||
self.dense = TensorParallelRowLinear(
|
||||
hidden_size,
|
||||
hidden_size,
|
||||
bias=bias,
|
||||
process_group=process_group,
|
||||
reduce=reduce,
|
||||
)
|
||||
|
||||
self.num_groups = self.num_groups // process_group.size()
|
||||
self.query_key_value = TensorParallelColumnLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.query_key_value",
|
||||
weights=weights,
|
||||
bias=config.bias,
|
||||
)
|
||||
self.dense = load_row(
|
||||
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -359,28 +351,16 @@ class FlashRWLargeAttention(torch.nn.Module):
|
|||
|
||||
|
||||
class FlashMLP(nn.Module):
|
||||
def __init__(self, hidden_size, bias, process_group=None, reduce=True):
|
||||
def __init__(self, config, prefix, weights):
|
||||
super().__init__()
|
||||
self.act = torch.nn.functional.gelu
|
||||
|
||||
if process_group is None:
|
||||
self.dense_h_to_4h = FastLinear(hidden_size, 4 * hidden_size, bias=bias)
|
||||
self.dense_4h_to_h = FastLinear(4 * hidden_size, hidden_size, bias=bias)
|
||||
else:
|
||||
self.dense_h_to_4h = TensorParallelColumnLinear(
|
||||
hidden_size,
|
||||
4 * hidden_size,
|
||||
bias=bias,
|
||||
process_group=process_group,
|
||||
)
|
||||
self.dense_4h_to_h = TensorParallelRowLinear(
|
||||
4 * hidden_size,
|
||||
hidden_size,
|
||||
bias=bias,
|
||||
process_group=process_group,
|
||||
reduce=reduce,
|
||||
)
|
||||
self.process_group = process_group
|
||||
self.dense_h_to_4h = TensorParallelColumnLinear.load(
|
||||
config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=config.bias
|
||||
)
|
||||
self.dense_4h_to_h = load_row(
|
||||
config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=config.bias
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.dense_h_to_4h(hidden_states)
|
||||
|
@ -392,38 +372,44 @@ class FlashMLP(nn.Module):
|
|||
class FlashRWLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads,
|
||||
num_heads_kv,
|
||||
hidden_size,
|
||||
bias,
|
||||
layer_norm_eps,
|
||||
parallel_attn,
|
||||
process_group=None,
|
||||
layer_id,
|
||||
config,
|
||||
weights,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
parallel_attn = config.parallel_attn
|
||||
self.parallel_attn = parallel_attn
|
||||
|
||||
self.input_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps)
|
||||
prefix = f"transformer.h.{layer_id}"
|
||||
|
||||
self.input_layernorm = FastLayerNorm.load(
|
||||
prefix=f"{prefix}.input_layernorm",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_epsilon,
|
||||
)
|
||||
self.self_attention = FlashRWAttention(
|
||||
num_heads,
|
||||
num_heads_kv,
|
||||
hidden_size,
|
||||
bias,
|
||||
process_group=process_group,
|
||||
reduce=False,
|
||||
config,
|
||||
prefix=f"{prefix}.self_attention",
|
||||
weights=weights,
|
||||
)
|
||||
self.post_attention_layernorm = (
|
||||
FastLayerNorm(hidden_size, eps=layer_norm_eps)
|
||||
FastLayerNorm.load(
|
||||
prefix=f"{prefix}.post_attention_layernorm",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_epsilon,
|
||||
)
|
||||
if not parallel_attn
|
||||
else None
|
||||
)
|
||||
|
||||
self.mlp = FlashMLP(
|
||||
hidden_size, bias, process_group=process_group, reduce=False
|
||||
config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
self.process_group = process_group
|
||||
self.process_group = weights.process_group
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -454,9 +440,7 @@ class FlashRWLayer(nn.Module):
|
|||
mlp_output = self.mlp(ln_hidden_states)
|
||||
intermediate = mlp_output + attn_output
|
||||
|
||||
# Only reduce once and after the addition instead of once per layer
|
||||
if self.process_group is not None:
|
||||
torch.distributed.all_reduce(intermediate, group=self.process_group)
|
||||
torch.distributed.all_reduce(intermediate, group=self.process_group)
|
||||
|
||||
return intermediate, residual
|
||||
else:
|
||||
|
@ -483,33 +467,30 @@ class FlashRWLayer(nn.Module):
|
|||
|
||||
|
||||
class FlashRWLargeLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads,
|
||||
num_heads_kv,
|
||||
hidden_size,
|
||||
bias,
|
||||
layer_norm_eps,
|
||||
process_group=None,
|
||||
):
|
||||
def __init__(self, layer_id, config, weights):
|
||||
super().__init__()
|
||||
self.ln_attn = FastLayerNorm(hidden_size, eps=layer_norm_eps)
|
||||
self.ln_mlp = FastLayerNorm(hidden_size, eps=layer_norm_eps)
|
||||
prefix = f"transformer.h.{layer_id}"
|
||||
self.ln_attn = FastLayerNorm.load(
|
||||
prefix=f"{prefix}.ln_attn",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_epsilon,
|
||||
)
|
||||
self.ln_mlp = FastLayerNorm.load(
|
||||
prefix=f"{prefix}.ln_mlp",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_epsilon,
|
||||
)
|
||||
|
||||
self.self_attention = FlashRWLargeAttention(
|
||||
num_heads,
|
||||
num_heads_kv,
|
||||
hidden_size,
|
||||
bias,
|
||||
process_group=process_group,
|
||||
reduce=False,
|
||||
config,
|
||||
prefix=f"{prefix}.self_attention",
|
||||
weights=weights,
|
||||
)
|
||||
assert config.parallel_attn, "This version doesn't support non parallel_attn"
|
||||
|
||||
self.mlp = FlashMLP(
|
||||
hidden_size, bias, process_group=process_group, reduce=False
|
||||
)
|
||||
self.mlp = FlashMLP(config, prefix=f"{prefix}.mlp", weights=weights)
|
||||
|
||||
self.process_group = process_group
|
||||
self.process_group = weights.process_group
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -543,9 +524,7 @@ class FlashRWLargeLayer(nn.Module):
|
|||
|
||||
intermediate = attn_output + mlp_output
|
||||
|
||||
# Only reduce once and after the addition instead of once per layer
|
||||
if self.process_group is not None:
|
||||
torch.distributed.all_reduce(intermediate, group=self.process_group)
|
||||
torch.distributed.all_reduce(intermediate, group=self.process_group)
|
||||
|
||||
return intermediate, residual
|
||||
|
||||
|
@ -555,37 +534,18 @@ class FlashRWPreTrainedModel(PreTrainedModel):
|
|||
|
||||
|
||||
class FlashRWModel(FlashRWPreTrainedModel):
|
||||
def __init__(self, config, process_group=None):
|
||||
def __init__(self, config, weights):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
self.tp_embeddings = False
|
||||
if process_group is not None:
|
||||
self.tp_rank = process_group.rank()
|
||||
self.tp_world_size = process_group.size()
|
||||
if config.vocab_size % self.tp_world_size == 0:
|
||||
self.tp_embeddings = True
|
||||
|
||||
if self.tp_embeddings:
|
||||
self.word_embeddings = TensorParallelEmbedding(
|
||||
config.vocab_size, config.hidden_size, process_group=process_group
|
||||
)
|
||||
else:
|
||||
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
|
||||
self.word_embeddings = TensorParallelEmbedding(
|
||||
prefix="transformer.word_embeddings", weights=weights
|
||||
)
|
||||
if config.model_type == "RefinedWebModel":
|
||||
self.h = nn.ModuleList(
|
||||
[
|
||||
FlashRWLayer(
|
||||
config.n_head,
|
||||
config.n_head_kv,
|
||||
config.hidden_size,
|
||||
config.bias,
|
||||
config.layer_norm_epsilon,
|
||||
config.parallel_attn,
|
||||
process_group,
|
||||
)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
FlashRWLayer(layer_id, config, weights)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.cache_size = (
|
||||
|
@ -596,15 +556,8 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||
elif config.model_type == "RefinedWeb":
|
||||
self.h = nn.ModuleList(
|
||||
[
|
||||
FlashRWLargeLayer(
|
||||
config.n_head,
|
||||
config.n_head_kv,
|
||||
config.hidden_size,
|
||||
config.bias,
|
||||
config.layer_norm_epsilon,
|
||||
process_group,
|
||||
)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
FlashRWLargeLayer(layer_id, config, weights)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.cache_size = (
|
||||
|
@ -617,31 +570,13 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||
f"model_type {config.model_type} is not supported."
|
||||
)
|
||||
|
||||
self.ln_f = FastLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
self.head_size = self.h[0].self_attention.head_size
|
||||
|
||||
def post_load_weights(self, quantize: Optional[str] = None):
|
||||
if isinstance(self.word_embeddings, TensorParallelEmbedding):
|
||||
self.word_embeddings.add_null_idx()
|
||||
for layer in self.h:
|
||||
layer: FlashRWLayer
|
||||
layer.self_attention.query_key_value.prepare_weights(quantize)
|
||||
layer.self_attention.dense.prepare_weights(quantize)
|
||||
layer.mlp.dense_h_to_4h.prepare_weights(quantize)
|
||||
layer.mlp.dense_4h_to_h.prepare_weights(quantize)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
# Pop here as we will replace the layer in our own logic and don't want from_pretrained
|
||||
# to do it for us
|
||||
load_in_8bit = kwargs.pop("load_in_8bit", False)
|
||||
model = super(FlashRWModel, cls).from_pretrained(
|
||||
pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs
|
||||
self.ln_f = FastLayerNorm.load(
|
||||
prefix="transformer.ln_f",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_epsilon,
|
||||
)
|
||||
|
||||
model.post_load_weights("bitsandbytes" if load_in_8bit else None)
|
||||
return model
|
||||
self.head_size = self.h[0].self_attention.head_size
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -708,40 +643,14 @@ class FlashRWModel(FlashRWPreTrainedModel):
|
|||
|
||||
|
||||
class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
||||
def __init__(self, config, process_group=None):
|
||||
def __init__(self, config, weights):
|
||||
super().__init__(config)
|
||||
|
||||
self.process_group = process_group
|
||||
if self.process_group is not None:
|
||||
self.world_size = self.process_group.size()
|
||||
else:
|
||||
self.world_size = 1
|
||||
self.transformer = FlashRWModel(config, weights)
|
||||
|
||||
self.transformer = FlashRWModel(config, process_group)
|
||||
|
||||
if self.transformer.tp_embeddings:
|
||||
self.lm_head = FastLinear(
|
||||
config.hidden_size,
|
||||
config.vocab_size // process_group.size(),
|
||||
bias=False,
|
||||
)
|
||||
else:
|
||||
self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
def post_load_weights(self, quantize: Optional[str] = None):
|
||||
self.transformer.post_load_weights(quantize)
|
||||
self.lm_head.prepare_weights()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
# Pop here as we will replace the layer in our own logic and don't want from_pretrained
|
||||
# to do it for us
|
||||
load_in_8bit = kwargs.pop("load_in_8bit", False)
|
||||
model = super(FlashRWForCausalLM, cls).from_pretrained(
|
||||
pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
config, prefix="lm_head", weights=weights
|
||||
)
|
||||
model.post_load_weights("bitsandbytes" if load_in_8bit else None)
|
||||
return model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -766,12 +675,4 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
|||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
if self.transformer.tp_embeddings:
|
||||
# Logits are sharded, so we need to gather them
|
||||
world_logits = [torch.empty_like(logits) for _ in range(self.world_size)]
|
||||
torch.distributed.all_gather(world_logits, logits, group=self.process_group)
|
||||
world_logits = torch.cat(world_logits, dim=1)
|
||||
|
||||
return world_logits, present
|
||||
return logits, present
|
||||
|
|
|
@ -8,39 +8,142 @@ from typing import Optional
|
|||
# Flash attention imports
|
||||
import flash_attn_cuda
|
||||
from text_generation_server.utils.layers import (
|
||||
FastLinear,
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelHead,
|
||||
TensorParallelEmbedding,
|
||||
FastLayerNorm,
|
||||
get_linear,
|
||||
)
|
||||
|
||||
|
||||
class FlashMQAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads,
|
||||
def load_multi_mqa(
|
||||
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
|
||||
):
|
||||
if any("c_attn" in k for k in weights.routing.keys()):
|
||||
slice_ = weights._get_slice(f"{prefix}.c_attn.weight")
|
||||
shape = slice_.get_shape()
|
||||
world_size = weights.process_group.size()
|
||||
rank = weights.process_group.rank()
|
||||
if config.transpose:
|
||||
block_size = (shape[1] - 2 * head_size) // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
assert (shape[1] - 2 * head_size) % world_size == 0
|
||||
q_tensor = slice_[:, start:stop]
|
||||
kv_tensor = slice_[:, -2 * head_size :]
|
||||
weight = torch.cat([q_tensor, kv_tensor], dim=1).T
|
||||
else:
|
||||
block_size = (shape[0] - 2 * head_size) // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
assert (shape[0] - 2 * head_size) % world_size == 0
|
||||
q_tensor = slice_[start:stop]
|
||||
kv_tensor = slice_[-2 * head_size :]
|
||||
weight = torch.cat([q_tensor, kv_tensor], dim=0)
|
||||
if bias:
|
||||
slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
|
||||
shape = slice_.get_shape()
|
||||
block_size = (shape[0] - 2 * head_size) // world_size
|
||||
assert (shape[0] - 2 * head_size) % world_size == 0
|
||||
q_tensor = slice_[start:stop]
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
q_tensor = slice_[start:stop]
|
||||
kv_tensor = slice_[-2 * head_size :]
|
||||
bias = torch.cat([q_tensor, kv_tensor], dim=0)
|
||||
else:
|
||||
if config.transpose:
|
||||
w = [
|
||||
weights.get_sharded(f"{prefix}.q_attn.weight", dim=1).T,
|
||||
weights.get_tensor(f"{prefix}.kv_attn.weight").T,
|
||||
]
|
||||
weight = torch.cat(w, dim=0)
|
||||
else:
|
||||
w = [
|
||||
weights.get_sharded(f"{prefix}.q_attn.weight", dim=0),
|
||||
weights.get_tensor(f"{prefix}.kv_attn.weight"),
|
||||
]
|
||||
weight = torch.cat(w, dim=1)
|
||||
|
||||
if bias:
|
||||
b = [
|
||||
weights.get_sharded(f"{prefix}.q_attn.bias", dim=0),
|
||||
weights.get_tensor(f"{prefix}.kv_attn.bias"),
|
||||
]
|
||||
bias = torch.cat(b, dim=0)
|
||||
else:
|
||||
bias = None
|
||||
|
||||
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||
assert list(weight.shape) == [
|
||||
(num_heads + 2) * head_size,
|
||||
hidden_size,
|
||||
process_group=None,
|
||||
):
|
||||
], f"{weight.shape} != {[(num_heads + 2) * head_size, hidden_size]}"
|
||||
if bias is not None:
|
||||
bias = bias.to(dtype=weights.dtype).to(device=weights.device)
|
||||
assert list(bias.shape) == [
|
||||
(num_heads + 2) * head_size
|
||||
], f"{weight.shape} != {[(num_heads + 2) * head_size]}"
|
||||
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
|
||||
|
||||
|
||||
def load_col(config, prefix: str, weights, bias: bool):
|
||||
if config.transpose:
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=1).T
|
||||
else:
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
||||
|
||||
if bias:
|
||||
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
|
||||
else:
|
||||
bias = None
|
||||
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
|
||||
|
||||
|
||||
def load_row(config, prefix: str, weights, bias: bool):
|
||||
if config.transpose:
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T
|
||||
else:
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||
|
||||
if bias and weights.process_group.rank() == 0:
|
||||
# Rank is only on the first rank process
|
||||
bias = weights.get_tensor(f"{prefix}.bias")
|
||||
else:
|
||||
bias = None
|
||||
return TensorParallelRowLinear(
|
||||
get_linear(weight, bias, config.quantize), process_group=weights.process_group
|
||||
)
|
||||
|
||||
|
||||
class FlashMQAttention(torch.nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
num_heads = config.num_attention_heads
|
||||
hidden_size = config.hidden_size
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.head_size = hidden_size // num_heads
|
||||
|
||||
assert self.num_heads % weights.process_group.size() == 0
|
||||
self.num_heads = self.num_heads // weights.process_group.size()
|
||||
|
||||
self.softmax_scale = self.head_size ** (-0.5)
|
||||
|
||||
if process_group is None:
|
||||
self.c_attn = FastLinear(hidden_size, hidden_size + 2 * self.head_size)
|
||||
self.c_proj = FastLinear(hidden_size, hidden_size)
|
||||
else:
|
||||
self.num_heads = self.num_heads // process_group.size()
|
||||
self.c_attn = FastLinear(hidden_size, self.head_size * (self.num_heads + 2))
|
||||
self.c_proj = TensorParallelRowLinear(
|
||||
hidden_size,
|
||||
hidden_size,
|
||||
process_group=process_group,
|
||||
)
|
||||
self.c_attn = load_multi_mqa(
|
||||
config,
|
||||
prefix=prefix,
|
||||
weights=weights,
|
||||
bias=True,
|
||||
head_size=self.head_size,
|
||||
hidden_size=hidden_size,
|
||||
num_heads=self.num_heads,
|
||||
)
|
||||
self.c_proj = load_row(
|
||||
config, prefix=f"{prefix}.c_proj", weights=weights, bias=True
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -121,8 +224,9 @@ class FlashMQAttention(torch.nn.Module):
|
|||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, act, hidden_size, intermediate_size, process_group=None):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
act = config.activation_function
|
||||
self.act = (
|
||||
ACT2FN[act]
|
||||
if "gelu" not in act
|
||||
|
@ -134,20 +238,12 @@ class MLP(nn.Module):
|
|||
)
|
||||
)
|
||||
|
||||
if process_group is None:
|
||||
self.c_fc = FastLinear(hidden_size, intermediate_size)
|
||||
self.c_proj = FastLinear(intermediate_size, hidden_size)
|
||||
else:
|
||||
self.c_fc = TensorParallelColumnLinear(
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
process_group=process_group,
|
||||
)
|
||||
self.c_proj = TensorParallelRowLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
process_group=process_group,
|
||||
)
|
||||
self.c_fc = load_col(
|
||||
config, prefix=f"{prefix}.c_fc", weights=weights, bias=True
|
||||
)
|
||||
self.c_proj = load_row(
|
||||
config, prefix=f"{prefix}.c_proj", weights=weights, bias=True
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.c_fc(hidden_states)
|
||||
|
@ -157,28 +253,24 @@ class MLP(nn.Module):
|
|||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads,
|
||||
act,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
layer_norm_eps,
|
||||
process_group=None,
|
||||
):
|
||||
def __init__(self, layer_id, config, weights):
|
||||
super().__init__()
|
||||
self.ln_1 = FastLayerNorm(hidden_size, eps=layer_norm_eps)
|
||||
self.ln_2 = FastLayerNorm(hidden_size, eps=layer_norm_eps)
|
||||
prefix = f"transformer.h.{layer_id}"
|
||||
self.ln_1 = FastLayerNorm.load(
|
||||
prefix=f"{prefix}.ln_1", weights=weights, eps=config.layer_norm_epsilon
|
||||
)
|
||||
self.ln_2 = FastLayerNorm.load(
|
||||
prefix=f"{prefix}.ln_2", weights=weights, eps=config.layer_norm_epsilon
|
||||
)
|
||||
self.attn = FlashMQAttention(
|
||||
num_heads,
|
||||
hidden_size,
|
||||
process_group,
|
||||
prefix=f"{prefix}.attn",
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
self.mlp = MLP(
|
||||
act,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
process_group,
|
||||
prefix=f"{prefix}.mlp",
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
@ -210,66 +302,39 @@ class Block(nn.Module):
|
|||
|
||||
|
||||
class FlashSantacoderModel(nn.Module):
|
||||
def __init__(self, config, process_group=None):
|
||||
def __init__(self, config, weights):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.process_group = process_group
|
||||
self.tp_embeddings = False
|
||||
if process_group is not None:
|
||||
self.tp_rank = process_group.rank()
|
||||
self.tp_world_size = process_group.size()
|
||||
if config.vocab_size % self.tp_world_size == 0:
|
||||
self.tp_embeddings = True
|
||||
|
||||
if self.tp_embeddings:
|
||||
self.wte = TensorParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
reduce=False,
|
||||
process_group=process_group,
|
||||
)
|
||||
self.wpe = TensorParallelEmbedding(
|
||||
config.max_position_embeddings,
|
||||
config.hidden_size,
|
||||
reduce=False,
|
||||
process_group=process_group,
|
||||
)
|
||||
else:
|
||||
self.wte = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.wpe = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
||||
self.process_group = weights.process_group
|
||||
self.wte = TensorParallelEmbedding(
|
||||
prefix="transformer.wte",
|
||||
weights=weights,
|
||||
reduce=False,
|
||||
)
|
||||
self.wpe = TensorParallelEmbedding(
|
||||
prefix="transformer.wpe",
|
||||
weights=weights,
|
||||
reduce=False,
|
||||
)
|
||||
|
||||
self.h = nn.ModuleList(
|
||||
[
|
||||
Block(
|
||||
config.num_attention_heads,
|
||||
config.activation_function,
|
||||
config.hidden_size,
|
||||
config.n_inner
|
||||
if config.n_inner is not None
|
||||
else 4 * config.hidden_size,
|
||||
config.layer_norm_epsilon,
|
||||
process_group,
|
||||
layer_id,
|
||||
config,
|
||||
weights,
|
||||
)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.ln_f = FastLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.ln_f = FastLayerNorm.load(
|
||||
prefix="transformer.ln_f", weights=weights, eps=config.layer_norm_epsilon
|
||||
)
|
||||
|
||||
self.head_size = self.h[0].attn.head_size
|
||||
self.num_heads = self.h[0].attn.num_heads
|
||||
|
||||
def post_load_weights(self, quantize: Optional[str] = None):
|
||||
if self.tp_embeddings:
|
||||
self.wte.add_null_idx()
|
||||
self.wpe.add_null_idx()
|
||||
for layer in self.h:
|
||||
layer: Block
|
||||
layer.attn.c_attn.prepare_weights(quantize)
|
||||
layer.attn.c_proj.prepare_weights(quantize)
|
||||
layer.mlp.c_fc.prepare_weights(quantize)
|
||||
layer.mlp.c_proj.prepare_weights(quantize)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
|
@ -281,8 +346,7 @@ class FlashSantacoderModel(nn.Module):
|
|||
pre_allocate_past_size: Optional[int] = None,
|
||||
):
|
||||
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
|
||||
if self.tp_embeddings:
|
||||
torch.distributed.all_reduce(hidden_states, group=self.process_group)
|
||||
torch.distributed.all_reduce(hidden_states, group=self.process_group)
|
||||
|
||||
# Prefill
|
||||
if past_key_values is None:
|
||||
|
@ -331,23 +395,12 @@ class FlashSantacoderModel(nn.Module):
|
|||
|
||||
|
||||
class FlashSantacoderForCausalLM(nn.Module):
|
||||
def __init__(self, config, process_group=None):
|
||||
def __init__(self, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.transformer = FlashSantacoderModel(config, process_group)
|
||||
|
||||
if self.transformer.tp_embeddings:
|
||||
self.lm_head = FastLinear(
|
||||
config.hidden_size,
|
||||
config.vocab_size // process_group.size(),
|
||||
bias=False,
|
||||
)
|
||||
else:
|
||||
self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
def post_load_weights(self, quantize: Optional[str] = None):
|
||||
self.transformer.post_load_weights(quantize)
|
||||
self.lm_head.prepare_weights()
|
||||
self.transformer = FlashSantacoderModel(config, weights)
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
config, prefix="transformer.wte", weights=weights
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -372,29 +425,4 @@ class FlashSantacoderForCausalLM(nn.Module):
|
|||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
if self.transformer.tp_embeddings:
|
||||
# Logits are sharded, so we need to gather them
|
||||
if logits.shape[0] == 1:
|
||||
# Fast path when batch size is 1
|
||||
world_logits = logits.new_empty(
|
||||
(logits.shape[1] * self.transformer.tp_world_size)
|
||||
)
|
||||
torch.distributed.all_gather_into_tensor(
|
||||
world_logits, logits.view(-1), group=self.transformer.process_group
|
||||
)
|
||||
world_logits = world_logits.view(1, -1)
|
||||
else:
|
||||
# We cannot use all_gather_into_tensor as it only support concatenating on the first dim
|
||||
world_logits = [
|
||||
torch.empty_like(logits)
|
||||
for _ in range(self.transformer.tp_world_size)
|
||||
]
|
||||
torch.distributed.all_gather(
|
||||
world_logits, logits, group=self.transformer.process_group
|
||||
)
|
||||
world_logits = torch.cat(world_logits, dim=1)
|
||||
|
||||
return world_logits, present
|
||||
|
||||
return logits, present
|
||||
|
|
|
@ -0,0 +1,794 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" PyTorch GPTNeoX model."""
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import os
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.file_utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
QuestionAnsweringModelOutput,
|
||||
SequenceClassifierOutputWithPast,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers import GPTNeoXConfig
|
||||
from loguru import logger
|
||||
from text_generation_server.utils.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelHead,
|
||||
)
|
||||
|
||||
|
||||
CUSTOM_KERNELS_ENABLED = False
|
||||
if not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True":
|
||||
try:
|
||||
from custom_kernels import fused_attention_cuda
|
||||
|
||||
CUSTOM_KERNELS_ENABLED = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if not CUSTOM_KERNELS_ENABLED:
|
||||
logger.warning("We're not using custom kernels.")
|
||||
|
||||
|
||||
def make_causal_mask(
|
||||
input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
|
||||
) -> torch.BoolTensor:
|
||||
"""
|
||||
Make causal mask used for self-attention.
|
||||
"""
|
||||
batch_size, target_length = input_ids_shape
|
||||
mask = torch.ones(
|
||||
(target_length, target_length + past_key_values_length),
|
||||
dtype=torch.bool,
|
||||
device=device,
|
||||
)
|
||||
mask = mask.triu(1 + past_key_values_length)
|
||||
|
||||
expanded_mask = mask.unsqueeze(0).expand(
|
||||
batch_size, target_length, target_length + past_key_values_length
|
||||
)
|
||||
return expanded_mask
|
||||
|
||||
|
||||
def expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
|
||||
"""
|
||||
Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`.
|
||||
"""
|
||||
batch_size, src_length = mask.shape
|
||||
tgt_length = tgt_length if tgt_length is not None else src_length
|
||||
|
||||
expanded_mask = ~(mask[:, None, :].to(torch.bool))
|
||||
return expanded_mask.expand(batch_size, tgt_length, src_length)
|
||||
|
||||
|
||||
def prepare_attn_mask(
|
||||
attention_mask: torch.Tensor,
|
||||
input_shape: Tuple[int, int],
|
||||
past_key_values_length: int,
|
||||
) -> torch.BoolTensor:
|
||||
# create causal mask
|
||||
# [batch_size, seq_length] -> [batch_size, tgt_length, src_length]
|
||||
combined_attention_mask = None
|
||||
device = attention_mask.device
|
||||
_, src_length = input_shape
|
||||
|
||||
if src_length > 1:
|
||||
combined_attention_mask = make_causal_mask(
|
||||
input_shape, device=device, past_key_values_length=past_key_values_length
|
||||
)
|
||||
|
||||
# [batch_size, seq_length] -> [batch_size, tgt_length, src_length]
|
||||
expanded_attn_mask = expand_mask(attention_mask, tgt_length=src_length)
|
||||
combined_attention_mask = (
|
||||
expanded_attn_mask
|
||||
if combined_attention_mask is None
|
||||
else expanded_attn_mask | combined_attention_mask
|
||||
)
|
||||
|
||||
return combined_attention_mask
|
||||
|
||||
|
||||
class GPTNeoXPreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
models.
|
||||
"""
|
||||
|
||||
|
||||
class GPTNeoXAttention(nn.Module):
|
||||
def __init__(self, config, prefix, weights):
|
||||
super().__init__()
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_size = self.hidden_size // self.num_attention_heads
|
||||
self.rotary_ndims = int(self.head_size * config.rotary_pct)
|
||||
max_positions = config.max_position_embeddings
|
||||
# ??? TODO
|
||||
# self.register_buffer(
|
||||
# "bias",
|
||||
# torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
|
||||
# 1, 1, max_positions, max_positions
|
||||
# ),
|
||||
# )
|
||||
# self.register_buffer("masked_bias", torch.tensor(-1e9))
|
||||
self.rotary_emb = RotaryEmbedding(
|
||||
self.rotary_ndims,
|
||||
config.max_position_embeddings,
|
||||
base=config.rotary_emb_base,
|
||||
)
|
||||
self.rotary_emb.inv_freq = nn.Parameter(
|
||||
weights.get_tensor(f"{prefix}.rotary_emb.inv_freq")
|
||||
)
|
||||
self.inv_norm_factor = 1.0 / torch.sqrt(
|
||||
torch.tensor(self.head_size, dtype=torch.float32)
|
||||
).to(torch.get_default_dtype())
|
||||
|
||||
assert self.num_attention_heads % weights.process_group.size() == 0
|
||||
self.num_attention_heads = (
|
||||
self.num_attention_heads // weights.process_group.size()
|
||||
)
|
||||
self.query_key_value = TensorParallelColumnLinear.load(
|
||||
config, prefix=f"{prefix}.query_key_value", weights=weights, bias=True
|
||||
)
|
||||
self.dense = TensorParallelRowLinear.load(
|
||||
config, prefix=f"{prefix}.dense", weights=weights, bias=True
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
head_mask=None,
|
||||
layer_past=None,
|
||||
use_cache=False,
|
||||
output_attentions=False,
|
||||
):
|
||||
has_layer_past = layer_past is not None
|
||||
|
||||
# Compute QKV
|
||||
# Attention heads [batch, seq_len, hidden_size]
|
||||
# --> [batch, seq_len, (np * 3 * head_size)]
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
|
||||
# [batch, seq_len, (num_heads * 3 * head_size)]
|
||||
# --> [batch, seq_len, num_heads, 3 * head_size]
|
||||
new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size)
|
||||
qkv = qkv.view(*new_qkv_shape).permute(0, 2, 1, 3)
|
||||
# [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size]
|
||||
query, key, value = qkv.split(self.head_size, -1)
|
||||
|
||||
# Compute token offset for rotary embeddings (when decoding)
|
||||
seq_len = key.shape[-2]
|
||||
if has_layer_past:
|
||||
seq_len += layer_past[0].shape[-2]
|
||||
|
||||
# Compute rotary embeddings on rotary_ndims
|
||||
query_rot = query[..., : self.rotary_ndims]
|
||||
key_rot = key[..., : self.rotary_ndims]
|
||||
|
||||
query_rot, key_rot = self.rotary_emb(query_rot, key_rot, position_ids, seq_len)
|
||||
|
||||
query[..., : self.rotary_ndims] = query_rot
|
||||
key[..., : self.rotary_ndims] = key_rot
|
||||
|
||||
if CUSTOM_KERNELS_ENABLED:
|
||||
attn_output, present, attn_weights = fused_attention_cuda.forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
layer_past,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
self.inv_norm_factor,
|
||||
self.num_attention_heads,
|
||||
use_cache,
|
||||
)
|
||||
else:
|
||||
# Cache QKV values
|
||||
if has_layer_past:
|
||||
past_key = layer_past[0]
|
||||
past_value = layer_past[1]
|
||||
key = torch.cat((past_key, key), dim=-2)
|
||||
value = torch.cat((past_value, value), dim=-2)
|
||||
present = (key, value) if use_cache else None
|
||||
|
||||
# Compute attention
|
||||
attn_output, attn_weights = self._attn(
|
||||
query, key, value, attention_mask, head_mask
|
||||
)
|
||||
|
||||
# Reshape outputs
|
||||
attn_output = self._merge_heads(
|
||||
attn_output, self.num_attention_heads, self.head_size
|
||||
)
|
||||
|
||||
attn_output = self.dense(attn_output)
|
||||
|
||||
outputs = (attn_output, present)
|
||||
if output_attentions:
|
||||
outputs += (attn_weights,)
|
||||
|
||||
return outputs
|
||||
|
||||
@classmethod
|
||||
def _split_heads(cls, tensor, num_attention_heads, attn_head_size):
|
||||
"""
|
||||
Splits hidden dim into attn_head_size and num_attention_heads
|
||||
"""
|
||||
# tensor: [bs, seq_len, hidden_size]
|
||||
new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size)
|
||||
# -> [bs, seq_len, num_attention_heads, attn_head_size]
|
||||
tensor = tensor.view(new_shape)
|
||||
# -> [bs, num_attention_heads, seq_len, attn_head_size]
|
||||
tensor = tensor.permute(0, 2, 1, 3)
|
||||
return tensor
|
||||
|
||||
@classmethod
|
||||
def _merge_heads(cls, tensor, num_attention_heads, attn_head_size):
|
||||
"""
|
||||
Merges attn_head_size dim and num_attn_heads dim into hidden dim
|
||||
"""
|
||||
# tensor [bs, num_attention_heads, seq_len, attn_head_size]
|
||||
tensor = tensor.permute(0, 2, 1, 3).contiguous()
|
||||
# -> [bs, seq_len, num_attention_heads, attn_head_size]
|
||||
tensor = tensor.view(
|
||||
tensor.size(0), tensor.size(1), num_attention_heads * attn_head_size
|
||||
)
|
||||
# -> [bs, seq_len, hidden_size]
|
||||
return tensor
|
||||
|
||||
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
||||
# q, k, v: [bs, num_attention_heads, seq_len, attn_head_size]
|
||||
# compute causal mask from causal mask buffer
|
||||
batch_size, num_attention_heads, query_length, attn_head_size = query.size()
|
||||
key_length = key.size(-2)
|
||||
|
||||
query = query.view(
|
||||
batch_size * num_attention_heads, query_length, attn_head_size
|
||||
)
|
||||
key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
|
||||
attn_scores = torch.zeros(
|
||||
1,
|
||||
dtype=query.dtype,
|
||||
device=key.device,
|
||||
).expand(batch_size * num_attention_heads, query_length, key_length)
|
||||
attn_scores = torch.baddbmm(
|
||||
attn_scores,
|
||||
query,
|
||||
key.transpose(1, 2),
|
||||
beta=1.0,
|
||||
alpha=self.inv_norm_factor,
|
||||
)
|
||||
|
||||
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
|
||||
input_dtype = attn_scores.dtype
|
||||
if input_dtype in [torch.float16, torch.bfloat16]:
|
||||
attn_scores = attn_scores.to(torch.float)
|
||||
attn_scores = torch.where(
|
||||
attention_mask, torch.finfo(attn_scores.dtype).min, attn_scores
|
||||
)
|
||||
attn_scores = attn_scores.view(
|
||||
batch_size, num_attention_heads, query_length, key_length
|
||||
)
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_scores, dim=-1)
|
||||
attn_weights = attn_weights.to(value.dtype)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attn_weights = attn_weights * head_mask
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class RotaryEmbedding(torch.nn.Module):
|
||||
def __init__(self, dim, max_position_embeddings, base=10000, device=None):
|
||||
super().__init__()
|
||||
self.true_inv_freq = 1.0 / (
|
||||
base ** (torch.arange(0, dim, 2).float().to(device) / dim)
|
||||
)
|
||||
self.register_buffer("inv_freq", self.true_inv_freq)
|
||||
|
||||
# Build here to make `torch.jit.trace` work.
|
||||
self.max_seq_len_cached = max_position_embeddings
|
||||
self.cos_cached = None
|
||||
self.sin_cached = None
|
||||
|
||||
@staticmethod
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
@staticmethod
|
||||
def _create_cos_sin(inv_freq, max_position_embeddings, dtype, device):
|
||||
t = torch.arange(
|
||||
max_position_embeddings, device=inv_freq.device, dtype=inv_freq.dtype
|
||||
)
|
||||
freqs = torch.einsum("i,j->ij", t, inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
return emb.cos().to(device).to(dtype), emb.sin().to(device).to(dtype)
|
||||
|
||||
def forward(self, q, k, position_ids, seq_len=None):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
if (
|
||||
seq_len > self.max_seq_len_cached
|
||||
or self.cos_cached is None
|
||||
or self.sin_cached is None
|
||||
):
|
||||
if seq_len > self.max_seq_len_cached:
|
||||
self.max_seq_len_cached = seq_len
|
||||
self.cos_cached, self.sin_cached = self._create_cos_sin(
|
||||
self.true_inv_freq, self.max_seq_len_cached, q.dtype, q.device
|
||||
)
|
||||
return rotary_forward(q, k, self.cos_cached, self.sin_cached, position_ids)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def rotary_forward(q, k, cos, sin, position_ids):
|
||||
cos = cos[position_ids].unsqueeze(1)
|
||||
sin = sin[position_ids].unsqueeze(1)
|
||||
|
||||
chunk_size = q.shape[-1] // 2
|
||||
q1, q2 = q.split(chunk_size, -1)
|
||||
q_rotated = torch.cat((-q2, q1), dim=-1)
|
||||
k1, k2 = k.split(chunk_size, -1)
|
||||
k_rotated = torch.cat((-k2, k1), dim=-1)
|
||||
|
||||
q_embed = (q * cos) + (q_rotated * sin)
|
||||
k_embed = (k * cos) + (k_rotated * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class GPTNeoXMLP(nn.Module):
|
||||
def __init__(self, config, prefix, weights):
|
||||
super().__init__()
|
||||
self.act = (
|
||||
ACT2FN[config.hidden_act]
|
||||
if "gelu_fast" not in config.hidden_act
|
||||
else lambda x: torch.nn.functional.gelu(x, approximate="tanh")
|
||||
)
|
||||
|
||||
self.dense_h_to_4h = TensorParallelColumnLinear.load(
|
||||
config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=True
|
||||
)
|
||||
self.dense_4h_to_h = TensorParallelRowLinear.load(
|
||||
config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=True
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.dense_h_to_4h(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.dense_4h_to_h(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GPTNeoXLayer(nn.Module):
|
||||
def __init__(self, layer_id, config, weights):
|
||||
super().__init__()
|
||||
self.use_parallel_residual = config.use_parallel_residual
|
||||
self.input_layernorm = nn.LayerNorm.load(
|
||||
prefix=f"gpt_neox.layers.{layer_id}.input_layernorm",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_eps,
|
||||
)
|
||||
self.post_attention_layernorm = nn.LayerNorm.load(
|
||||
prefix=f"gpt_neox.layers.{layer_id}.post_attention_layernorm",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_eps,
|
||||
)
|
||||
self.attention = GPTNeoXAttention(
|
||||
config, prefix=f"gpt_neox.layers.{layer_id}.attention", weights=weights
|
||||
)
|
||||
self.mlp = GPTNeoXMLP(
|
||||
config, prefix=f"gpt_neox.layers.{layer_id}.mlp", weights=weights
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
position_ids,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
use_cache=False,
|
||||
layer_past=None,
|
||||
output_attentions=False,
|
||||
):
|
||||
attention_layer_outputs = self.attention(
|
||||
self.input_layernorm(hidden_states),
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
layer_past=layer_past,
|
||||
head_mask=head_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
attn_output = attention_layer_outputs[
|
||||
0
|
||||
] # output_attn: attn_output, present, (attn_weights)
|
||||
outputs = attention_layer_outputs[1:]
|
||||
|
||||
if self.use_parallel_residual:
|
||||
# pseudocode:
|
||||
# x = x + attn(ln1(x)) + mlp(ln2(x))
|
||||
mlp_output = self.mlp(self.post_attention_layernorm(hidden_states))
|
||||
hidden_states = mlp_output + attn_output + hidden_states
|
||||
else:
|
||||
# pseudocode:
|
||||
# x = x + attn(ln1(x))
|
||||
# x = x + mlp(ln2(x))
|
||||
attn_output = attn_output + hidden_states
|
||||
mlp_output = self.mlp(self.post_attention_layernorm(attn_output))
|
||||
hidden_states = mlp_output + attn_output
|
||||
|
||||
if use_cache:
|
||||
outputs = (
|
||||
hidden_states,
|
||||
) + outputs # hidden_states, present, (attn_weights)
|
||||
else:
|
||||
outputs = (hidden_states,) + outputs[1:] # hidden_states, (attn_weights)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class GPTNeoXModel(GPTNeoXPreTrainedModel):
|
||||
def __init__(self, config, weights):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
|
||||
self.embed_in = TensorParallelEmbedding(
|
||||
prefix="gpt_neox.embed_in", weights=weights
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
GPTNeoXLayer(layer_id, config, weights)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.final_layer_norm = nn.LayerNorm.load(
|
||||
prefix="gpt_neox.final_layer_norm",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_eps,
|
||||
)
|
||||
self.tp_world_size = weights.process_group.size()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids=None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
r"""
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
||||
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
||||
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
||||
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||||
`past_key_values`).
|
||||
"""
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both input_ids and inputs_embeds at the same time"
|
||||
)
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
batch_size, seq_length = input_shape
|
||||
|
||||
if past_key_values is None:
|
||||
past_length = 0
|
||||
past_key_values = tuple([None] * self.config.num_hidden_layers)
|
||||
else:
|
||||
past_length = past_key_values[0][0].size(-2)
|
||||
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_length, seq_length + past_length, dtype=torch.long, device=device
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_in(input_ids)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# Attention mask.
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
if past_key_values[0] is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[-1]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, seq_length_with_past), device=hidden_states.device
|
||||
)
|
||||
else:
|
||||
attention_mask = attention_mask.to(hidden_states.device)
|
||||
|
||||
causal_mask = prepare_attn_mask(
|
||||
attention_mask,
|
||||
input_shape=(batch_size, seq_length),
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
|
||||
assert self.num_attention_heads % self.tp_world_size == 0
|
||||
block_size = self.num_attention_heads // self.tp_world_size
|
||||
causal_mask = torch.repeat_interleave(causal_mask, block_size, dim=0)
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
presents = () if use_cache else None
|
||||
all_attentions = () if output_attentions else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
for i, (layer, layer_past) in enumerate(zip(self.layers, past_key_values)):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
outputs = layer(
|
||||
hidden_states,
|
||||
position_ids=position_ids,
|
||||
attention_mask=causal_mask,
|
||||
head_mask=head_mask[i],
|
||||
layer_past=layer_past,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
hidden_states = outputs[0]
|
||||
if use_cache is True:
|
||||
presents = presents + (outputs[1],)
|
||||
if output_attentions:
|
||||
all_attentions = all_attentions + (outputs[2 if use_cache else 1],)
|
||||
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
# Add last hidden state
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [hidden_states, presents, all_hidden_states, all_attentions]
|
||||
if v is not None
|
||||
)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=presents,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_attentions,
|
||||
)
|
||||
|
||||
|
||||
class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
||||
|
||||
def __init__(self, config, weights):
|
||||
super().__init__(config)
|
||||
self.gpt_neox = GPTNeoXModel(config, weights)
|
||||
self.embed_out = TensorParallelHead.load(
|
||||
config, prefix="embed_out", weights=weights
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
||||
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional tensors are
|
||||
only required when the model is used as a decoder in a Sequence to Sequence model.
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks that can be used (see
|
||||
`past_key_values` input) to speed up sequential decoding.
|
||||
|
||||
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
||||
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
||||
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
|
||||
`[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
|
||||
ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||||
`past_key_values`).
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, GPTNeoXForCausalLM, GPTNeoXConfig
|
||||
>>> import torch
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
|
||||
>>> config = GPTNeoXConfig.from_pretrained("EleutherAI/gpt-neox-20b")
|
||||
>>> config.is_decoder = True
|
||||
>>> model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/gpt-neox-20b", config=config)
|
||||
|
||||
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
||||
>>> outputs = model(**inputs)
|
||||
|
||||
>>> prediction_logits = outputs.logits
|
||||
```"""
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
outputs = self.gpt_neox(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
lm_logits = self.embed_out(hidden_states)
|
||||
|
||||
lm_loss = None
|
||||
if labels is not None:
|
||||
# move labels to correct device to enable model parallelism
|
||||
labels = labels.to(lm_logits.device)
|
||||
# we are doing next-token prediction; shift prediction scores and input ids by one
|
||||
shift_logits = lm_logits[:, :-1, :].contiguous()
|
||||
labels = labels[:, 1:].contiguous()
|
||||
loss_fct = CrossEntropyLoss()
|
||||
lm_loss = loss_fct(
|
||||
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + outputs[1:]
|
||||
return ((lm_loss,) + output) if lm_loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=lm_loss,
|
||||
logits=lm_logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
inputs_embeds=None,
|
||||
**kwargs,
|
||||
):
|
||||
input_shape = input_ids.shape
|
||||
|
||||
# cut decoder_input_ids if past is used
|
||||
if past_key_values and past_key_values[0] is not None:
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if past_key_values:
|
||||
position_ids = position_ids[:, -1].unsqueeze(-1)
|
||||
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(input_shape)
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_key_values is None:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past_key_values,
|
||||
"position_ids": position_ids,
|
||||
}
|
||||
)
|
||||
|
||||
return model_inputs
|
||||
|
||||
def _reorder_cache(self, past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (
|
||||
tuple(
|
||||
past_state.index_select(0, beam_idx)
|
||||
for past_state in layer_past[:2]
|
||||
)
|
||||
+ layer_past[2:],
|
||||
)
|
||||
return reordered_past
|
|
@ -0,0 +1,837 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" PyTorch OPT model."""
|
||||
import random
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
)
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers import OPTConfig
|
||||
from text_generation_server.utils.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelHead,
|
||||
)
|
||||
|
||||
EPS = 1e-5
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
||||
def _make_causal_mask(
|
||||
input_ids_shape: torch.Size,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
past_key_values_length: int = 0,
|
||||
):
|
||||
"""
|
||||
Make causal mask used for bi-directional self-attention.
|
||||
"""
|
||||
bsz, tgt_len = input_ids_shape
|
||||
mask = torch.full(
|
||||
(tgt_len, tgt_len),
|
||||
torch.tensor(torch.finfo(dtype).min, device=device),
|
||||
device=device,
|
||||
)
|
||||
mask_cond = torch.arange(mask.size(-1), device=device)
|
||||
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
||||
mask = mask.to(dtype)
|
||||
|
||||
if past_key_values_length > 0:
|
||||
mask = torch.cat(
|
||||
[
|
||||
torch.zeros(
|
||||
tgt_len, past_key_values_length, dtype=dtype, device=device
|
||||
),
|
||||
mask,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
return mask[None, None, :, :].expand(
|
||||
bsz, 1, tgt_len, tgt_len + past_key_values_length
|
||||
)
|
||||
|
||||
|
||||
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
||||
"""
|
||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||
"""
|
||||
bsz, src_len = mask.size()
|
||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||
|
||||
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
||||
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
|
||||
return inverted_mask.masked_fill(
|
||||
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
||||
)
|
||||
|
||||
|
||||
class OPTLearnedPositionalEmbedding(nn.Module):
|
||||
"""
|
||||
This module learns positional embeddings up to a fixed maximum size.
|
||||
"""
|
||||
|
||||
def __init__(self, weights):
|
||||
super().__init__()
|
||||
self.offset = 2
|
||||
self.weight = nn.Parameter(
|
||||
weights.get_tensor("model.decoder.embed_positions.weight")
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, attention_mask: torch.LongTensor, past_key_values_length: int = 0
|
||||
):
|
||||
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
|
||||
attention_mask = attention_mask.long()
|
||||
|
||||
# create positions depending on attention_mask
|
||||
positions = (
|
||||
torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask
|
||||
).long() - 1
|
||||
|
||||
# cut positions if `past_key_values_length` is > 0
|
||||
positions = positions[:, past_key_values_length:]
|
||||
|
||||
return torch.nn.functional.embedding(positions + self.offset, self.weight)
|
||||
|
||||
|
||||
class OPTAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
prefix,
|
||||
weights,
|
||||
is_decoder: bool = False,
|
||||
bias: bool = True,
|
||||
process_group=None,
|
||||
):
|
||||
super().__init__()
|
||||
embed_dim = config.embed_dim
|
||||
num_heads = config.num_attention_heads
|
||||
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.dropout = config.dropout
|
||||
self.head_dim = embed_dim // num_heads
|
||||
|
||||
if (self.head_dim * num_heads) != self.embed_dim:
|
||||
raise ValueError(
|
||||
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
|
||||
f" and `num_heads`: {num_heads})."
|
||||
)
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.is_decoder = is_decoder
|
||||
|
||||
process_group = weights.process_group
|
||||
assert self.num_heads % process_group.size() == 0
|
||||
self.num_heads = self.num_heads // process_group.size()
|
||||
self.embed_dim = self.embed_dim // process_group.size()
|
||||
|
||||
self.q_proj = TensorParallelColumnLinear.load(
|
||||
config, prefix=f"{prefix}.q_proj", weights=weights, bias=bias
|
||||
)
|
||||
self.k_proj = TensorParallelColumnLinear.load(
|
||||
config, prefix=f"{prefix}.k_proj", weights=weights, bias=bias
|
||||
)
|
||||
self.v_proj = TensorParallelColumnLinear.load(
|
||||
config, prefix=f"{prefix}.v_proj", weights=weights, bias=bias
|
||||
)
|
||||
self.out_proj = TensorParallelRowLinear.load(
|
||||
config, prefix=f"{prefix}.out_proj", weights=weights, bias=bias
|
||||
)
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return (
|
||||
tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
key_value_states: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states) * self.scaling
|
||||
# get key, value proj
|
||||
if is_cross_attention and past_key_value is not None:
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value[0]
|
||||
value_states = past_key_value[1]
|
||||
elif is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
||||
elif past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
else:
|
||||
# self_attention
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_states, value_states)
|
||||
|
||||
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
||||
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
||||
key_states = key_states.view(*proj_shape)
|
||||
value_states = value_states.view(*proj_shape)
|
||||
|
||||
src_len = key_states.size(1)
|
||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||
|
||||
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = (
|
||||
attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
+ attention_mask
|
||||
)
|
||||
attn_weights = torch.max(
|
||||
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
|
||||
)
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
# upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
|
||||
if attn_weights.dtype == torch.float16:
|
||||
attn_weights = nn.functional.softmax(
|
||||
attn_weights, dim=-1, dtype=torch.float32
|
||||
).to(torch.float16)
|
||||
else:
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
if layer_head_mask is not None:
|
||||
if layer_head_mask.size() != (self.num_heads,):
|
||||
raise ValueError(
|
||||
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
|
||||
f" {layer_head_mask.size()}"
|
||||
)
|
||||
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(
|
||||
bsz, self.num_heads, tgt_len, src_len
|
||||
)
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
if output_attentions:
|
||||
# this operation is a bit awkward, but it's required to
|
||||
# make sure that attn_weights keeps its gradient.
|
||||
# In order to do so, attn_weights have to be reshaped
|
||||
# twice and have to be reused in the following
|
||||
attn_weights_reshaped = attn_weights.view(
|
||||
bsz, self.num_heads, tgt_len, src_len
|
||||
)
|
||||
attn_weights = attn_weights_reshaped.view(
|
||||
bsz * self.num_heads, tgt_len, src_len
|
||||
)
|
||||
else:
|
||||
attn_weights_reshaped = None
|
||||
|
||||
attn_probs = nn.functional.dropout(
|
||||
attn_weights, p=self.dropout, training=self.training
|
||||
)
|
||||
|
||||
attn_output = torch.bmm(attn_probs, value_states)
|
||||
|
||||
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
|
||||
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
||||
# partitioned aross GPUs when using tensor-parallelism.
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights_reshaped, past_key_value
|
||||
|
||||
|
||||
class OPTDecoderLayer(nn.Module):
|
||||
def __init__(self, layer_id: int, config: OPTConfig, weights):
|
||||
super().__init__()
|
||||
self.process_group = weights.process_group
|
||||
self.embed_dim = config.hidden_size
|
||||
prefix = f"model.decoder.layers.{layer_id}"
|
||||
self.self_attn = OPTAttention(
|
||||
config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
weights=weights,
|
||||
is_decoder=True,
|
||||
bias=config.enable_bias,
|
||||
)
|
||||
self.do_layer_norm_before = config.do_layer_norm_before
|
||||
self.dropout = config.dropout
|
||||
self.activation_fn = ACT2FN[config.activation_function]
|
||||
|
||||
self.self_attn_layer_norm = nn.LayerNorm.load(
|
||||
prefix=f"{prefix}.self_attn_layer_norm", weights=weights, eps=EPS
|
||||
)
|
||||
self.fc1 = TensorParallelColumnLinear.load(
|
||||
config, prefix=f"{prefix}.fc1", weights=weights, bias=config.enable_bias
|
||||
)
|
||||
self.fc2 = TensorParallelRowLinear.load(
|
||||
config, prefix=f"{prefix}.fc2", weights=weights, bias=config.enable_bias
|
||||
)
|
||||
self.final_layer_norm = nn.LayerNorm.load(
|
||||
prefix=f"{prefix}.final_layer_norm", weights=weights, eps=EPS
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
) -> Tuple[
|
||||
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
|
||||
]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size
|
||||
`(encoder_attention_heads,)`.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||
(see `past_key_values`).
|
||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
"""
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
|
||||
if self.do_layer_norm_before:
|
||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
past_key_value=past_key_value,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
hidden_states = nn.functional.dropout(
|
||||
hidden_states, p=self.dropout, training=self.training
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# 350m applies layer norm AFTER attention
|
||||
if not self.do_layer_norm_before:
|
||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states_shape = hidden_states.shape
|
||||
hidden_states = hidden_states.reshape(-1, hidden_states.size(-1))
|
||||
residual = hidden_states
|
||||
|
||||
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
|
||||
if self.do_layer_norm_before:
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
|
||||
hidden_states = self.fc1(hidden_states)
|
||||
hidden_states = self.activation_fn(hidden_states)
|
||||
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
hidden_states = nn.functional.dropout(
|
||||
hidden_states, p=self.dropout, training=self.training
|
||||
)
|
||||
|
||||
hidden_states = (residual + hidden_states).view(hidden_states_shape)
|
||||
|
||||
# 350m applies layer norm AFTER attention
|
||||
if not self.do_layer_norm_before:
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
if use_cache:
|
||||
outputs += (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class OPTPreTrainedModel(PreTrainedModel):
|
||||
config_class = OPTConfig
|
||||
|
||||
|
||||
class OPTDecoder(OPTPreTrainedModel):
|
||||
def __init__(self, config: OPTConfig, weights):
|
||||
super().__init__(config)
|
||||
self.dropout = config.dropout
|
||||
self.layerdrop = config.layerdrop
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.max_target_positions = config.max_position_embeddings
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix="model.decoder.embed_tokens", weights=weights
|
||||
)
|
||||
self.embed_positions = OPTLearnedPositionalEmbedding(weights)
|
||||
|
||||
if config.word_embed_proj_dim != config.hidden_size:
|
||||
self.project_out = FastLinear.load(
|
||||
config, prefix="model.decoder.project_out", bias=False
|
||||
)
|
||||
else:
|
||||
self.project_out = None
|
||||
|
||||
if config.word_embed_proj_dim != config.hidden_size:
|
||||
self.project_in = FastLinear.load(
|
||||
config, prefix="model.decoder.project_in", bias=False
|
||||
)
|
||||
else:
|
||||
self.project_in = None
|
||||
|
||||
# Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
|
||||
# with checkpoints that have been fine-tuned before transformers v4.20.1
|
||||
# see https://github.com/facebookresearch/metaseq/pull/164
|
||||
if config.do_layer_norm_before and not config._remove_final_layer_norm:
|
||||
self.final_layer_norm = nn.LayerNorm.load(
|
||||
prefix="model.decoder.final_layer_norm", weights=weights, eps=EPS
|
||||
)
|
||||
else:
|
||||
self.final_layer_norm = None
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
OPTDecoderLayer(layer_id, config, weights)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
|
||||
def _prepare_decoder_attention_mask(
|
||||
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||
):
|
||||
# create causal mask
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
combined_attention_mask = None
|
||||
if input_shape[-1] > 1:
|
||||
combined_attention_mask = _make_causal_mask(
|
||||
input_shape,
|
||||
inputs_embeds.dtype,
|
||||
device=inputs_embeds.device,
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
expanded_attn_mask = _expand_mask(
|
||||
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
||||
).to(inputs_embeds.device)
|
||||
combined_attention_mask = (
|
||||
expanded_attn_mask
|
||||
if combined_attention_mask is None
|
||||
else expanded_attn_mask + combined_attention_mask
|
||||
)
|
||||
|
||||
return combined_attention_mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
||||
provide it.
|
||||
|
||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
||||
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
|
||||
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
||||
|
||||
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
|
||||
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
|
||||
all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
||||
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||
than the model's internal embedding lookup matrix.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
||||
for more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
||||
)
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
else:
|
||||
raise ValueError(
|
||||
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
batch_size, seq_length = input_shape
|
||||
past_key_values_length = (
|
||||
past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||
)
|
||||
# required mask seq length can be calculated via length of past
|
||||
mask_seq_length = past_key_values_length + seq_length
|
||||
|
||||
# embed positions
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
batch_size, mask_seq_length, device=inputs_embeds.device
|
||||
)
|
||||
causal_attention_mask = self._prepare_decoder_attention_mask(
|
||||
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||
)
|
||||
pos_embeds = self.embed_positions(attention_mask, past_key_values_length)
|
||||
|
||||
if self.project_in is not None:
|
||||
inputs_embeds = self.project_in(inputs_embeds)
|
||||
|
||||
hidden_states = inputs_embeds + pos_embeds
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
for attn_mask, mask_name in zip([head_mask], ["head_mask"]):
|
||||
if attn_mask is not None:
|
||||
if attn_mask.size()[0] != (len(self.layers)):
|
||||
raise ValueError(
|
||||
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
|
||||
f" {head_mask.size()[0]}."
|
||||
)
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
dropout_probability = random.uniform(0, 1)
|
||||
if self.training and (dropout_probability < self.layerdrop):
|
||||
continue
|
||||
|
||||
past_key_value = (
|
||||
past_key_values[idx] if past_key_values is not None else None
|
||||
)
|
||||
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
if self.final_layer_norm is not None:
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
|
||||
if self.project_out is not None:
|
||||
hidden_states = self.project_out(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
class OPTModel(OPTPreTrainedModel):
|
||||
def __init__(self, config: OPTConfig, weights):
|
||||
super().__init__(config)
|
||||
self.decoder = OPTDecoder(config, weights)
|
||||
# Initialize weights and apply final processing
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
|
||||
decoder_outputs = self.decoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return decoder_outputs
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=decoder_outputs.last_hidden_state,
|
||||
past_key_values=decoder_outputs.past_key_values,
|
||||
hidden_states=decoder_outputs.hidden_states,
|
||||
attentions=decoder_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
class OPTForCausalLM(OPTPreTrainedModel):
|
||||
def __init__(self, config, weights):
|
||||
super().__init__(config)
|
||||
|
||||
self.model = OPTModel(config, weights)
|
||||
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
config, prefix="model.decoder.embed_tokens", weights=weights
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model.decoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
logits = self.lm_head(outputs[0]).contiguous()
|
||||
|
||||
loss = None
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
inputs_embeds=None,
|
||||
**kwargs,
|
||||
):
|
||||
if past_key_values:
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_key_values is None:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past_key_values, beam_idx):
|
||||
reordered_past = ()
|
||||
for layer_past in past_key_values:
|
||||
reordered_past += (
|
||||
tuple(
|
||||
past_state.index_select(0, beam_idx) for past_state in layer_past
|
||||
),
|
||||
)
|
||||
return reordered_past
|
File diff suppressed because it is too large
Load Diff
|
@ -1,154 +1,25 @@
|
|||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from accelerate import init_empty_weights
|
||||
from opentelemetry import trace
|
||||
from pathlib import Path
|
||||
from safetensors import safe_open
|
||||
from transformers import AutoConfig
|
||||
from transformers.models.llama import LlamaTokenizer
|
||||
from typing import Optional, List
|
||||
from typing import Optional
|
||||
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
||||
FlashLlamaForCausalLM,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
download_weights,
|
||||
weight_hub_files,
|
||||
LocalEntryNotFoundError,
|
||||
Weights,
|
||||
)
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
class FlashLlama(FlashCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.float16
|
||||
else:
|
||||
raise NotImplementedError("FlashLlama is only available on GPU")
|
||||
|
||||
tokenizer = LlamaTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
|
||||
# We do not use from_pretrained as we modified the model internal module layout
|
||||
try:
|
||||
filenames = weight_files(model_id, revision, ".bin")
|
||||
# Local files not found
|
||||
except LocalEntryNotFoundError:
|
||||
hub_files = weight_hub_files(model_id, revision, ".bin")
|
||||
filenames = download_weights(hub_files, model_id, revision)
|
||||
|
||||
with init_empty_weights():
|
||||
model = FlashLlamaForCausalLM(config)
|
||||
|
||||
self.load_weights(model, filenames, quantize, device, dtype)
|
||||
|
||||
super(FlashCausalLM, self).__init__(
|
||||
model=model.to(device),
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=False,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def load_weights(
|
||||
model,
|
||||
filenames: List[Path],
|
||||
quantize: Optional[str],
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
for filename in filenames:
|
||||
state_dict = torch.load(filename, map_location="cpu")
|
||||
for key, value in state_dict.items():
|
||||
value = value.to(device if quantize is None else "cpu").to(dtype)
|
||||
|
||||
layer_name = ".".join(key.split(".")[:4])
|
||||
|
||||
# Fused qkv
|
||||
if "q_proj" in key or "k_proj" in key or "v_proj" in key:
|
||||
final_key = layer_name + ".query_key_value.weight"
|
||||
|
||||
# Fused gate and up projs
|
||||
elif "gate_proj" in key or "up_proj" in key:
|
||||
final_key = layer_name + ".gate_up_proj.weight"
|
||||
else:
|
||||
final_key = key
|
||||
|
||||
module_name, param_name = final_key.rsplit(".", 1)
|
||||
module = model.get_submodule(module_name)
|
||||
|
||||
try:
|
||||
current_parameter_tensor = module._parameters[param_name]
|
||||
except KeyError:
|
||||
current_parameter_tensor = None
|
||||
|
||||
if current_parameter_tensor is not None:
|
||||
if current_parameter_tensor.device == torch.device("meta"):
|
||||
# Init qkv
|
||||
if "query_key_value" in final_key:
|
||||
module._parameters[param_name] = value.new_empty(
|
||||
(value.shape[0] * 3, value.shape[1])
|
||||
)
|
||||
# Init gate and up proj
|
||||
elif "gate_up_proj" in final_key:
|
||||
module._parameters[param_name] = value.new_empty(
|
||||
(value.shape[0] * 2, value.shape[1])
|
||||
)
|
||||
|
||||
# Copy to correct slice
|
||||
if "q_proj" in key:
|
||||
module._parameters[param_name][: value.shape[0]] = value
|
||||
elif "k_proj" in key:
|
||||
module._parameters[param_name][
|
||||
value.shape[0] : value.shape[0] * 2
|
||||
] = value
|
||||
elif "v_proj" in key:
|
||||
module._parameters[param_name][value.shape[0] * 2 :] = value
|
||||
elif "gate_proj" in key:
|
||||
module._parameters[param_name][: value.shape[0]] = value
|
||||
elif "up_proj" in key:
|
||||
module._parameters[param_name][value.shape[0] :] = value
|
||||
else:
|
||||
if current_parameter_tensor.shape != value.shape:
|
||||
raise ValueError(
|
||||
f"Name {final_key} -- Current {current_parameter_tensor.shape} and got {value.shape}"
|
||||
)
|
||||
module._parameters[param_name] = value
|
||||
else:
|
||||
module._buffers[param_name] = value
|
||||
|
||||
del value
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
model.post_load_weights(quantize)
|
||||
|
||||
|
||||
class FlashLlamaSharded(FlashLlama):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
|
@ -176,24 +47,16 @@ class FlashLlamaSharded(FlashLlama):
|
|||
)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
|
||||
with init_empty_weights():
|
||||
model = FlashLlamaForCausalLM(config, process_group=self.process_group)
|
||||
config.quantize = quantize
|
||||
model = FlashLlamaForCausalLM(config, weights)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
self.load_weights(
|
||||
model,
|
||||
filenames,
|
||||
quantize=quantize,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashCausalLM, self).__init__(
|
||||
model=model.to(device),
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=False,
|
||||
dtype=dtype,
|
||||
|
@ -201,114 +64,3 @@ class FlashLlamaSharded(FlashLlama):
|
|||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def load_weights(
|
||||
model,
|
||||
filenames: List[str],
|
||||
quantize: Optional[str],
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
):
|
||||
for file in filenames:
|
||||
with safe_open(
|
||||
file, framework="pt", device=str(device) if quantize is None else "cpu"
|
||||
) as f:
|
||||
for name in f.keys():
|
||||
slice_ = f.get_slice(name)
|
||||
|
||||
layer_name = ".".join(name.split(".")[:4])
|
||||
|
||||
# Fused qkv
|
||||
if "q_proj" in name or "k_proj" in name or "v_proj" in name:
|
||||
final_name = layer_name + ".query_key_value.weight"
|
||||
|
||||
# Fused gate and up projs
|
||||
elif "gate_proj" in name or "up_proj" in name:
|
||||
final_name = layer_name + ".gate_up_proj.weight"
|
||||
else:
|
||||
final_name = name
|
||||
|
||||
module_name, param_name = final_name.rsplit(".", 1)
|
||||
module = model.get_submodule(module_name)
|
||||
|
||||
if isinstance(module, TensorParallelColumnLinear):
|
||||
size = slice_.get_shape()[0]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[start:stop]
|
||||
elif isinstance(module, TensorParallelRowLinear):
|
||||
size = slice_.get_shape()[1]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[:, start:stop]
|
||||
elif isinstance(module, TensorParallelEmbedding):
|
||||
size = slice_.get_shape()[0]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[start:stop]
|
||||
elif name == "lm_head.weight" and model.model.tp_embeddings:
|
||||
size = slice_.get_shape()[0]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[start:stop]
|
||||
else:
|
||||
try:
|
||||
tensor = slice_[:]
|
||||
except:
|
||||
tensor = f.get_tensor(name)
|
||||
|
||||
tensor = tensor.contiguous().to(dtype)
|
||||
|
||||
try:
|
||||
current_parameter_tensor = module._parameters[param_name]
|
||||
except KeyError:
|
||||
current_parameter_tensor = None
|
||||
|
||||
if current_parameter_tensor is not None:
|
||||
if current_parameter_tensor.device == torch.device("meta"):
|
||||
# Init qkv
|
||||
if "query_key_value" in final_name:
|
||||
module._parameters[param_name] = tensor.new_empty(
|
||||
(tensor.shape[0] * 3, tensor.shape[1])
|
||||
)
|
||||
# Init gate and up proj
|
||||
elif "gate_up_proj" in final_name:
|
||||
module._parameters[param_name] = tensor.new_empty(
|
||||
(tensor.shape[0] * 2, tensor.shape[1])
|
||||
)
|
||||
|
||||
# Init gate and up proj
|
||||
if "q_proj" in name:
|
||||
module._parameters[param_name][: tensor.shape[0]] = tensor
|
||||
elif "k_proj" in name:
|
||||
module._parameters[param_name][
|
||||
tensor.shape[0] : tensor.shape[0] * 2
|
||||
] = tensor
|
||||
elif "v_proj" in name:
|
||||
module._parameters[param_name][
|
||||
tensor.shape[0] * 2 :
|
||||
] = tensor
|
||||
elif "gate_proj" in name:
|
||||
module._parameters[param_name][: tensor.shape[0]] = tensor
|
||||
elif "up_proj" in name:
|
||||
module._parameters[param_name][tensor.shape[0] :] = tensor
|
||||
else:
|
||||
if current_parameter_tensor.shape != tensor.shape:
|
||||
raise ValueError(
|
||||
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
|
||||
)
|
||||
|
||||
module._parameters[param_name] = tensor
|
||||
|
||||
else:
|
||||
module._buffers[param_name] = tensor
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
model.post_load_weights(quantize)
|
||||
|
|
|
@ -1,45 +1,24 @@
|
|||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from accelerate import init_empty_weights
|
||||
from opentelemetry import trace
|
||||
from safetensors import safe_open
|
||||
from transformers import AutoTokenizer, AutoConfig
|
||||
from typing import Optional, List
|
||||
from typing import Optional
|
||||
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
|
||||
FlashGPTNeoXForCausalLM,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
class FlashNeoX(FlashCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
super(FlashNeoX, self).__init__(
|
||||
FlashGPTNeoXForCausalLM,
|
||||
model_id,
|
||||
revision,
|
||||
quantize,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
|
||||
class FlashNeoXSharded(FlashNeoX):
|
||||
class FlashNeoXSharded(FlashCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
|
@ -65,23 +44,16 @@ class FlashNeoXSharded(FlashNeoX):
|
|||
config = AutoConfig.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
|
||||
with init_empty_weights():
|
||||
model = FlashGPTNeoXForCausalLM(config, self.process_group)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
self.load_weights(
|
||||
model,
|
||||
filenames,
|
||||
quantize=quantize,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
weights = Weights(
|
||||
filenames, device=device, dtype=dtype, process_group=self.process_group
|
||||
)
|
||||
|
||||
model = FlashGPTNeoXForCausalLM(config, weights)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashCausalLM, self).__init__(
|
||||
model=model.to(device),
|
||||
|
@ -92,79 +64,3 @@ class FlashNeoXSharded(FlashNeoX):
|
|||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def load_weights(
|
||||
model,
|
||||
filenames: List[str],
|
||||
quantize: Optional[str],
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
):
|
||||
parameters = dict(model.named_parameters())
|
||||
for file in filenames:
|
||||
with safe_open(
|
||||
file, framework="pt", device=str(device) if quantize is None else "cpu"
|
||||
) as f:
|
||||
for name in f.keys():
|
||||
module_name, param_name = name.rsplit(".", 1)
|
||||
module = model.get_submodule(module_name)
|
||||
|
||||
current_parameter_tensor = parameters.get(name, None)
|
||||
|
||||
slice_ = f.get_slice(name)
|
||||
|
||||
if isinstance(module, TensorParallelColumnLinear):
|
||||
size = slice_.get_shape()[0]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[start:stop]
|
||||
elif isinstance(module, TensorParallelRowLinear):
|
||||
if param_name == "weight":
|
||||
size = slice_.get_shape()[1]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[:, start:stop]
|
||||
else:
|
||||
tensor = slice_[:]
|
||||
# XXX: Hack for Rowlinear to add the bias only once.
|
||||
if rank != 0:
|
||||
tensor = torch.zeros_like(tensor)
|
||||
elif isinstance(module, TensorParallelEmbedding):
|
||||
size = slice_.get_shape()[0]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[start:stop]
|
||||
elif name == "embed_out.weight" and model.gpt_neox.tp_embeddings:
|
||||
size = slice_.get_shape()[0]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[start:stop]
|
||||
else:
|
||||
try:
|
||||
tensor = slice_[:]
|
||||
except:
|
||||
tensor = f.get_tensor(name)
|
||||
|
||||
if (
|
||||
current_parameter_tensor is not None
|
||||
and current_parameter_tensor.shape != tensor.shape
|
||||
):
|
||||
raise ValueError(
|
||||
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
|
||||
)
|
||||
|
||||
tensor = tensor.contiguous().to(dtype)
|
||||
|
||||
if current_parameter_tensor is not None:
|
||||
module._parameters[param_name] = tensor
|
||||
else:
|
||||
module._buffers[param_name] = tensor
|
||||
|
||||
model.post_load_weights(quantize)
|
||||
|
|
|
@ -1,119 +1,25 @@
|
|||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from pathlib import Path
|
||||
from accelerate import init_empty_weights
|
||||
from opentelemetry import trace
|
||||
from safetensors import safe_open
|
||||
from transformers import AutoTokenizer, AutoConfig
|
||||
from typing import Optional, List
|
||||
from transformers import AutoTokenizer
|
||||
from typing import Optional
|
||||
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
from text_generation_server.models.custom_modeling.flash_rw_modeling import (
|
||||
RWConfig,
|
||||
FlashRWForCausalLM,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
download_weights,
|
||||
weight_hub_files,
|
||||
LocalEntryNotFoundError,
|
||||
Weights,
|
||||
)
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
class FlashRW(FlashCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.float16
|
||||
else:
|
||||
raise NotImplementedError("RW is only available on GPU")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = RWConfig.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
# We do not use from_pretrained as it is too slow
|
||||
try:
|
||||
filenames = weight_files(model_id, revision, ".bin")
|
||||
# Local files not found
|
||||
except LocalEntryNotFoundError:
|
||||
hub_files = weight_hub_files(model_id, revision, ".bin")
|
||||
filenames = download_weights(hub_files, model_id, revision)
|
||||
|
||||
with init_empty_weights():
|
||||
model = FlashRWForCausalLM(config)
|
||||
|
||||
self.load_weights(
|
||||
model,
|
||||
filenames,
|
||||
quantize,
|
||||
device,
|
||||
dtype,
|
||||
)
|
||||
|
||||
super(FlashCausalLM, self).__init__(
|
||||
model=model.to(device),
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=False,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def load_weights(
|
||||
model: FlashRWForCausalLM,
|
||||
filenames: List[Path],
|
||||
quantize: Optional[str],
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
for filename in filenames:
|
||||
state_dict = torch.load(filename, map_location="cpu")
|
||||
for key, value in state_dict.items():
|
||||
value = value.to(device if quantize is None else "cpu").to(dtype)
|
||||
|
||||
module_name, param_name = key.rsplit(".", 1)
|
||||
module = model.get_submodule(module_name)
|
||||
|
||||
try:
|
||||
current_parameter_tensor = module._parameters[param_name]
|
||||
if current_parameter_tensor.shape != value.shape:
|
||||
raise ValueError(
|
||||
f"Name {key} -- Current {current_parameter_tensor.shape} and got {value.shape}"
|
||||
)
|
||||
module._parameters[param_name] = value
|
||||
except KeyError:
|
||||
module._buffers[param_name] = value
|
||||
|
||||
del value
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
model.post_load_weights(quantize)
|
||||
|
||||
|
||||
class FlashRWSharded(FlashRW):
|
||||
class FlashRWSharded(FlashCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
|
@ -142,20 +48,12 @@ class FlashRWSharded(FlashRW):
|
|||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
|
||||
with init_empty_weights():
|
||||
model = FlashRWForCausalLM(config, self.process_group)
|
||||
config.quantize = quantize
|
||||
|
||||
model = FlashRWForCausalLM(config, weights)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
self.load_weights(
|
||||
model,
|
||||
filenames,
|
||||
quantize=quantize,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashCausalLM, self).__init__(
|
||||
model=model.to(device),
|
||||
|
@ -166,79 +64,3 @@ class FlashRWSharded(FlashRW):
|
|||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def load_weights(
|
||||
model,
|
||||
filenames: List[str],
|
||||
quantize: Optional[str],
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
):
|
||||
parameters = dict(model.named_parameters())
|
||||
for file in filenames:
|
||||
with safe_open(
|
||||
file, framework="pt", device=str(device) if quantize is None else "cpu"
|
||||
) as f:
|
||||
for name in f.keys():
|
||||
module_name, param_name = name.rsplit(".", 1)
|
||||
module = model.get_submodule(module_name)
|
||||
|
||||
current_parameter_tensor = parameters.get(name, None)
|
||||
|
||||
slice_ = f.get_slice(name)
|
||||
|
||||
if isinstance(module, TensorParallelColumnLinear):
|
||||
size = slice_.get_shape()[0]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[start:stop]
|
||||
elif isinstance(module, TensorParallelRowLinear):
|
||||
if param_name == "weight":
|
||||
size = slice_.get_shape()[1]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[:, start:stop]
|
||||
else:
|
||||
tensor = slice_[:]
|
||||
# XXX: Hack for Rowlinear to add the bias only once.
|
||||
if rank != 0:
|
||||
tensor = torch.zeros_like(tensor)
|
||||
elif isinstance(module, TensorParallelEmbedding):
|
||||
size = slice_.get_shape()[0]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[start:stop]
|
||||
elif name == "lm_head.weight" and model.transformer.tp_embeddings:
|
||||
size = slice_.get_shape()[0]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[start:stop]
|
||||
else:
|
||||
try:
|
||||
tensor = slice_[:]
|
||||
except:
|
||||
tensor = f.get_tensor(name)
|
||||
|
||||
if (
|
||||
current_parameter_tensor is not None
|
||||
and current_parameter_tensor.shape != tensor.shape
|
||||
):
|
||||
raise ValueError(
|
||||
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
|
||||
)
|
||||
|
||||
tensor = tensor.contiguous().to(dtype)
|
||||
|
||||
if current_parameter_tensor is not None:
|
||||
module._parameters[param_name] = tensor
|
||||
else:
|
||||
module._buffers[param_name] = tensor
|
||||
|
||||
model.post_load_weights(quantize)
|
||||
|
|
|
@ -1,197 +1,24 @@
|
|||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from accelerate import init_empty_weights
|
||||
from opentelemetry import trace
|
||||
from safetensors import safe_open
|
||||
from pathlib import Path
|
||||
from transformers import AutoTokenizer, GPT2Config
|
||||
from transformers import AutoTokenizer, AutoConfig
|
||||
from typing import Optional, List
|
||||
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
|
||||
FlashSantacoderForCausalLM,
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
download_weights,
|
||||
weight_hub_files,
|
||||
LocalEntryNotFoundError,
|
||||
Weights,
|
||||
)
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
class FlashSantacoder(FlashCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.float16
|
||||
else:
|
||||
raise NotImplementedError("FlashSantacoder is only available on GPU")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = GPT2Config.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
# We do not use from_pretrained as we modified the model internal module layout
|
||||
filenames = weight_files(model_id, revision, ".safetensors")
|
||||
|
||||
with init_empty_weights():
|
||||
model = FlashSantacoderForCausalLM(config)
|
||||
|
||||
self.load_weights(
|
||||
model,
|
||||
filenames,
|
||||
quantize,
|
||||
device,
|
||||
dtype,
|
||||
config.architectures[0].startswith("GPT2"),
|
||||
)
|
||||
|
||||
super(FlashCausalLM, self).__init__(
|
||||
model=model.to(device),
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=False,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def load_weights(
|
||||
model: FlashSantacoderForCausalLM,
|
||||
filenames: List[Path],
|
||||
quantize: Optional[str],
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
transpose: bool,
|
||||
):
|
||||
for filename in filenames:
|
||||
with safe_open(
|
||||
filename,
|
||||
framework="pt",
|
||||
device=str(device) if quantize is None else "cpu",
|
||||
) as f:
|
||||
for key in f.keys():
|
||||
value = f.get_tensor(key)
|
||||
value = value.to(device if quantize is None else "cpu").to(dtype)
|
||||
|
||||
layer_name = ".".join(key.split(".")[:4])
|
||||
|
||||
# Fused qkv
|
||||
if "q_attn.weight" in key or "kv_attn.weight" in key:
|
||||
final_key = layer_name + ".c_attn.weight"
|
||||
elif "q_attn.bias" in key or "kv_attn.bias" in key:
|
||||
final_key = layer_name + ".c_attn.bias"
|
||||
|
||||
else:
|
||||
final_key = key
|
||||
|
||||
module_name, param_name = final_key.rsplit(".", 1)
|
||||
module = model.get_submodule(module_name)
|
||||
|
||||
try:
|
||||
current_parameter_tensor = module._parameters[param_name]
|
||||
except KeyError:
|
||||
current_parameter_tensor = None
|
||||
|
||||
if current_parameter_tensor is not None:
|
||||
if transpose and (
|
||||
"c_fc.weight" in key
|
||||
or "c_proj.weight" in key
|
||||
or "q_attn.weight" in key
|
||||
or "kv_attn.weight" in key
|
||||
or "c_attn.weight" in key
|
||||
):
|
||||
# Tranpose as we use nn.Linear instead of Conv1D
|
||||
value = value.T
|
||||
|
||||
if current_parameter_tensor.device == torch.device("meta"):
|
||||
# Init qkv
|
||||
if "c_attn.weight" in final_key:
|
||||
module._parameters[param_name] = value.new_empty(
|
||||
(
|
||||
model.transformer.head_size
|
||||
* (model.transformer.num_heads + 2),
|
||||
value.shape[1],
|
||||
)
|
||||
)
|
||||
elif "c_attn.bias" in final_key:
|
||||
module._parameters[param_name] = value.new_empty(
|
||||
(
|
||||
model.transformer.head_size
|
||||
* (model.transformer.num_heads + 2)
|
||||
)
|
||||
)
|
||||
|
||||
# Copy to correct slice
|
||||
if "q_attn.weight" in key:
|
||||
module._parameters[param_name][: value.shape[0]] = value
|
||||
elif "q_attn.bias" in key:
|
||||
module._parameters[param_name][: value.shape[0]] = value
|
||||
elif "kv_attn.weight" in key:
|
||||
module._parameters[param_name][
|
||||
model.transformer.head_size
|
||||
* model.transformer.num_heads :
|
||||
] = value
|
||||
elif "kv_attn.bias" in key:
|
||||
module._parameters[param_name][
|
||||
model.transformer.head_size
|
||||
* model.transformer.num_heads :
|
||||
] = value
|
||||
else:
|
||||
if current_parameter_tensor.shape != value.shape:
|
||||
raise ValueError(
|
||||
f"Name {final_key} -- Current {current_parameter_tensor.shape} and got {value.shape}"
|
||||
)
|
||||
module._parameters[param_name] = value
|
||||
else:
|
||||
module._buffers[param_name] = value
|
||||
|
||||
del value
|
||||
|
||||
if model.lm_head.weight.device == torch.device("meta"):
|
||||
model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
model.post_load_weights(quantize)
|
||||
|
||||
uninitialized_parameters = []
|
||||
for n, p in model.named_parameters():
|
||||
if p.data.device == torch.device("meta"):
|
||||
uninitialized_parameters.append(n)
|
||||
if uninitialized_parameters:
|
||||
raise RuntimeError(
|
||||
f"found uninitialized parameters in model : {uninitialized_parameters}"
|
||||
)
|
||||
|
||||
def decode(self, generated_ids: List[int]) -> str:
|
||||
# Do not skip special tokens as they are used for custom parsing rules of the generated text
|
||||
return self.tokenizer.decode(
|
||||
generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
|
||||
)
|
||||
|
||||
|
||||
class FlashSantacoderSharded(FlashSantacoder):
|
||||
class FlashSantacoderSharded(FlashCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
|
@ -214,28 +41,22 @@ class FlashSantacoderSharded(FlashSantacoder):
|
|||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = GPT2Config.from_pretrained(
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.transpose = config.architectures[0].startswith("GPT2")
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
|
||||
with init_empty_weights():
|
||||
model = FlashSantacoderForCausalLM(config, self.process_group)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
self.load_weights(
|
||||
model,
|
||||
filenames,
|
||||
quantize=quantize,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
transpose=config.architectures[0].startswith("GPT2"),
|
||||
weights = Weights(
|
||||
filenames, device=device, dtype=dtype, process_group=self.process_group
|
||||
)
|
||||
|
||||
model = FlashSantacoderForCausalLM(config, weights)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashCausalLM, self).__init__(
|
||||
model=model.to(device),
|
||||
|
@ -247,164 +68,8 @@ class FlashSantacoderSharded(FlashSantacoder):
|
|||
world_size=world_size,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def load_weights(
|
||||
model,
|
||||
filenames: List[str],
|
||||
quantize: Optional[str],
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
transpose: bool,
|
||||
):
|
||||
for file in filenames:
|
||||
with safe_open(
|
||||
file, framework="pt", device=str(device) if quantize is None else "cpu"
|
||||
) as f:
|
||||
for key in f.keys():
|
||||
slice_ = f.get_slice(key)
|
||||
|
||||
layer_name = ".".join(key.split(".")[:4])
|
||||
|
||||
# Fused qkv
|
||||
if "q_attn.weight" in key or "kv_attn.weight" in key:
|
||||
final_key = layer_name + ".c_attn.weight"
|
||||
elif "q_attn.bias" in key or "kv_attn.bias" in key:
|
||||
final_key = layer_name + ".c_attn.bias"
|
||||
else:
|
||||
final_key = key
|
||||
|
||||
module_name, param_name = final_key.rsplit(".", 1)
|
||||
module = model.get_submodule(module_name)
|
||||
|
||||
if isinstance(module, TensorParallelColumnLinear):
|
||||
dim = 1 if transpose and "weight" in param_name else 0
|
||||
size = slice_.get_shape()[dim]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = (
|
||||
slice_[start:stop] if dim == 0 else slice_[:, start:stop]
|
||||
)
|
||||
elif isinstance(module, TensorParallelRowLinear):
|
||||
if param_name == "weight":
|
||||
dim = 0 if transpose else 1
|
||||
size = slice_.get_shape()[dim]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = (
|
||||
slice_[start:stop]
|
||||
if dim == 0
|
||||
else slice_[:, start:stop]
|
||||
)
|
||||
else:
|
||||
tensor = slice_[:]
|
||||
# XXX: Hack for Rowlinear to add the bias only once.
|
||||
if rank != 0:
|
||||
tensor = torch.zeros_like(tensor)
|
||||
elif isinstance(module, TensorParallelEmbedding):
|
||||
size = slice_.get_shape()[0]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[start:stop]
|
||||
elif key == "lm_head.weight" and model.transformer.tp_embeddings:
|
||||
size = slice_.get_shape()[0]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[start:stop]
|
||||
else:
|
||||
try:
|
||||
tensor = slice_[:]
|
||||
except:
|
||||
tensor = f.get_tensor(key)
|
||||
|
||||
tensor = tensor.contiguous().to(dtype)
|
||||
|
||||
try:
|
||||
current_parameter_tensor = module._parameters[param_name]
|
||||
except KeyError:
|
||||
current_parameter_tensor = None
|
||||
|
||||
if current_parameter_tensor is not None:
|
||||
if transpose and (
|
||||
"c_fc.weight" in key
|
||||
or "c_proj.weight" in key
|
||||
or "q_attn.weight" in key
|
||||
or "kv_attn.weight" in key
|
||||
or "c_attn.weight" in key
|
||||
):
|
||||
# Tranpose as we use nn.Linear instead of Conv1D
|
||||
tensor = tensor.T
|
||||
|
||||
if current_parameter_tensor.device == torch.device("meta"):
|
||||
# Init qkv
|
||||
if "c_attn.weight" in final_key:
|
||||
module._parameters[param_name] = tensor.new_empty(
|
||||
(
|
||||
model.transformer.head_size
|
||||
* (model.transformer.num_heads + 2),
|
||||
tensor.shape[1],
|
||||
)
|
||||
)
|
||||
elif "c_attn.bias" in final_key:
|
||||
module._parameters[param_name] = tensor.new_empty(
|
||||
(
|
||||
model.transformer.head_size
|
||||
* (model.transformer.num_heads + 2)
|
||||
)
|
||||
)
|
||||
|
||||
# Copy to correct slice
|
||||
if "q_attn" in key:
|
||||
size = tensor.shape[0]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = tensor[start:stop]
|
||||
module._parameters[param_name][: tensor.shape[0]] = tensor
|
||||
elif "kv_attn.weight" in key:
|
||||
module._parameters[param_name][
|
||||
model.transformer.head_size
|
||||
* model.transformer.num_heads :
|
||||
] = tensor
|
||||
elif "kv_attn.bias" in key:
|
||||
module._parameters[param_name][
|
||||
model.transformer.head_size
|
||||
* model.transformer.num_heads :
|
||||
] = tensor
|
||||
elif "c_attn" in key:
|
||||
# Slice q_tensor by shard
|
||||
q_tensor = tensor[: -2 * model.transformer.head_size]
|
||||
block_size = q_tensor.shape[0] // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
q_tensor = q_tensor[start:stop]
|
||||
|
||||
module._parameters[param_name][
|
||||
: q_tensor.shape[0]
|
||||
] = q_tensor
|
||||
|
||||
# Kv tensor is copied for every shard
|
||||
kv_tensor = tensor[-2 * model.transformer.head_size :]
|
||||
module._parameters[param_name][
|
||||
q_tensor.shape[0] :
|
||||
] = kv_tensor
|
||||
else:
|
||||
if current_parameter_tensor.shape != tensor.shape:
|
||||
raise ValueError(
|
||||
f"Name {key} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
|
||||
)
|
||||
|
||||
module._parameters[param_name] = tensor
|
||||
else:
|
||||
module._buffers[param_name] = tensor
|
||||
|
||||
if model.lm_head.weight.device == torch.device("meta"):
|
||||
model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
model.post_load_weights(quantize)
|
||||
def decode(self, generated_ids: List[int]) -> str:
|
||||
# Do not skip special tokens as they are used for custom parsing rules of the generated text
|
||||
return self.tokenizer.decode(
|
||||
generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
|
||||
)
|
||||
|
|
|
@ -2,41 +2,25 @@ import re
|
|||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from typing import List, Optional, Type, Tuple
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from accelerate import init_empty_weights
|
||||
from safetensors import safe_open
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoModelForCausalLM,
|
||||
AutoConfig,
|
||||
PreTrainedTokenizerBase,
|
||||
)
|
||||
from transformers.models.opt.parallel_layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelRowLinear,
|
||||
)
|
||||
|
||||
from text_generation_server.models import CausalLM
|
||||
from text_generation_server.models.causal_lm import CausalLMBatch
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.models.opt import OPT
|
||||
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
|
||||
from text_generation_server.utils import (
|
||||
NextTokenChooser,
|
||||
StoppingCriteria,
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
|
||||
HAS_BITS_AND_BYTES = True
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
from bitsandbytes.nn import Int8Params
|
||||
except Exception as e:
|
||||
HAS_BITS_AND_BYTES = False
|
||||
|
||||
|
||||
# CREDIT: Papers with code => https://github.com/paperswithcode/galai/blob/main/galai/utils.py
|
||||
|
||||
# we split individual characters inside special tokens like [START_DNA]
|
||||
|
@ -168,33 +152,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
|
|||
)
|
||||
|
||||
|
||||
class Galactica(OPT):
|
||||
@property
|
||||
def batch_type(self) -> Type[CausalLMBatch]:
|
||||
return GalacticaCausalLMBatch
|
||||
|
||||
def decode(self, generated_ids: List[int]) -> str:
|
||||
# Do not skip special tokens as they are used for custom parsing rules of the generated text
|
||||
return self.tokenizer.decode(
|
||||
generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
"""Overwrite forward to ignore position_ids"""
|
||||
|
||||
# Model Forward
|
||||
outputs = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
return outputs.logits, outputs.past_key_values
|
||||
|
||||
|
||||
class GalacticaSharded(Galactica):
|
||||
class GalacticaSharded(CausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
|
@ -224,26 +182,17 @@ class GalacticaSharded(Galactica):
|
|||
tp_parallel=True,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
config.quantize = quantize
|
||||
tokenizer.pad_token_id = config.pad_token_id
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
|
||||
with init_empty_weights():
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
config, trust_remote_code=trust_remote_code
|
||||
)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
self.load_weights(
|
||||
model,
|
||||
filenames,
|
||||
quantize=quantize,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
weights = Weights(
|
||||
filenames, device=device, dtype=dtype, process_group=self.process_group
|
||||
)
|
||||
|
||||
model = OPTForCausalLM(config, weights)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(CausalLM, self).__init__(
|
||||
model=model,
|
||||
|
@ -255,127 +204,15 @@ class GalacticaSharded(Galactica):
|
|||
world_size=world_size,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def load_weights(
|
||||
model,
|
||||
filenames: List[str],
|
||||
quantize: Optional[str],
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
):
|
||||
parameters = dict(model.named_parameters())
|
||||
for file in filenames:
|
||||
with safe_open(
|
||||
file, framework="pt", device=str(device) if quantize is None else "cpu"
|
||||
) as f:
|
||||
for name in f.keys():
|
||||
if name == "lm_head.weight":
|
||||
continue
|
||||
@property
|
||||
def batch_type(self) -> Type[CausalLMBatch]:
|
||||
return GalacticaCausalLMBatch
|
||||
|
||||
module_name, param_name = name.rsplit(".", 1)
|
||||
module = model.get_submodule(module_name)
|
||||
current_tensor = parameters[name]
|
||||
|
||||
slice_ = f.get_slice(name)
|
||||
|
||||
if isinstance(module, TensorParallelColumnLinear):
|
||||
size = slice_.get_shape()[0]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[start:stop]
|
||||
elif isinstance(module, TensorParallelRowLinear):
|
||||
if param_name == "weight":
|
||||
size = slice_.get_shape()[1]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[:, start:stop]
|
||||
else:
|
||||
tensor = slice_[:]
|
||||
# XXX: Hack for Rowlinear to add the bias only once.
|
||||
if rank != 0:
|
||||
tensor = torch.zeros_like(tensor)
|
||||
elif isinstance(module, TensorParallelEmbedding):
|
||||
size = slice_.get_shape()[0]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[start:stop]
|
||||
else:
|
||||
tensor = slice_[:]
|
||||
|
||||
if current_tensor.shape != tensor.shape:
|
||||
raise ValueError(
|
||||
f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}"
|
||||
)
|
||||
|
||||
tensor = tensor.contiguous().to(dtype)
|
||||
|
||||
if quantize == "bitsandbytes":
|
||||
if not HAS_BITS_AND_BYTES:
|
||||
raise ImportError(
|
||||
"bitsandbytes is not available on your machine either because it is not installed "
|
||||
"or you don't have a GPU.\n"
|
||||
"You can install it with `pip install bitsandbytes`."
|
||||
)
|
||||
|
||||
if (
|
||||
type(module)
|
||||
in [TensorParallelRowLinear, TensorParallelColumnLinear]
|
||||
and param_name == "weight"
|
||||
):
|
||||
tensor = Int8Params(
|
||||
tensor,
|
||||
has_fp16_weights=False,
|
||||
requires_grad=False,
|
||||
).to(device)
|
||||
state = bnb.MatmulLtState()
|
||||
state.threshold = 6.0
|
||||
state.has_fp16_weights = False
|
||||
state.memory_efficient_backward = False
|
||||
state.use_pool = True
|
||||
state.CB = tensor.CB
|
||||
state.SCB = tensor.SCB
|
||||
tensor.CB = None
|
||||
tensor.SCB = None
|
||||
|
||||
def replace_linear(state):
|
||||
def linear(input, weight, bias):
|
||||
out = bnb.matmul(
|
||||
input,
|
||||
weight,
|
||||
state=state,
|
||||
threshold=state.threshold,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
if state.CB is not None:
|
||||
# we converted 8-bit row major to turing/ampere format
|
||||
# in the first inference pass
|
||||
# we no longer need the row-major weight
|
||||
del state.CB
|
||||
weight.data = state.CxB
|
||||
|
||||
return out
|
||||
|
||||
return linear
|
||||
|
||||
module.linear = replace_linear(state)
|
||||
else:
|
||||
tensor = tensor.to(device)
|
||||
elif quantize == "gptq":
|
||||
raise NotImplementedError("`gptq` is not implemented for now")
|
||||
elif quantize is None:
|
||||
tensor = tensor.to(device)
|
||||
else:
|
||||
raise ValueError(f"Unexpected quantize `{quantize}`")
|
||||
|
||||
module._parameters[param_name] = tensor
|
||||
if name == "model.decoder.embed_tokens.weight":
|
||||
model.lm_head._parameters["weight"] = tensor
|
||||
def decode(self, generated_ids: List[int]) -> str:
|
||||
# Do not skip special tokens as they are used for custom parsing rules of the generated text
|
||||
return self.tokenizer.decode(
|
||||
generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
|
@ -386,10 +223,4 @@ class GalacticaSharded(Galactica):
|
|||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
# Logits are sharded, so we need to gather them
|
||||
logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
|
||||
torch.distributed.all_gather(logits, outputs.logits, group=self.process_group)
|
||||
logits = torch.cat(logits, dim=2)
|
||||
|
||||
return logits, outputs.past_key_values
|
||||
return outputs.logits, outputs.past_key_values
|
||||
|
|
|
@ -1,34 +1,22 @@
|
|||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from accelerate import init_empty_weights
|
||||
from safetensors import safe_open
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoModelForCausalLM,
|
||||
AutoConfig,
|
||||
)
|
||||
from transformers.models.gpt_neox.parallel_layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelRowLinear,
|
||||
)
|
||||
|
||||
from text_generation_server.models import CausalLM
|
||||
from text_generation_server.models.custom_modeling.neox_modeling import (
|
||||
GPTNeoxForCausalLM,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
|
||||
HAS_BITS_AND_BYTES = True
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
from bitsandbytes.nn import Int8Params
|
||||
except Exception as e:
|
||||
HAS_BITS_AND_BYTES = False
|
||||
|
||||
|
||||
class GPTNeoxSharded(CausalLM):
|
||||
def __init__(
|
||||
|
@ -58,28 +46,18 @@ class GPTNeoxSharded(CausalLM):
|
|||
config = AutoConfig.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
tp_parallel=True,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
config.quantize = quantize
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
|
||||
with init_empty_weights():
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
config, trust_remote_code=trust_remote_code
|
||||
)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
self.load_weights(
|
||||
model,
|
||||
filenames,
|
||||
quantize=quantize,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
weights = Weights(
|
||||
filenames, device=device, dtype=dtype, process_group=self.process_group
|
||||
)
|
||||
|
||||
model = GPTNeoxForCausalLM(config, weights)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(CausalLM, self).__init__(
|
||||
model=model,
|
||||
|
@ -91,161 +69,16 @@ class GPTNeoxSharded(CausalLM):
|
|||
world_size=world_size,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def load_weights(
|
||||
model,
|
||||
filenames: List[str],
|
||||
quantize: Optional[str],
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
):
|
||||
parameters = dict(model.named_parameters())
|
||||
for file in filenames:
|
||||
with safe_open(
|
||||
file, framework="pt", device=str(device) if quantize is None else "cpu"
|
||||
) as f:
|
||||
for name in f.keys():
|
||||
module_name, param_name = name.rsplit(".", 1)
|
||||
module = model.get_submodule(module_name)
|
||||
|
||||
current_parameter_tensor = parameters.get(name, None)
|
||||
|
||||
slice_ = f.get_slice(name)
|
||||
|
||||
if isinstance(module, TensorParallelColumnLinear):
|
||||
size = slice_.get_shape()[0]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[start:stop]
|
||||
elif isinstance(module, TensorParallelRowLinear):
|
||||
if param_name == "weight":
|
||||
size = slice_.get_shape()[1]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[:, start:stop]
|
||||
else:
|
||||
tensor = slice_[:]
|
||||
# XXX: Hack for Rowlinear to add the bias only once.
|
||||
if rank != 0:
|
||||
tensor = torch.zeros_like(tensor)
|
||||
elif isinstance(module, TensorParallelEmbedding):
|
||||
size = slice_.get_shape()[0]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[start:stop]
|
||||
elif name == "embed_out.weight" and model.gpt_neox.tp_embeddings:
|
||||
size = slice_.get_shape()[0]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[start:stop]
|
||||
else:
|
||||
try:
|
||||
tensor = slice_[:]
|
||||
except:
|
||||
tensor = f.get_tensor(name)
|
||||
|
||||
if (
|
||||
current_parameter_tensor is not None
|
||||
and current_parameter_tensor.shape != tensor.shape
|
||||
):
|
||||
raise ValueError(
|
||||
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
|
||||
)
|
||||
|
||||
tensor = tensor.contiguous().to(dtype)
|
||||
|
||||
if quantize == "bitsandbytes":
|
||||
if not HAS_BITS_AND_BYTES:
|
||||
raise ImportError(
|
||||
"bitsandbytes is not available on your machine either because it is not installed "
|
||||
"or you don't have a GPU.\n"
|
||||
"You can install it with `pip install bitsandbytes`."
|
||||
)
|
||||
|
||||
if (
|
||||
type(module)
|
||||
in [TensorParallelRowLinear, TensorParallelColumnLinear]
|
||||
and param_name == "weight"
|
||||
):
|
||||
tensor = Int8Params(
|
||||
tensor,
|
||||
has_fp16_weights=False,
|
||||
requires_grad=False,
|
||||
).to(device)
|
||||
state = bnb.MatmulLtState()
|
||||
state.threshold = 6.0
|
||||
state.has_fp16_weights = False
|
||||
state.memory_efficient_backward = False
|
||||
state.use_pool = True
|
||||
state.CB = tensor.CB
|
||||
state.SCB = tensor.SCB
|
||||
tensor.CB = None
|
||||
tensor.SCB = None
|
||||
|
||||
def replace_linear(state):
|
||||
def linear(input, weight, bias):
|
||||
out = bnb.matmul(
|
||||
input,
|
||||
weight,
|
||||
state=state,
|
||||
threshold=state.threshold,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
if state.CB is not None:
|
||||
# we converted 8-bit row major to turing/ampere format
|
||||
# in the first inference pass
|
||||
# we no longer need the row-major weight
|
||||
del state.CB
|
||||
weight.data = state.CxB
|
||||
|
||||
return out
|
||||
|
||||
return linear
|
||||
|
||||
module.linear = replace_linear(state)
|
||||
else:
|
||||
tensor = tensor.to(device)
|
||||
elif quantize == "gptq":
|
||||
raise NotImplementedError("`gptq` is not implemented for now")
|
||||
elif quantize is None:
|
||||
tensor = tensor.to(device)
|
||||
else:
|
||||
raise ValueError(f"Unexpected quantize `{quantize}`")
|
||||
|
||||
if current_parameter_tensor is not None:
|
||||
module._parameters[param_name] = tensor
|
||||
else:
|
||||
module._buffers[param_name] = tensor
|
||||
|
||||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
):
|
||||
if self.model.gpt_neox.tp_embeddings:
|
||||
outputs = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
outputs = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
# Logits are sharded, so we need to gather them
|
||||
logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
|
||||
torch.distributed.all_gather(
|
||||
logits, outputs.logits, group=self.process_group
|
||||
)
|
||||
logits = torch.cat(logits, dim=2)
|
||||
|
||||
return logits, outputs.past_key_values
|
||||
# While the model itself is sharded, the embeddings might not as they might not be dividable by num-shard
|
||||
else:
|
||||
return super(GPTNeoxSharded, self).forward(
|
||||
input_ids, attention_mask, position_ids, past_key_values
|
||||
)
|
||||
logits = outputs.logits
|
||||
return logits, outputs.past_key_values
|
||||
|
|
|
@ -1,52 +1,22 @@
|
|||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
from accelerate import init_empty_weights
|
||||
from safetensors import safe_open
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoModelForCausalLM,
|
||||
AutoConfig,
|
||||
)
|
||||
from transformers.models.opt.parallel_layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelRowLinear,
|
||||
)
|
||||
|
||||
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
|
||||
from text_generation_server.models import CausalLM
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
|
||||
HAS_BITS_AND_BYTES = True
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
from bitsandbytes.nn import Int8Params
|
||||
except Exception as e:
|
||||
HAS_BITS_AND_BYTES = False
|
||||
|
||||
|
||||
class OPT(CausalLM):
|
||||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
"""Overwrite forward to ignore position_ids"""
|
||||
|
||||
# Model Forward
|
||||
outputs = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
return outputs.logits, outputs.past_key_values
|
||||
|
||||
|
||||
class OPTSharded(OPT):
|
||||
class OPTSharded(CausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
|
@ -73,29 +43,19 @@ class OPTSharded(OPT):
|
|||
config = AutoConfig.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
tp_parallel=True,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
config.quantize = quantize
|
||||
tokenizer.pad_token_id = config.pad_token_id
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
|
||||
with init_empty_weights():
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
config, trust_remote_code=trust_remote_code
|
||||
)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
self.load_weights(
|
||||
model,
|
||||
filenames,
|
||||
quantize=quantize,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
weights = Weights(
|
||||
filenames, device=device, dtype=dtype, process_group=self.process_group
|
||||
)
|
||||
|
||||
model = OPTForCausalLM(config, weights)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(CausalLM, self).__init__(
|
||||
model=model,
|
||||
|
@ -107,128 +67,6 @@ class OPTSharded(OPT):
|
|||
world_size=world_size,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def load_weights(
|
||||
model,
|
||||
filenames: List[str],
|
||||
quantize: Optional[str],
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
):
|
||||
parameters = dict(model.named_parameters())
|
||||
for file in filenames:
|
||||
with safe_open(
|
||||
file, framework="pt", device=str(device) if quantize is None else "cpu"
|
||||
) as f:
|
||||
for name in f.keys():
|
||||
if name == "lm_head.weight":
|
||||
continue
|
||||
|
||||
module_name, param_name = name.rsplit(".", 1)
|
||||
module = model.get_submodule(module_name)
|
||||
current_tensor = parameters[name]
|
||||
|
||||
slice_ = f.get_slice(name)
|
||||
|
||||
if isinstance(module, TensorParallelColumnLinear):
|
||||
size = slice_.get_shape()[0]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[start:stop]
|
||||
elif isinstance(module, TensorParallelRowLinear):
|
||||
if param_name == "weight":
|
||||
size = slice_.get_shape()[1]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[:, start:stop]
|
||||
else:
|
||||
tensor = slice_[:]
|
||||
# XXX: Hack for Rowlinear to add the bias only once.
|
||||
if rank != 0:
|
||||
tensor = torch.zeros_like(tensor)
|
||||
elif isinstance(module, TensorParallelEmbedding):
|
||||
size = slice_.get_shape()[0]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[start:stop]
|
||||
else:
|
||||
tensor = slice_[:]
|
||||
|
||||
if current_tensor.shape != tensor.shape:
|
||||
raise ValueError(
|
||||
f"Name {name} -- Current {current_tensor.shape} and got {tensor.shape}"
|
||||
)
|
||||
|
||||
tensor = tensor.contiguous().to(dtype)
|
||||
|
||||
if quantize == "bitsandbytes":
|
||||
if not HAS_BITS_AND_BYTES:
|
||||
raise ImportError(
|
||||
"bitsandbytes is not available on your machine either because it is not installed "
|
||||
"or you don't have a GPU.\n"
|
||||
"You can install it with `pip install bitsandbytes`."
|
||||
)
|
||||
|
||||
if (
|
||||
type(module)
|
||||
in [TensorParallelRowLinear, TensorParallelColumnLinear]
|
||||
and param_name == "weight"
|
||||
):
|
||||
tensor = Int8Params(
|
||||
tensor,
|
||||
has_fp16_weights=False,
|
||||
requires_grad=False,
|
||||
).to(device)
|
||||
state = bnb.MatmulLtState()
|
||||
state.threshold = 6.0
|
||||
state.has_fp16_weights = False
|
||||
state.memory_efficient_backward = False
|
||||
state.use_pool = True
|
||||
state.CB = tensor.CB
|
||||
state.SCB = tensor.SCB
|
||||
tensor.CB = None
|
||||
tensor.SCB = None
|
||||
|
||||
def replace_linear(state):
|
||||
def linear(input, weight, bias):
|
||||
out = bnb.matmul(
|
||||
input,
|
||||
weight,
|
||||
state=state,
|
||||
threshold=state.threshold,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
if state.CB is not None:
|
||||
# we converted 8-bit row major to turing/ampere format
|
||||
# in the first inference pass
|
||||
# we no longer need the row-major weight
|
||||
del state.CB
|
||||
weight.data = state.CxB
|
||||
|
||||
return out
|
||||
|
||||
return linear
|
||||
|
||||
module.linear = replace_linear(state)
|
||||
else:
|
||||
tensor = tensor.to(device)
|
||||
elif quantize == "gptq":
|
||||
raise NotImplementedError("`gptq` is not implemented for now")
|
||||
elif quantize is None:
|
||||
tensor = tensor.to(device)
|
||||
else:
|
||||
raise ValueError(f"Unexpected quantize `{quantize}`")
|
||||
|
||||
module._parameters[param_name] = tensor
|
||||
if name == "model.decoder.embed_tokens.weight":
|
||||
model.lm_head._parameters["weight"] = tensor
|
||||
|
||||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
):
|
||||
|
@ -239,9 +77,4 @@ class OPTSharded(OPT):
|
|||
use_cache=True,
|
||||
)
|
||||
|
||||
# Logits are sharded, so we need to gather them
|
||||
logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
|
||||
torch.distributed.all_gather(logits, outputs.logits, group=self.process_group)
|
||||
logits = torch.cat(logits, dim=2)
|
||||
|
||||
return logits, outputs.past_key_values
|
||||
return outputs.logits, outputs.past_key_values
|
||||
|
|
|
@ -3,31 +3,20 @@ import torch.distributed
|
|||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from accelerate import init_empty_weights
|
||||
from safetensors import safe_open
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoConfig,
|
||||
)
|
||||
|
||||
from text_generation_server.models import Seq2SeqLM
|
||||
from text_generation_server.models.custom_modeling.t5_modeling import (
|
||||
T5ForConditionalGeneration,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from transformers.models.t5.parallel_layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
)
|
||||
|
||||
HAS_BITS_AND_BYTES = True
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
from bitsandbytes.nn import Int8Params
|
||||
except ImportError as e:
|
||||
HAS_BITS_AND_BYTES = False
|
||||
|
||||
|
||||
class T5Sharded(Seq2SeqLM):
|
||||
|
@ -46,6 +35,13 @@ class T5Sharded(Seq2SeqLM):
|
|||
device = torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
config.quantize = quantize
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
|
@ -53,33 +49,16 @@ class T5Sharded(Seq2SeqLM):
|
|||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
tp_parallel=True,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
tokenizer.bos_token_id = config.decoder_start_token_id
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
|
||||
with init_empty_weights():
|
||||
model = AutoModelForSeq2SeqLM.from_config(
|
||||
config, trust_remote_code=trust_remote_code
|
||||
)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
self.load_weights(
|
||||
model,
|
||||
filenames,
|
||||
quantize=quantize,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
weights = Weights(
|
||||
filenames, device=device, dtype=dtype, process_group=self.process_group
|
||||
)
|
||||
|
||||
model = T5ForConditionalGeneration(config, weights)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(Seq2SeqLM, self).__init__(
|
||||
model=model,
|
||||
|
@ -91,151 +70,6 @@ class T5Sharded(Seq2SeqLM):
|
|||
world_size=world_size,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def load_weights(
|
||||
model,
|
||||
filenames: List[str],
|
||||
quantize: Optional[str],
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
):
|
||||
parameters = dict(model.named_parameters())
|
||||
for file in filenames:
|
||||
with safe_open(
|
||||
file, framework="pt", device=str(device) if quantize is None else "cpu"
|
||||
) as f:
|
||||
for name in f.keys():
|
||||
module_name, param_name = name.rsplit(".", 1)
|
||||
module = model.get_submodule(module_name)
|
||||
|
||||
current_parameter_tensor = parameters.get(name, None)
|
||||
|
||||
slice_ = f.get_slice(name)
|
||||
|
||||
if isinstance(module, TensorParallelColumnLinear):
|
||||
size = slice_.get_shape()[0]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[start:stop]
|
||||
elif isinstance(module, TensorParallelRowLinear):
|
||||
if param_name == "weight":
|
||||
size = slice_.get_shape()[1]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[:, start:stop]
|
||||
else:
|
||||
tensor = slice_[:]
|
||||
# XXX: Hack for Rowlinear to add the bias only once.
|
||||
if rank != 0:
|
||||
tensor = torch.zeros_like(tensor)
|
||||
elif isinstance(module, TensorParallelEmbedding):
|
||||
size = slice_.get_shape()[0]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[start:stop]
|
||||
elif name == "lm_head.weight":
|
||||
size = slice_.get_shape()[0]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[start:stop]
|
||||
elif "relative_attention_bias.weight" in name:
|
||||
size = slice_.get_shape()[1]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[:, start:stop]
|
||||
else:
|
||||
try:
|
||||
tensor = slice_[:]
|
||||
except:
|
||||
tensor = f.get_tensor(name)
|
||||
|
||||
if (
|
||||
current_parameter_tensor is not None
|
||||
and current_parameter_tensor.shape != tensor.shape
|
||||
):
|
||||
raise ValueError(
|
||||
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
|
||||
)
|
||||
|
||||
tensor = tensor.contiguous()
|
||||
|
||||
# See: https://github.com/huggingface/transformers/blob/1fe1e3caa44617047f149bcc0c0b566343b714a7/src/transformers/models/t5/modeling_t5.py#LL316C15-L316C71
|
||||
if module_name.endswith("wo"):
|
||||
tensor = tensor.to(torch.float32)
|
||||
else:
|
||||
tensor = tensor.to(dtype)
|
||||
|
||||
if quantize == "bitsandbytes" and not module_name.endswith("wo"):
|
||||
if not HAS_BITS_AND_BYTES:
|
||||
raise ImportError(
|
||||
"bitsandbytes is not available on your machine either because it is not installed "
|
||||
"or you don't have a GPU.\n"
|
||||
"You can install it with `pip install bitsandbytes`."
|
||||
)
|
||||
|
||||
if (
|
||||
type(module)
|
||||
in [TensorParallelRowLinear, TensorParallelColumnLinear]
|
||||
and param_name == "weight"
|
||||
):
|
||||
tensor = Int8Params(
|
||||
tensor,
|
||||
has_fp16_weights=False,
|
||||
requires_grad=False,
|
||||
).to(device)
|
||||
state = bnb.MatmulLtState()
|
||||
state.threshold = 6.0
|
||||
state.has_fp16_weights = False
|
||||
state.memory_efficient_backward = False
|
||||
state.use_pool = True
|
||||
state.CB = tensor.CB
|
||||
state.SCB = tensor.SCB
|
||||
tensor.CB = None
|
||||
tensor.SCB = None
|
||||
|
||||
def replace_linear(state):
|
||||
def linear(input, weight, bias):
|
||||
out = bnb.matmul(
|
||||
input,
|
||||
weight,
|
||||
state=state,
|
||||
threshold=state.threshold,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
if state.CB is not None:
|
||||
# we converted 8-bit row major to turing/ampere format
|
||||
# in the first inference pass
|
||||
# we no longer need the row-major weight
|
||||
del state.CB
|
||||
weight.data = state.CxB
|
||||
|
||||
return out
|
||||
|
||||
return linear
|
||||
|
||||
module.linear = replace_linear(state)
|
||||
else:
|
||||
tensor = tensor.to(device)
|
||||
elif quantize == "gptq" and not module_name.endswith("wo"):
|
||||
raise NotImplementedError("`gptq` is not implemented for now")
|
||||
elif quantize is None or module_name.endswith("wo"):
|
||||
tensor = tensor.to(device)
|
||||
else:
|
||||
raise ValueError(f"Unexpected quantize `{quantize}`")
|
||||
|
||||
if current_parameter_tensor is not None:
|
||||
module._parameters[param_name] = tensor
|
||||
else:
|
||||
module._buffers[param_name] = tensor
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
|
@ -260,13 +94,8 @@ class T5Sharded(Seq2SeqLM):
|
|||
use_cache=True,
|
||||
)
|
||||
|
||||
# Logits are sharded, so we need to gather them
|
||||
logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
|
||||
torch.distributed.all_gather(logits, outputs.logits, group=self.process_group)
|
||||
logits = torch.cat(logits, dim=2)
|
||||
|
||||
return (
|
||||
logits,
|
||||
outputs.logits,
|
||||
outputs.encoder_last_hidden_state,
|
||||
outputs.past_key_values,
|
||||
)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from text_generation_server.utils.convert import convert_file, convert_files
|
||||
from text_generation_server.utils.dist import initialize_torch_distributed
|
||||
from text_generation_server.utils.weights import Weights
|
||||
from text_generation_server.utils.hub import (
|
||||
weight_files,
|
||||
weight_hub_files,
|
||||
|
@ -35,4 +36,5 @@ __all__ = [
|
|||
"StoppingCriteria",
|
||||
"StopSequenceCriteria",
|
||||
"FinishReason",
|
||||
"Weights",
|
||||
]
|
||||
|
|
|
@ -4,6 +4,37 @@ import torch
|
|||
from datetime import timedelta
|
||||
|
||||
|
||||
class FakeBarrier:
|
||||
def wait(self):
|
||||
pass
|
||||
|
||||
|
||||
class FakeGroup:
|
||||
def __init__(self, rank, size):
|
||||
self._rank = rank
|
||||
self._size = size
|
||||
|
||||
def allreduce(self, *args, **kwargs):
|
||||
return FakeBarrier()
|
||||
|
||||
def allgather(self, inputs, local_tensor, **kwargs):
|
||||
assert (
|
||||
len(inputs[0]) == len(local_tensor) == 1
|
||||
), f"{len(inputs[0])} != {len(local_tensor)} != 1, and the FakeGroup is supposed to join on simple tensors"
|
||||
for input_ in inputs:
|
||||
input_[0].data = local_tensor[0].data
|
||||
return FakeBarrier()
|
||||
|
||||
def barrier(self, *args, **kwargs):
|
||||
return FakeBarrier()
|
||||
|
||||
def size(self):
|
||||
return self._size
|
||||
|
||||
def rank(self):
|
||||
return self._rank
|
||||
|
||||
|
||||
def initialize_torch_distributed():
|
||||
rank = int(os.getenv("RANK", "0"))
|
||||
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
||||
|
@ -23,13 +54,18 @@ def initialize_torch_distributed():
|
|||
backend = "gloo"
|
||||
options = None
|
||||
|
||||
# Call the init process.
|
||||
torch.distributed.init_process_group(
|
||||
backend=backend,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
timeout=timedelta(seconds=60),
|
||||
pg_options=options,
|
||||
)
|
||||
if world_size == 1:
|
||||
return FakeGroup(rank, world_size), rank, world_size
|
||||
else:
|
||||
if os.getenv("DEBUG", None) == "1":
|
||||
return FakeGroup(rank, world_size), rank, world_size
|
||||
# Call the init process.
|
||||
torch.distributed.init_process_group(
|
||||
backend=backend,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
timeout=timedelta(seconds=60),
|
||||
pg_options=options,
|
||||
)
|
||||
|
||||
return torch.distributed.group.WORLD, rank, world_size
|
||||
return torch.distributed.group.WORLD, rank, world_size
|
||||
|
|
|
@ -1,176 +1,240 @@
|
|||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from typing import Optional
|
||||
from typing import List
|
||||
|
||||
HAS_BITS_AND_BYTES = True
|
||||
try:
|
||||
from bitsandbytes.nn import Linear8bitLt
|
||||
except ImportError as e:
|
||||
import bitsandbytes as bnb
|
||||
from bitsandbytes.nn import Int8Params
|
||||
|
||||
except ImportError:
|
||||
HAS_BITS_AND_BYTES = False
|
||||
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
class FastLinear(nn.Linear):
|
||||
|
||||
# Monkey patching
|
||||
@classmethod
|
||||
def load_layer_norm(cls, prefix, weights, eps):
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
bias = weights.get_tensor(f"{prefix}.bias")
|
||||
with init_empty_weights():
|
||||
ln = cls(weight.shape, eps=eps)
|
||||
|
||||
ln.weight = nn.Parameter(weight)
|
||||
ln.bias = nn.Parameter(bias)
|
||||
return ln
|
||||
|
||||
|
||||
torch.nn.LayerNorm.load = load_layer_norm
|
||||
|
||||
|
||||
class FastLinear(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
weight,
|
||||
bias,
|
||||
) -> None:
|
||||
super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype)
|
||||
self.quantized = False
|
||||
self.bnb_linear = None
|
||||
|
||||
def prepare_weights(self, quantize: Optional[str] = None):
|
||||
if quantize == "bitsandbytes":
|
||||
if not HAS_BITS_AND_BYTES:
|
||||
raise ImportError(
|
||||
"bitsandbytes is not available on your machine either because it is not installed "
|
||||
"or you don't have a GPU.\n"
|
||||
"You can install it with `pip install bitsandbytes`."
|
||||
)
|
||||
|
||||
self.quantized = True
|
||||
self.bnb_linear = Linear8bitLt(
|
||||
self.in_features,
|
||||
self.out_features,
|
||||
has_fp16_weights=False,
|
||||
threshold=6.0,
|
||||
bias=False,
|
||||
)
|
||||
# Copy data to bnb_linear
|
||||
self.bnb_linear.weight.data = self.weight.data
|
||||
if self.bias is not None:
|
||||
self.bnb_linear.bias = nn.Parameter(self.bias)
|
||||
|
||||
# Delete reference to data
|
||||
self.weight = None
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(weight)
|
||||
if bias is not None:
|
||||
self.bias = nn.Parameter(bias)
|
||||
else:
|
||||
self.bias = None
|
||||
elif quantize == "gptq":
|
||||
raise NotImplementedError("`gptq` is not implemented for now")
|
||||
elif quantize is None:
|
||||
self.weight = nn.Parameter(self.weight.T)
|
||||
|
||||
@classmethod
|
||||
def load(cls, config, prefix: str, weights, bias: bool):
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
if bias:
|
||||
bias = weights.get_tensor(f"{prefix}.bias")
|
||||
else:
|
||||
raise ValueError(f"Unexpected quantize `{quantize}`")
|
||||
bias = None
|
||||
return cls(weight, bias)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if self.quantized:
|
||||
return self.bnb_linear(input)
|
||||
else:
|
||||
if self.bias is not None:
|
||||
return torch.addmm(self.bias, input, self.weight)
|
||||
return torch.matmul(input, self.weight)
|
||||
return F.linear(input, self.weight, self.bias)
|
||||
|
||||
|
||||
class TensorParallelColumnLinear(FastLinear):
|
||||
class Linear8bitLt(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
out_features,
|
||||
process_group: torch.distributed.ProcessGroup,
|
||||
bias=True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
weight,
|
||||
bias,
|
||||
has_fp16_weights=True,
|
||||
memory_efficient_backward=False,
|
||||
threshold=0.0,
|
||||
index=None,
|
||||
):
|
||||
self.process_group = process_group
|
||||
self.tp_world_size = process_group.size()
|
||||
assert out_features % self.tp_world_size == 0
|
||||
out_features = out_features // self.tp_world_size
|
||||
super().__init__()
|
||||
assert (
|
||||
not memory_efficient_backward
|
||||
), "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
|
||||
self.state = bnb.MatmulLtState()
|
||||
self.index = index
|
||||
|
||||
super().__init__(
|
||||
in_features=in_features,
|
||||
out_features=out_features,
|
||||
bias=bias,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
# Necessary for stacked layers
|
||||
self.state.threshold = threshold
|
||||
self.state.has_fp16_weights = has_fp16_weights
|
||||
self.state.memory_efficient_backward = memory_efficient_backward
|
||||
if threshold > 0.0 and not has_fp16_weights:
|
||||
self.state.use_pool = True
|
||||
|
||||
self.weight = Int8Params(
|
||||
weight.data,
|
||||
has_fp16_weights=has_fp16_weights,
|
||||
requires_grad=has_fp16_weights,
|
||||
)
|
||||
self.weight.cuda(weight.device)
|
||||
self.bias = bias
|
||||
|
||||
def init_8bit_state(self):
|
||||
self.state.CB = self.weight.CB
|
||||
self.state.SCB = self.weight.SCB
|
||||
self.weight.CB = None
|
||||
self.weight.SCB = None
|
||||
|
||||
class TensorParallelRowLinear(FastLinear):
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
out_features,
|
||||
process_group: torch.distributed.ProcessGroup,
|
||||
reduce=True,
|
||||
bias=True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
self.process_group = process_group
|
||||
self.tp_world_size = process_group.size()
|
||||
self.reduce = reduce
|
||||
assert in_features % self.tp_world_size == 0
|
||||
in_features = in_features // self.tp_world_size
|
||||
def forward(self, x: torch.Tensor):
|
||||
self.state.is_training = self.training
|
||||
if self.weight.CB is not None:
|
||||
self.init_8bit_state()
|
||||
|
||||
super().__init__(
|
||||
in_features=in_features,
|
||||
out_features=out_features,
|
||||
bias=bias,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
# weights are cast automatically as Int8Params, but the bias has to be cast manually
|
||||
if self.bias is not None and self.bias.dtype != x.dtype:
|
||||
self.bias.data = self.bias.data.to(x.dtype)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
out = super(TensorParallelRowLinear, self).forward(input)
|
||||
if self.reduce:
|
||||
torch.distributed.all_reduce(out, group=self.process_group)
|
||||
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
|
||||
|
||||
if not self.state.has_fp16_weights:
|
||||
if self.state.CB is not None and self.state.CxB is not None:
|
||||
# we converted 8-bit row major to turing/ampere format in the first inference pass
|
||||
# we no longer need the row-major weight
|
||||
del self.state.CB
|
||||
self.weight.data = self.state.CxB
|
||||
return out
|
||||
|
||||
|
||||
class TensorParallelEmbedding(nn.Embedding):
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings,
|
||||
embedding_dim,
|
||||
process_group: torch.distributed.ProcessGroup,
|
||||
reduce=True,
|
||||
padding_idx=None,
|
||||
max_norm=None,
|
||||
norm_type=2.0,
|
||||
scale_grad_by_freq=False,
|
||||
sparse=False,
|
||||
_weight=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
self.reduce = reduce
|
||||
def get_linear(weight, bias, quantize):
|
||||
if quantize is None:
|
||||
linear = FastLinear(weight, bias)
|
||||
elif quantize == "bitsandbytes":
|
||||
linear = Linear8bitLt(
|
||||
weight,
|
||||
bias,
|
||||
has_fp16_weights=False,
|
||||
threshold=6.0,
|
||||
)
|
||||
if bias is not None:
|
||||
linear.bias = nn.Parameter(bias)
|
||||
elif quantize == "gptq":
|
||||
raise NotImplementedError("Soon")
|
||||
else:
|
||||
raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
|
||||
return linear
|
||||
|
||||
|
||||
class SuperLayer(nn.Module):
|
||||
def __init__(self, linear):
|
||||
super().__init__()
|
||||
self.linear = linear
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear.forward(x)
|
||||
|
||||
|
||||
class TensorParallelHead(SuperLayer):
|
||||
def __init__(self, linear, process_group):
|
||||
super().__init__(linear)
|
||||
self.process_group = process_group
|
||||
self.tp_rank = process_group.rank()
|
||||
self.tp_world_size = process_group.size()
|
||||
|
||||
self.original_num_embeddings = num_embeddings
|
||||
|
||||
assert num_embeddings % self.tp_world_size == 0
|
||||
block_size = num_embeddings // self.tp_world_size
|
||||
# inputs in `[min_id, max_id[` are handled by `self` to get embeddings
|
||||
self.min_id = self.tp_rank * block_size
|
||||
self.max_id = (self.tp_rank + 1) * block_size
|
||||
|
||||
# Additional entry that will map to zero
|
||||
# Used for masking
|
||||
self.null_idx = block_size
|
||||
|
||||
super().__init__(
|
||||
block_size,
|
||||
embedding_dim,
|
||||
padding_idx=padding_idx,
|
||||
max_norm=max_norm,
|
||||
norm_type=norm_type,
|
||||
scale_grad_by_freq=scale_grad_by_freq,
|
||||
sparse=sparse,
|
||||
_weight=_weight,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
@staticmethod
|
||||
def load(config, prefix: str, weights):
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
||||
return TensorParallelHead(
|
||||
get_linear(weight, bias=None, quantize=config.quantize),
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
|
||||
def add_null_idx(self):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
output = super().forward(input)
|
||||
# Logits are sharded, so we need to gather them
|
||||
world_output = [
|
||||
torch.empty_like(output) for _ in range(self.process_group.size())
|
||||
]
|
||||
torch.distributed.all_gather(world_output, output, group=self.process_group)
|
||||
world_output = torch.cat(world_output, dim=-1)
|
||||
return world_output
|
||||
|
||||
|
||||
class TensorParallelColumnLinear(SuperLayer):
|
||||
@classmethod
|
||||
def load(cls, config, prefix: str, weights, bias: bool):
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
||||
if bias:
|
||||
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
|
||||
else:
|
||||
bias = None
|
||||
return cls(get_linear(weight, bias, config.quantize))
|
||||
|
||||
@classmethod
|
||||
def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):
|
||||
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||
weight = torch.cat(w, dim=dim)
|
||||
|
||||
if bias:
|
||||
b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes]
|
||||
bias = torch.cat(b, dim=0)
|
||||
else:
|
||||
bias = None
|
||||
return cls(get_linear(weight, bias, config.quantize))
|
||||
|
||||
|
||||
class TensorParallelRowLinear(SuperLayer):
|
||||
def __init__(self, linear, process_group):
|
||||
super().__init__(linear)
|
||||
self.process_group = process_group
|
||||
|
||||
@classmethod
|
||||
def load(cls, config, prefix: str, weights, bias: bool):
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||
if bias and weights.process_group.rank() == 0:
|
||||
# Rank is only on the first rank process
|
||||
bias = weights.get_tensor(f"{prefix}.bias")
|
||||
else:
|
||||
bias = None
|
||||
return cls(
|
||||
get_linear(weight, bias, config.quantize),
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
out = super().forward(input)
|
||||
torch.distributed.all_reduce(out, group=self.process_group)
|
||||
return out
|
||||
|
||||
|
||||
class TensorParallelEmbedding(nn.Module):
|
||||
def __init__(self, prefix: str, weights, reduce=True):
|
||||
super().__init__()
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
||||
num_embeddings = weights.get_shape(f"{prefix}.weight")[0]
|
||||
|
||||
process_group = weights.process_group
|
||||
|
||||
world_size = process_group.size()
|
||||
rank = process_group.rank()
|
||||
|
||||
block_size = num_embeddings // world_size
|
||||
self.min_id = rank * block_size
|
||||
self.max_id = min(num_embeddings, (rank + 1) * block_size)
|
||||
self.null_idx = block_size
|
||||
self.process_group = weights.process_group
|
||||
self.reduce = reduce
|
||||
|
||||
"""Additional 0 entry used for masking"""
|
||||
self.weight = nn.Parameter(F.pad(self.weight, (0, 0, 0, 1)))
|
||||
self.weight = nn.Parameter(F.pad(weight, (0, 0, 0, 1)))
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
# default all out of bounds values to `self.null_idx` that will then be mapped to 0
|
||||
|
@ -180,7 +244,7 @@ class TensorParallelEmbedding(nn.Embedding):
|
|||
self.null_idx,
|
||||
input - self.min_id,
|
||||
)
|
||||
out = super().forward(input)
|
||||
out = torch.nn.functional.embedding(input, self.weight)
|
||||
if self.reduce:
|
||||
torch.distributed.all_reduce(out, group=self.process_group)
|
||||
return out
|
||||
|
@ -232,7 +296,34 @@ try:
|
|||
from flash_attn.layers.rotary import RotaryEmbedding
|
||||
import rotary_emb
|
||||
|
||||
class PositionRotaryEmbedding(RotaryEmbedding):
|
||||
class PositionRotaryEmbedding(nn.Module):
|
||||
def __init__(self, inv_freq):
|
||||
super().__init__()
|
||||
|
||||
self.register_buffer("inv_freq", inv_freq)
|
||||
self._seq_len_cached = 0
|
||||
self._cos_cached = None
|
||||
self._sin_cached = None
|
||||
self._cos_k_cached = None
|
||||
self._sin_k_cached = None
|
||||
|
||||
@classmethod
|
||||
def static(cls, dim, base, device):
|
||||
inv_freq = 1.0 / (
|
||||
base
|
||||
** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
|
||||
)
|
||||
return cls(inv_freq)
|
||||
|
||||
@classmethod
|
||||
def load(cls, prefix, weights):
|
||||
# XXX: Always load this in float32 !
|
||||
dtype = weights.dtype
|
||||
weights.dtype = torch.float32
|
||||
inv_freq = weights.get_tensor(f"{prefix}.inv_freq")
|
||||
weights.dtype = dtype
|
||||
return cls(inv_freq)
|
||||
|
||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||
# Reset the tables if the sequence length has changed,
|
||||
# or if we're on a new device (possibly due to tracing for instance)
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
from pathlib import Path
|
||||
from typing import List
|
||||
from safetensors import safe_open
|
||||
|
||||
|
||||
class Weights:
|
||||
def __init__(self, filenames: List[Path], device, dtype, process_group):
|
||||
routing = {}
|
||||
for filename in filenames:
|
||||
with safe_open(filename, framework="pytorch") as f:
|
||||
for k in f.keys():
|
||||
if k in routing:
|
||||
raise RuntimeError(
|
||||
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
|
||||
)
|
||||
routing[k] = filename
|
||||
self.routing = routing
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
self.process_group = process_group
|
||||
self._handles = {}
|
||||
|
||||
def _get_handle(self, filename):
|
||||
if filename not in self._handles:
|
||||
f = safe_open(filename, framework="pytorch")
|
||||
self._handles[filename] = f
|
||||
|
||||
return self._handles[filename]
|
||||
|
||||
def get_filename(self, tensor_name: str) -> str:
|
||||
filename = self.routing.get(tensor_name, None)
|
||||
if filename is None:
|
||||
raise RuntimeError(f"weight {tensor_name} does not exist")
|
||||
return str(filename)
|
||||
|
||||
def _get_slice(self, tensor_name: str):
|
||||
filename = self.get_filename(tensor_name)
|
||||
f = self._get_handle(filename)
|
||||
slice_ = f.get_slice(tensor_name)
|
||||
return slice_
|
||||
|
||||
def get_shape(self, tensor_name: str):
|
||||
return self._get_slice(tensor_name).get_shape()
|
||||
|
||||
def get_tensor(self, tensor_name: str):
|
||||
filename = self.get_filename(tensor_name)
|
||||
f = self._get_handle(filename)
|
||||
tensor = f.get_tensor(tensor_name)
|
||||
tensor = tensor.to(dtype=self.dtype)
|
||||
tensor = tensor.to(device=self.device)
|
||||
return tensor
|
||||
|
||||
def get_sharded(self, tensor_name: str, dim: int):
|
||||
filename = self.get_filename(tensor_name)
|
||||
world_size = self.process_group.size()
|
||||
rank = self.process_group.rank()
|
||||
|
||||
f = self._get_handle(filename)
|
||||
slice_ = f.get_slice(tensor_name)
|
||||
size = slice_.get_shape()[dim]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
|
||||
assert (
|
||||
size % world_size == 0
|
||||
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
|
||||
|
||||
if dim == 0:
|
||||
tensor = slice_[start:stop]
|
||||
elif dim == 1:
|
||||
tensor = slice_[:, start:stop]
|
||||
else:
|
||||
raise NotImplementedError("Let's make that generic when needed")
|
||||
tensor = tensor.to(dtype=self.dtype)
|
||||
tensor = tensor.to(device=self.device)
|
||||
return tensor
|
Loading…
Reference in New Issue