Add GPT-2 with flash attention (#1889)
# What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> This change adds `FlashGPT2ForCausalLM` and wires it up. The model itself is pretty straightforward, the main difference from other models is that it uses trained position embeddings and that all weight matrices are transposed compared to other models (due to the use of Conv1D in the upstream model). <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [x] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [x] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [x] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. @Narsil <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil -->
This commit is contained in:
parent
92f1338b84
commit
b5bc6e5c4e
|
@ -9,6 +9,7 @@ The following models are optimized and can be served with TGI, which uses custom
|
|||
- [BLOOM](https://huggingface.co/bigscience/bloom)
|
||||
- [FLAN-T5](https://huggingface.co/google/flan-t5-xxl)
|
||||
- [Galactica](https://huggingface.co/facebook/galactica-120b)
|
||||
- [GPT-2](https://huggingface.co/openai-community/gpt2)
|
||||
- [GPT-Neox](https://huggingface.co/EleutherAI/gpt-neox-20b)
|
||||
- [Llama](https://github.com/facebookresearch/llama)
|
||||
- [OPT](https://huggingface.co/facebook/opt-66b)
|
||||
|
|
|
@ -0,0 +1,99 @@
|
|||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 2061,
|
||||
"logprob": null,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 318,
|
||||
"logprob": -3.1835938,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 2769,
|
||||
"logprob": -9.171875,
|
||||
"text": " deep"
|
||||
},
|
||||
{
|
||||
"id": 4673,
|
||||
"logprob": -1.6425781,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 30,
|
||||
"logprob": -0.7314453,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 198,
|
||||
"logprob": -0.68603516,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 198,
|
||||
"logprob": -0.005393982,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 29744,
|
||||
"logprob": -0.31079102,
|
||||
"special": false,
|
||||
"text": "Deep"
|
||||
},
|
||||
{
|
||||
"id": 4673,
|
||||
"logprob": -0.08300781,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 318,
|
||||
"logprob": -0.58984375,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 257,
|
||||
"logprob": -0.953125,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 649,
|
||||
"logprob": -2.0957031,
|
||||
"special": false,
|
||||
"text": " new"
|
||||
},
|
||||
{
|
||||
"id": 2214,
|
||||
"logprob": -1.8095703,
|
||||
"special": false,
|
||||
"text": " field"
|
||||
},
|
||||
{
|
||||
"id": 286,
|
||||
"logprob": -1.0673828,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 2267,
|
||||
"logprob": -0.9375,
|
||||
"special": false,
|
||||
"text": " research"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\nDeep learning is a new field of research"
|
||||
}
|
|
@ -0,0 +1,398 @@
|
|||
[
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 2061,
|
||||
"logprob": null,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 318,
|
||||
"logprob": -3.1835938,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 2769,
|
||||
"logprob": -9.171875,
|
||||
"text": " deep"
|
||||
},
|
||||
{
|
||||
"id": 4673,
|
||||
"logprob": -1.6425781,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 30,
|
||||
"logprob": -0.7314453,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 198,
|
||||
"logprob": -0.68603516,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 198,
|
||||
"logprob": -0.005672455,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 29744,
|
||||
"logprob": -0.3251953,
|
||||
"special": false,
|
||||
"text": "Deep"
|
||||
},
|
||||
{
|
||||
"id": 4673,
|
||||
"logprob": -0.08294678,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 318,
|
||||
"logprob": -0.5854492,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 257,
|
||||
"logprob": -0.9423828,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 649,
|
||||
"logprob": -2.0800781,
|
||||
"special": false,
|
||||
"text": " new"
|
||||
},
|
||||
{
|
||||
"id": 2214,
|
||||
"logprob": -1.8369141,
|
||||
"special": false,
|
||||
"text": " field"
|
||||
},
|
||||
{
|
||||
"id": 286,
|
||||
"logprob": -1.0683594,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 2267,
|
||||
"logprob": -0.9711914,
|
||||
"special": false,
|
||||
"text": " research"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\nDeep learning is a new field of research"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 2061,
|
||||
"logprob": null,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 318,
|
||||
"logprob": -3.1660156,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 2769,
|
||||
"logprob": -9.1796875,
|
||||
"text": " deep"
|
||||
},
|
||||
{
|
||||
"id": 4673,
|
||||
"logprob": -1.6376953,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 30,
|
||||
"logprob": -0.72216797,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 198,
|
||||
"logprob": -0.7089844,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 198,
|
||||
"logprob": -0.0054779053,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 29744,
|
||||
"logprob": -0.3190918,
|
||||
"special": false,
|
||||
"text": "Deep"
|
||||
},
|
||||
{
|
||||
"id": 4673,
|
||||
"logprob": -0.08319092,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 318,
|
||||
"logprob": -0.5839844,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 257,
|
||||
"logprob": -0.9506836,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 649,
|
||||
"logprob": -2.0878906,
|
||||
"special": false,
|
||||
"text": " new"
|
||||
},
|
||||
{
|
||||
"id": 2214,
|
||||
"logprob": -1.8496094,
|
||||
"special": false,
|
||||
"text": " field"
|
||||
},
|
||||
{
|
||||
"id": 286,
|
||||
"logprob": -1.0673828,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 2267,
|
||||
"logprob": -0.9370117,
|
||||
"special": false,
|
||||
"text": " research"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\nDeep learning is a new field of research"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 2061,
|
||||
"logprob": null,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 318,
|
||||
"logprob": -3.1660156,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 2769,
|
||||
"logprob": -9.1796875,
|
||||
"text": " deep"
|
||||
},
|
||||
{
|
||||
"id": 4673,
|
||||
"logprob": -1.6376953,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 30,
|
||||
"logprob": -0.72216797,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 198,
|
||||
"logprob": -0.7089844,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 198,
|
||||
"logprob": -0.0054779053,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 29744,
|
||||
"logprob": -0.3190918,
|
||||
"special": false,
|
||||
"text": "Deep"
|
||||
},
|
||||
{
|
||||
"id": 4673,
|
||||
"logprob": -0.08319092,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 318,
|
||||
"logprob": -0.5839844,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 257,
|
||||
"logprob": -0.9506836,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 649,
|
||||
"logprob": -2.0878906,
|
||||
"special": false,
|
||||
"text": " new"
|
||||
},
|
||||
{
|
||||
"id": 2214,
|
||||
"logprob": -1.8496094,
|
||||
"special": false,
|
||||
"text": " field"
|
||||
},
|
||||
{
|
||||
"id": 286,
|
||||
"logprob": -1.0673828,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 2267,
|
||||
"logprob": -0.9370117,
|
||||
"special": false,
|
||||
"text": " research"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\nDeep learning is a new field of research"
|
||||
},
|
||||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 2061,
|
||||
"logprob": null,
|
||||
"text": "What"
|
||||
},
|
||||
{
|
||||
"id": 318,
|
||||
"logprob": -3.1660156,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 2769,
|
||||
"logprob": -9.1796875,
|
||||
"text": " deep"
|
||||
},
|
||||
{
|
||||
"id": 4673,
|
||||
"logprob": -1.6376953,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 30,
|
||||
"logprob": -0.72216797,
|
||||
"text": "?"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 198,
|
||||
"logprob": -0.7089844,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 198,
|
||||
"logprob": -0.0054779053,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 29744,
|
||||
"logprob": -0.3190918,
|
||||
"special": false,
|
||||
"text": "Deep"
|
||||
},
|
||||
{
|
||||
"id": 4673,
|
||||
"logprob": -0.08319092,
|
||||
"special": false,
|
||||
"text": " learning"
|
||||
},
|
||||
{
|
||||
"id": 318,
|
||||
"logprob": -0.5839844,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 257,
|
||||
"logprob": -0.9506836,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 649,
|
||||
"logprob": -2.0878906,
|
||||
"special": false,
|
||||
"text": " new"
|
||||
},
|
||||
{
|
||||
"id": 2214,
|
||||
"logprob": -1.8496094,
|
||||
"special": false,
|
||||
"text": " field"
|
||||
},
|
||||
{
|
||||
"id": 286,
|
||||
"logprob": -1.0673828,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 2267,
|
||||
"logprob": -0.9370117,
|
||||
"special": false,
|
||||
"text": " research"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "\n\nDeep learning is a new field of research"
|
||||
}
|
||||
]
|
|
@ -0,0 +1,44 @@
|
|||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_gpt2_handle(launcher):
|
||||
with launcher("openai-community/gpt2", num_shard=2) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def flash_gpt2(flash_gpt2_handle):
|
||||
await flash_gpt2_handle.health(300)
|
||||
return flash_gpt2_handle.client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_gpt2(flash_gpt2, response_snapshot):
|
||||
response = await flash_gpt2.generate(
|
||||
"What is deep learning?",
|
||||
max_new_tokens=10,
|
||||
decoder_input_details=True,
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flash_gpt2_load(flash_gpt2, generate_load, response_snapshot):
|
||||
responses = await generate_load(
|
||||
flash_gpt2,
|
||||
"What is deep learning?",
|
||||
max_new_tokens=10,
|
||||
n=4,
|
||||
)
|
||||
|
||||
generated_texts = [r.generated_text for r in responses]
|
||||
|
||||
assert len(generated_texts) == 4
|
||||
assert all(
|
||||
[text == generated_texts[0] for text in generated_texts]
|
||||
), generated_texts
|
||||
|
||||
assert responses == response_snapshot
|
|
@ -132,6 +132,7 @@ pub enum Config {
|
|||
Santacoder,
|
||||
Bloom,
|
||||
Mpt,
|
||||
Gpt2,
|
||||
GptNeox,
|
||||
Phi,
|
||||
#[serde(rename = "phi-msft")]
|
||||
|
|
|
@ -51,6 +51,7 @@ FLASH_ATTENTION = True
|
|||
|
||||
try:
|
||||
from text_generation_server.models.flash_rw import FlashRWSharded
|
||||
from text_generation_server.models.flash_gpt2 import FlashGPT2
|
||||
from text_generation_server.models.flash_neox import FlashNeoXSharded
|
||||
from text_generation_server.models.flash_llama import (
|
||||
FlashLlama,
|
||||
|
@ -83,6 +84,7 @@ except ImportError as e:
|
|||
HAS_FLASH_ATTN_V2_CUDA = False
|
||||
|
||||
if FLASH_ATTENTION:
|
||||
__all__.append(FlashGPT2)
|
||||
__all__.append(FlashNeoXSharded)
|
||||
__all__.append(FlashRWSharded)
|
||||
__all__.append(FlashSantacoderSharded)
|
||||
|
@ -325,7 +327,27 @@ def get_model(
|
|||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
elif model_type == "gpt2":
|
||||
if FLASH_ATTENTION:
|
||||
return FlashGPT2(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
|
||||
else:
|
||||
return CausalLM(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
elif model_type == "gpt_neox":
|
||||
if FLASH_ATTENTION:
|
||||
return FlashNeoXSharded(
|
||||
|
|
|
@ -0,0 +1,454 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from text_generation_server.utils import paged_attention, flash_attn
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
SpeculativeHead,
|
||||
get_linear,
|
||||
)
|
||||
|
||||
|
||||
def load_qkv(config, prefix: str, weights, head_size, num_heads):
|
||||
if config.quantize == "gptq":
|
||||
return _load_qkv_gptq(
|
||||
config,
|
||||
prefix,
|
||||
weights,
|
||||
)
|
||||
else:
|
||||
return _load_qkv(config, prefix, weights, head_size, num_heads)
|
||||
|
||||
|
||||
def _load_qkv_gptq(config, prefix: str, weights):
|
||||
world_size = weights.process_group.size()
|
||||
rank = weights.process_group.rank()
|
||||
|
||||
# Weights
|
||||
weight = weights.get_weights_col_packed_qkv(f"{prefix}.c_attn", config.quantize)
|
||||
|
||||
# Bias
|
||||
slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
|
||||
shape = slice_.get_shape()
|
||||
total_size = shape[0]
|
||||
assert total_size % 3 == 0, f"Prepacked is not divisible by {3}"
|
||||
single_size = total_size // 3
|
||||
assert single_size % world_size == 0
|
||||
block_size = single_size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensors = []
|
||||
for i in range(3):
|
||||
tensor = slice_[start + i * single_size : stop + i * single_size]
|
||||
tensors.append(tensor)
|
||||
bias = torch.cat(tensors, dim=0)
|
||||
bias = bias.to(device=weights.device)
|
||||
|
||||
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
|
||||
|
||||
|
||||
def _load_qkv(config, prefix: str, weights, head_size, num_heads):
|
||||
"""Load QKV from a single, transposed matrix."""
|
||||
|
||||
slice_ = weights._get_slice(f"{prefix}.c_attn.weight")
|
||||
shape = slice_.get_shape()
|
||||
total_size = shape[1]
|
||||
assert total_size % 3 == 0, f"Prepacked is not divisible by {3}"
|
||||
world_size = weights.process_group.size()
|
||||
single_size = total_size // 3
|
||||
assert single_size % world_size == 0
|
||||
rank = weights.process_group.rank()
|
||||
|
||||
# Weights
|
||||
block_size = single_size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensors = []
|
||||
for i in range(3):
|
||||
tensor = slice_[:, start + i * single_size : stop + i * single_size]
|
||||
tensors.append(tensor)
|
||||
weight = torch.cat(tensors, dim=1).T
|
||||
weight = weight.to(dtype=weights.dtype)
|
||||
weight = weight.to(device=weights.device)
|
||||
|
||||
# Bias
|
||||
slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
|
||||
shape = slice_.get_shape()
|
||||
total_size = shape[0]
|
||||
single_size = total_size // 3
|
||||
block_size = single_size // world_size
|
||||
assert single_size % world_size == 0
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
b = []
|
||||
for i in range(3):
|
||||
tensor = slice_[start + i * single_size : stop + i * single_size]
|
||||
b.append(tensor)
|
||||
bias = torch.cat(b, dim=0)
|
||||
bias = bias.to(dtype=weights.dtype)
|
||||
bias = bias.to(device=weights.device)
|
||||
assert list(bias.shape) == [
|
||||
3 * num_heads * head_size
|
||||
], f"{weight.shape} != {[3 * num_heads * head_size]}"
|
||||
|
||||
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
|
||||
|
||||
|
||||
def load_row(config, prefix: str, weights, bias: bool):
|
||||
"""load_row, but with transposed weight matrices."""
|
||||
|
||||
if config.quantize == "gptq":
|
||||
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
|
||||
else:
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T
|
||||
|
||||
if bias and weights.process_group.rank() == 0:
|
||||
# Rank is only on the first rank process
|
||||
bias = weights.get_tensor(f"{prefix}.bias")
|
||||
else:
|
||||
bias = None
|
||||
|
||||
return TensorParallelRowLinear(
|
||||
get_linear(weight, bias, config.quantize), process_group=weights.process_group
|
||||
)
|
||||
|
||||
|
||||
def load_col(config, prefix: str, weights, bias: bool):
|
||||
"""load_col, but with transposed weight matrices."""
|
||||
if config.quantize == "gptq":
|
||||
weight = weights.get_multi_weights_col(
|
||||
[prefix], quantize=config.quantize, dim=1
|
||||
)
|
||||
else:
|
||||
weight = weights.get_sharded(f"{prefix}.weight", dim=1).T
|
||||
|
||||
if bias:
|
||||
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
|
||||
else:
|
||||
bias = None
|
||||
|
||||
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
|
||||
|
||||
|
||||
class FlashGPT2Attention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
prefix: str,
|
||||
config,
|
||||
weights,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.head_size = self.hidden_size // self.num_heads
|
||||
self.softmax_scale = self.head_size**-0.5
|
||||
|
||||
if self.num_heads % weights.process_group.size() != 0:
|
||||
raise ValueError(
|
||||
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||
f"and `num_shards`: {weights.process_group.size()}"
|
||||
)
|
||||
self.num_heads = self.num_heads // weights.process_group.size()
|
||||
|
||||
self.query_key_value = load_qkv(
|
||||
config,
|
||||
prefix=prefix,
|
||||
weights=weights,
|
||||
head_size=self.head_size,
|
||||
num_heads=self.num_heads,
|
||||
)
|
||||
|
||||
self.o_proj = load_row(
|
||||
config,
|
||||
prefix=f"{prefix}.c_proj",
|
||||
weights=weights,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_heads, dtype=torch.int32, device=weights.device
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
):
|
||||
query, key, value = self.query_key_value(hidden_states).split(
|
||||
self.head_size * self.num_heads, dim=1
|
||||
)
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_heads, self.head_size)
|
||||
value = value.view(-1, self.num_heads, self.head_size)
|
||||
|
||||
paged_attention.reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
|
||||
|
||||
# output tensor
|
||||
attn_output = torch.empty_like(query)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
flash_attn.attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_output,
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
paged_attention.attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
|
||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
||||
|
||||
class GPT2MLP(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
act = config.activation_function
|
||||
self.act = (
|
||||
ACT2FN[act]
|
||||
if "gelu" not in act
|
||||
else lambda x: torch.nn.functional.gelu(
|
||||
x,
|
||||
approximate=(
|
||||
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
self.c_fc = load_col(
|
||||
config, prefix=f"{prefix}.c_fc", weights=weights, bias=True
|
||||
)
|
||||
self.c_proj = load_row(
|
||||
config,
|
||||
prefix=f"{prefix}.c_proj",
|
||||
weights=weights,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
intermediate_size = (
|
||||
config.n_inner if config.n_inner is not None else 4 * config.hidden_size
|
||||
)
|
||||
|
||||
self.intermediate_size = intermediate_size // weights.process_group.size()
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.c_fc(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
return self.c_proj(hidden_states)
|
||||
|
||||
|
||||
class FlashGPT2Layer(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.self_attn = FlashGPT2Attention(
|
||||
prefix=f"{prefix}.attn", config=config, weights=weights
|
||||
)
|
||||
self.mlp = GPT2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||
|
||||
self.input_layernorm = nn.LayerNorm.load(
|
||||
prefix=f"{prefix}.ln_1", weights=weights, eps=config.layer_norm_epsilon
|
||||
)
|
||||
self.post_attention_layernorm = nn.LayerNorm.load(
|
||||
prefix=f"{prefix}.ln_2",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_epsilon,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
residual,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
):
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
attn_output = self.self_attn(
|
||||
hidden_states,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
|
||||
hidden_states = attn_output + residual
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
|
||||
mlp_output = self.mlp(hidden_states)
|
||||
|
||||
return residual + mlp_output, residual
|
||||
|
||||
|
||||
class FlashGPT2Model(torch.nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
|
||||
process_group = weights.process_group
|
||||
self.tp_rank = process_group.rank()
|
||||
self.tp_world_size = process_group.size()
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
FlashGPT2Layer(
|
||||
prefix=(
|
||||
f"h.{layer_id}" if not prefix else f"{prefix}.h.{layer_id}"
|
||||
),
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.norm = nn.LayerNorm.load(
|
||||
prefix="ln_f" if not prefix else f"{prefix}.ln_f",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_epsilon,
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
self.head_size = self.layers[0].self_attn.head_size
|
||||
self.num_heads = self.layers[0].self_attn.num_heads
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
true_max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
residual = None
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
residual,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache[i],
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlashGPT2ForCausalLM(torch.nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix=("wte" if not prefix else f"{prefix}.wte"),
|
||||
weights=weights,
|
||||
)
|
||||
self.embed_positions = TensorParallelEmbedding(
|
||||
prefix=("wpe" if not prefix else f"{prefix}.wpe"),
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
self.model = FlashGPT2Model(prefix, config, weights)
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config,
|
||||
prefix="wte" if not prefix else f"{prefix}.wte",
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor] = None,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
token_embeds = self.embed_tokens(input_ids)
|
||||
position_embeds = self.embed_positions(position_ids)
|
||||
inputs_embeds = token_embeds + position_embeds
|
||||
hidden_states = self.model(
|
||||
inputs_embeds,
|
||||
position_ids,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
true_max_s=max_s,
|
||||
prefill_cache_indices=prefill_cache_indices,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits, speculative_logits = self.lm_head(hidden_states)
|
||||
return logits, speculative_logits
|
|
@ -0,0 +1,78 @@
|
|||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from opentelemetry import trace
|
||||
from transformers import AutoConfig, AutoTokenizer, GenerationConfig
|
||||
from transformers.models.gpt2 import GPT2Tokenizer
|
||||
from typing import Optional
|
||||
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (
|
||||
FlashGPT2ForCausalLM,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
|
||||
class FlashGPT2(FlashCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif SYSTEM == "xpu":
|
||||
device = torch.device(f"xpu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError("FlashGPT2 is only available on GPU")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.speculator = speculator
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
if config.quantize in ["gptq", "awq"]:
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
prefix = ""
|
||||
model = FlashGPT2ForCausalLM(prefix, config, weights)
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashGPT2, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.model.layers),
|
||||
num_kv_heads=model.model.num_heads,
|
||||
head_size=model.model.head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
Loading…
Reference in New Issue