Add GPT-2 with flash attention (#1889)

# What does this PR do?

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

This change adds `FlashGPT2ForCausalLM` and wires it up. The model
itself is pretty straightforward, the main difference from other models
is that it uses trained position embeddings and that all weight matrices
are transposed compared to other models (due to the use of Conv1D in the
upstream model).


<!-- Remove if not applicable -->

Fixes # (issue)


## Before submitting
- [x] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [x] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [x] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [x] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

@Narsil 

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @


@OlivierDehaene OR @Narsil

 -->
This commit is contained in:
Daniël de Kok 2024-05-15 13:31:22 +02:00 committed by GitHub
parent 92f1338b84
commit b5bc6e5c4e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 1098 additions and 1 deletions

View File

@ -9,6 +9,7 @@ The following models are optimized and can be served with TGI, which uses custom
- [BLOOM](https://huggingface.co/bigscience/bloom)
- [FLAN-T5](https://huggingface.co/google/flan-t5-xxl)
- [Galactica](https://huggingface.co/facebook/galactica-120b)
- [GPT-2](https://huggingface.co/openai-community/gpt2)
- [GPT-Neox](https://huggingface.co/EleutherAI/gpt-neox-20b)
- [Llama](https://github.com/facebookresearch/llama)
- [OPT](https://huggingface.co/facebook/opt-66b)

View File

@ -0,0 +1,99 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2061,
"logprob": null,
"text": "What"
},
{
"id": 318,
"logprob": -3.1835938,
"text": " is"
},
{
"id": 2769,
"logprob": -9.171875,
"text": " deep"
},
{
"id": 4673,
"logprob": -1.6425781,
"text": " learning"
},
{
"id": 30,
"logprob": -0.7314453,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 198,
"logprob": -0.68603516,
"special": false,
"text": "\n"
},
{
"id": 198,
"logprob": -0.005393982,
"special": false,
"text": "\n"
},
{
"id": 29744,
"logprob": -0.31079102,
"special": false,
"text": "Deep"
},
{
"id": 4673,
"logprob": -0.08300781,
"special": false,
"text": " learning"
},
{
"id": 318,
"logprob": -0.58984375,
"special": false,
"text": " is"
},
{
"id": 257,
"logprob": -0.953125,
"special": false,
"text": " a"
},
{
"id": 649,
"logprob": -2.0957031,
"special": false,
"text": " new"
},
{
"id": 2214,
"logprob": -1.8095703,
"special": false,
"text": " field"
},
{
"id": 286,
"logprob": -1.0673828,
"special": false,
"text": " of"
},
{
"id": 2267,
"logprob": -0.9375,
"special": false,
"text": " research"
}
],
"top_tokens": null
},
"generated_text": "\n\nDeep learning is a new field of research"
}

View File

@ -0,0 +1,398 @@
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2061,
"logprob": null,
"text": "What"
},
{
"id": 318,
"logprob": -3.1835938,
"text": " is"
},
{
"id": 2769,
"logprob": -9.171875,
"text": " deep"
},
{
"id": 4673,
"logprob": -1.6425781,
"text": " learning"
},
{
"id": 30,
"logprob": -0.7314453,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 198,
"logprob": -0.68603516,
"special": false,
"text": "\n"
},
{
"id": 198,
"logprob": -0.005672455,
"special": false,
"text": "\n"
},
{
"id": 29744,
"logprob": -0.3251953,
"special": false,
"text": "Deep"
},
{
"id": 4673,
"logprob": -0.08294678,
"special": false,
"text": " learning"
},
{
"id": 318,
"logprob": -0.5854492,
"special": false,
"text": " is"
},
{
"id": 257,
"logprob": -0.9423828,
"special": false,
"text": " a"
},
{
"id": 649,
"logprob": -2.0800781,
"special": false,
"text": " new"
},
{
"id": 2214,
"logprob": -1.8369141,
"special": false,
"text": " field"
},
{
"id": 286,
"logprob": -1.0683594,
"special": false,
"text": " of"
},
{
"id": 2267,
"logprob": -0.9711914,
"special": false,
"text": " research"
}
],
"top_tokens": null
},
"generated_text": "\n\nDeep learning is a new field of research"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2061,
"logprob": null,
"text": "What"
},
{
"id": 318,
"logprob": -3.1660156,
"text": " is"
},
{
"id": 2769,
"logprob": -9.1796875,
"text": " deep"
},
{
"id": 4673,
"logprob": -1.6376953,
"text": " learning"
},
{
"id": 30,
"logprob": -0.72216797,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 198,
"logprob": -0.7089844,
"special": false,
"text": "\n"
},
{
"id": 198,
"logprob": -0.0054779053,
"special": false,
"text": "\n"
},
{
"id": 29744,
"logprob": -0.3190918,
"special": false,
"text": "Deep"
},
{
"id": 4673,
"logprob": -0.08319092,
"special": false,
"text": " learning"
},
{
"id": 318,
"logprob": -0.5839844,
"special": false,
"text": " is"
},
{
"id": 257,
"logprob": -0.9506836,
"special": false,
"text": " a"
},
{
"id": 649,
"logprob": -2.0878906,
"special": false,
"text": " new"
},
{
"id": 2214,
"logprob": -1.8496094,
"special": false,
"text": " field"
},
{
"id": 286,
"logprob": -1.0673828,
"special": false,
"text": " of"
},
{
"id": 2267,
"logprob": -0.9370117,
"special": false,
"text": " research"
}
],
"top_tokens": null
},
"generated_text": "\n\nDeep learning is a new field of research"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2061,
"logprob": null,
"text": "What"
},
{
"id": 318,
"logprob": -3.1660156,
"text": " is"
},
{
"id": 2769,
"logprob": -9.1796875,
"text": " deep"
},
{
"id": 4673,
"logprob": -1.6376953,
"text": " learning"
},
{
"id": 30,
"logprob": -0.72216797,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 198,
"logprob": -0.7089844,
"special": false,
"text": "\n"
},
{
"id": 198,
"logprob": -0.0054779053,
"special": false,
"text": "\n"
},
{
"id": 29744,
"logprob": -0.3190918,
"special": false,
"text": "Deep"
},
{
"id": 4673,
"logprob": -0.08319092,
"special": false,
"text": " learning"
},
{
"id": 318,
"logprob": -0.5839844,
"special": false,
"text": " is"
},
{
"id": 257,
"logprob": -0.9506836,
"special": false,
"text": " a"
},
{
"id": 649,
"logprob": -2.0878906,
"special": false,
"text": " new"
},
{
"id": 2214,
"logprob": -1.8496094,
"special": false,
"text": " field"
},
{
"id": 286,
"logprob": -1.0673828,
"special": false,
"text": " of"
},
{
"id": 2267,
"logprob": -0.9370117,
"special": false,
"text": " research"
}
],
"top_tokens": null
},
"generated_text": "\n\nDeep learning is a new field of research"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2061,
"logprob": null,
"text": "What"
},
{
"id": 318,
"logprob": -3.1660156,
"text": " is"
},
{
"id": 2769,
"logprob": -9.1796875,
"text": " deep"
},
{
"id": 4673,
"logprob": -1.6376953,
"text": " learning"
},
{
"id": 30,
"logprob": -0.72216797,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 198,
"logprob": -0.7089844,
"special": false,
"text": "\n"
},
{
"id": 198,
"logprob": -0.0054779053,
"special": false,
"text": "\n"
},
{
"id": 29744,
"logprob": -0.3190918,
"special": false,
"text": "Deep"
},
{
"id": 4673,
"logprob": -0.08319092,
"special": false,
"text": " learning"
},
{
"id": 318,
"logprob": -0.5839844,
"special": false,
"text": " is"
},
{
"id": 257,
"logprob": -0.9506836,
"special": false,
"text": " a"
},
{
"id": 649,
"logprob": -2.0878906,
"special": false,
"text": " new"
},
{
"id": 2214,
"logprob": -1.8496094,
"special": false,
"text": " field"
},
{
"id": 286,
"logprob": -1.0673828,
"special": false,
"text": " of"
},
{
"id": 2267,
"logprob": -0.9370117,
"special": false,
"text": " research"
}
],
"top_tokens": null
},
"generated_text": "\n\nDeep learning is a new field of research"
}
]

View File

@ -0,0 +1,44 @@
import pytest
@pytest.fixture(scope="module")
def flash_gpt2_handle(launcher):
with launcher("openai-community/gpt2", num_shard=2) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_gpt2(flash_gpt2_handle):
await flash_gpt2_handle.health(300)
return flash_gpt2_handle.client
@pytest.mark.asyncio
async def test_flash_gpt2(flash_gpt2, response_snapshot):
response = await flash_gpt2.generate(
"What is deep learning?",
max_new_tokens=10,
decoder_input_details=True,
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
async def test_flash_gpt2_load(flash_gpt2, generate_load, response_snapshot):
responses = await generate_load(
flash_gpt2,
"What is deep learning?",
max_new_tokens=10,
n=4,
)
generated_texts = [r.generated_text for r in responses]
assert len(generated_texts) == 4
assert all(
[text == generated_texts[0] for text in generated_texts]
), generated_texts
assert responses == response_snapshot

View File

@ -132,6 +132,7 @@ pub enum Config {
Santacoder,
Bloom,
Mpt,
Gpt2,
GptNeox,
Phi,
#[serde(rename = "phi-msft")]

View File

@ -51,6 +51,7 @@ FLASH_ATTENTION = True
try:
from text_generation_server.models.flash_rw import FlashRWSharded
from text_generation_server.models.flash_gpt2 import FlashGPT2
from text_generation_server.models.flash_neox import FlashNeoXSharded
from text_generation_server.models.flash_llama import (
FlashLlama,
@ -83,6 +84,7 @@ except ImportError as e:
HAS_FLASH_ATTN_V2_CUDA = False
if FLASH_ATTENTION:
__all__.append(FlashGPT2)
__all__.append(FlashNeoXSharded)
__all__.append(FlashRWSharded)
__all__.append(FlashSantacoderSharded)
@ -325,7 +327,27 @@ def get_model(
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif model_type == "gpt2":
if FLASH_ATTENTION:
return FlashGPT2(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
else:
return CausalLM(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif model_type == "gpt_neox":
if FLASH_ATTENTION:
return FlashNeoXSharded(

View File

@ -0,0 +1,454 @@
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
SpeculativeHead,
get_linear,
)
def load_qkv(config, prefix: str, weights, head_size, num_heads):
if config.quantize == "gptq":
return _load_qkv_gptq(
config,
prefix,
weights,
)
else:
return _load_qkv(config, prefix, weights, head_size, num_heads)
def _load_qkv_gptq(config, prefix: str, weights):
world_size = weights.process_group.size()
rank = weights.process_group.rank()
# Weights
weight = weights.get_weights_col_packed_qkv(f"{prefix}.c_attn", config.quantize)
# Bias
slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
shape = slice_.get_shape()
total_size = shape[0]
assert total_size % 3 == 0, f"Prepacked is not divisible by {3}"
single_size = total_size // 3
assert single_size % world_size == 0
block_size = single_size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensors = []
for i in range(3):
tensor = slice_[start + i * single_size : stop + i * single_size]
tensors.append(tensor)
bias = torch.cat(tensors, dim=0)
bias = bias.to(device=weights.device)
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
def _load_qkv(config, prefix: str, weights, head_size, num_heads):
"""Load QKV from a single, transposed matrix."""
slice_ = weights._get_slice(f"{prefix}.c_attn.weight")
shape = slice_.get_shape()
total_size = shape[1]
assert total_size % 3 == 0, f"Prepacked is not divisible by {3}"
world_size = weights.process_group.size()
single_size = total_size // 3
assert single_size % world_size == 0
rank = weights.process_group.rank()
# Weights
block_size = single_size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensors = []
for i in range(3):
tensor = slice_[:, start + i * single_size : stop + i * single_size]
tensors.append(tensor)
weight = torch.cat(tensors, dim=1).T
weight = weight.to(dtype=weights.dtype)
weight = weight.to(device=weights.device)
# Bias
slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
shape = slice_.get_shape()
total_size = shape[0]
single_size = total_size // 3
block_size = single_size // world_size
assert single_size % world_size == 0
start = rank * block_size
stop = (rank + 1) * block_size
b = []
for i in range(3):
tensor = slice_[start + i * single_size : stop + i * single_size]
b.append(tensor)
bias = torch.cat(b, dim=0)
bias = bias.to(dtype=weights.dtype)
bias = bias.to(device=weights.device)
assert list(bias.shape) == [
3 * num_heads * head_size
], f"{weight.shape} != {[3 * num_heads * head_size]}"
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
def load_row(config, prefix: str, weights, bias: bool):
"""load_row, but with transposed weight matrices."""
if config.quantize == "gptq":
weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
else:
weight = weights.get_sharded(f"{prefix}.weight", dim=0).T
if bias and weights.process_group.rank() == 0:
# Rank is only on the first rank process
bias = weights.get_tensor(f"{prefix}.bias")
else:
bias = None
return TensorParallelRowLinear(
get_linear(weight, bias, config.quantize), process_group=weights.process_group
)
def load_col(config, prefix: str, weights, bias: bool):
"""load_col, but with transposed weight matrices."""
if config.quantize == "gptq":
weight = weights.get_multi_weights_col(
[prefix], quantize=config.quantize, dim=1
)
else:
weight = weights.get_sharded(f"{prefix}.weight", dim=1).T
if bias:
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
else:
bias = None
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
class FlashGPT2Attention(torch.nn.Module):
def __init__(
self,
prefix: str,
config,
weights,
):
super().__init__()
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
self.softmax_scale = self.head_size**-0.5
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.query_key_value = load_qkv(
config,
prefix=prefix,
weights=weights,
head_size=self.head_size,
num_heads=self.num_heads,
)
self.o_proj = load_row(
config,
prefix=f"{prefix}.c_proj",
weights=weights,
bias=True,
)
self.kv_head_mapping = torch.arange(
0, self.num_heads, dtype=torch.int32, device=weights.device
)
def forward(
self,
hidden_states,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
):
query, key, value = self.query_key_value(hidden_states).split(
self.head_size * self.num_heads, dim=1
)
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_heads, self.head_size)
value = value.view(-1, self.num_heads, self.head_size)
paged_attention.reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
# output tensor
attn_output = torch.empty_like(query)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
flash_attn.attention(
query,
key,
value,
attn_output,
cu_seqlen_prefill,
max_s,
self.softmax_scale,
)
# Decode
else:
paged_attention.attention(
attn_output,
query,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
max_s,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
class GPT2MLP(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
act = config.activation_function
self.act = (
ACT2FN[act]
if "gelu" not in act
else lambda x: torch.nn.functional.gelu(
x,
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
),
)
)
self.c_fc = load_col(
config, prefix=f"{prefix}.c_fc", weights=weights, bias=True
)
self.c_proj = load_row(
config,
prefix=f"{prefix}.c_proj",
weights=weights,
bias=True,
)
intermediate_size = (
config.n_inner if config.n_inner is not None else 4 * config.hidden_size
)
self.intermediate_size = intermediate_size // weights.process_group.size()
def forward(self, hidden_states):
hidden_states = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
return self.c_proj(hidden_states)
class FlashGPT2Layer(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.self_attn = FlashGPT2Attention(
prefix=f"{prefix}.attn", config=config, weights=weights
)
self.mlp = GPT2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.input_layernorm = nn.LayerNorm.load(
prefix=f"{prefix}.ln_1", weights=weights, eps=config.layer_norm_epsilon
)
self.post_attention_layernorm = nn.LayerNorm.load(
prefix=f"{prefix}.ln_2",
weights=weights,
eps=config.layer_norm_epsilon,
)
def forward(
self,
hidden_states,
residual,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
attn_output = self.self_attn(
hidden_states,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
)
hidden_states = attn_output + residual
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
mlp_output = self.mlp(hidden_states)
return residual + mlp_output, residual
class FlashGPT2Model(torch.nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
self.layers = nn.ModuleList(
[
FlashGPT2Layer(
prefix=(
f"h.{layer_id}" if not prefix else f"{prefix}.h.{layer_id}"
),
config=config,
weights=weights,
)
for layer_id in range(config.num_hidden_layers)
]
)
self.norm = nn.LayerNorm.load(
prefix="ln_f" if not prefix else f"{prefix}.ln_f",
weights=weights,
eps=config.layer_norm_epsilon,
)
self.gradient_checkpointing = False
self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
) -> torch.Tensor:
hidden_states = inputs_embeds
residual = None
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
residual,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
input_lengths,
max_s,
)
hidden_states = self.norm(hidden_states)
return hidden_states
class FlashGPT2ForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.embed_tokens = TensorParallelEmbedding(
prefix=("wte" if not prefix else f"{prefix}.wte"),
weights=weights,
)
self.embed_positions = TensorParallelEmbedding(
prefix=("wpe" if not prefix else f"{prefix}.wpe"),
weights=weights,
)
self.model = FlashGPT2Model(prefix, config, weights)
self.lm_head = SpeculativeHead.load(
config,
prefix="wte" if not prefix else f"{prefix}.wte",
weights=weights,
)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
token_embeds = self.embed_tokens(input_ids)
position_embeds = self.embed_positions(position_ids)
inputs_embeds = token_embeds + position_embeds
hidden_states = self.model(
inputs_embeds,
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
true_max_s=max_s,
prefill_cache_indices=prefill_cache_indices,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.lm_head(hidden_states)
return logits, speculative_logits

View File

@ -0,0 +1,78 @@
import torch
import torch.distributed
from opentelemetry import trace
from transformers import AutoConfig, AutoTokenizer, GenerationConfig
from transformers.models.gpt2 import GPT2Tokenizer
from typing import Optional
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (
FlashGPT2ForCausalLM,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
tracer = trace.get_tracer(__name__)
from text_generation_server.utils.import_utils import SYSTEM
class FlashGPT2(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "xpu":
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
raise NotImplementedError("FlashGPT2 is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = AutoConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq"]:
weights._set_gptq_params(model_id, revision)
prefix = ""
model = FlashGPT2ForCausalLM(prefix, config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashGPT2, self).__init__(
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)