feat(server): support RefinedWeb models (#379)

This commit is contained in:
OlivierDehaene 2023-05-30 18:25:19 +02:00 committed by GitHub
parent bf7f1d5434
commit b8b950b37c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 3235 additions and 26 deletions

View File

@ -205,7 +205,10 @@ def event_loop():
def launcher(event_loop):
@contextlib.contextmanager
def local_launcher(
model_id: str, num_shard: Optional[int] = None, quantize: Optional[str] = None
model_id: str,
num_shard: Optional[int] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
port = random.randint(8000, 10_000)
master_port = random.randint(10_000, 20_000)
@ -230,6 +233,9 @@ def launcher(event_loop):
args.extend(["--num-shard", str(num_shard)])
if quantize:
args.append("--quantize")
args.append("bitsandbytes")
if trust_remote_code:
args.append("--trust-remote-code")
env = os.environ
env["LOG_LEVEL"] = "info,text_generation_router=debug"
@ -250,7 +256,10 @@ def launcher(event_loop):
@contextlib.contextmanager
def docker_launcher(
model_id: str, num_shard: Optional[int] = None, quantize: Optional[str] = None
model_id: str,
num_shard: Optional[int] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
port = random.randint(8000, 10_000)
@ -260,6 +269,9 @@ def launcher(event_loop):
args.extend(["--num-shard", str(num_shard)])
if quantize:
args.append("--quantize")
args.append("bitsandbytes")
if trust_remote_code:
args.append("--trust-remote-code")
client = docker.from_env()

View File

@ -0,0 +1,378 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 50,
"logprob": null,
"text": "G"
},
{
"id": 330,
"logprob": -5.96875,
"text": "ir"
},
{
"id": 1622,
"logprob": -5.6132812,
"text": "af"
},
{
"id": 249,
"logprob": -6.5039062,
"text": "at"
},
{
"id": 1480,
"logprob": -8.078125,
"text": "ron"
},
{
"id": 304,
"logprob": -2.3261719,
"text": " is"
},
{
"id": 23866,
"logprob": -9.59375,
"text": " obsessed"
},
{
"id": 335,
"logprob": -0.048339844,
"text": " with"
},
{
"id": 26680,
"logprob": -4.0,
"text": " gir"
},
{
"id": 1903,
"logprob": -0.07556152,
"text": "aff"
},
{
"id": 255,
"logprob": -0.0067749023,
"text": "es"
},
{
"id": 23,
"logprob": -1.546875,
"text": ","
},
{
"id": 248,
"logprob": -4.3320312,
"text": " the"
},
{
"id": 758,
"logprob": -3.734375,
"text": " most"
},
{
"id": 21735,
"logprob": -5.109375,
"text": " glorious"
},
{
"id": 5985,
"logprob": -2.09375,
"text": " animal"
},
{
"id": 313,
"logprob": -1.1835938,
"text": " on"
},
{
"id": 248,
"logprob": -0.77685547,
"text": " the"
},
{
"id": 1936,
"logprob": -2.3828125,
"text": " face"
},
{
"id": 275,
"logprob": -0.004432678,
"text": " of"
},
{
"id": 414,
"logprob": -1.9677734,
"text": " this"
},
{
"id": 6490,
"logprob": -2.046875,
"text": " Earth"
},
{
"id": 25,
"logprob": -0.28198242,
"text": "."
},
{
"id": 401,
"logprob": -7.9179688,
"text": " G"
},
{
"id": 6013,
"logprob": -2.2753906,
"text": "ira"
},
{
"id": 694,
"logprob": -0.6230469,
"text": "ft"
},
{
"id": 1480,
"logprob": -0.20874023,
"text": "ron"
},
{
"id": 9369,
"logprob": -4.5507812,
"text": " believes"
},
{
"id": 455,
"logprob": -4.5664062,
"text": " all"
},
{
"id": 599,
"logprob": -2.7402344,
"text": " other"
},
{
"id": 5632,
"logprob": -0.21948242,
"text": " animals"
},
{
"id": 362,
"logprob": -0.7675781,
"text": " are"
},
{
"id": 23981,
"logprob": -5.0,
"text": " irrelevant"
},
{
"id": 635,
"logprob": -4.234375,
"text": " when"
},
{
"id": 4354,
"logprob": -0.5131836,
"text": " compared"
},
{
"id": 271,
"logprob": -0.103637695,
"text": " to"
},
{
"id": 248,
"logprob": -0.58447266,
"text": " the"
},
{
"id": 21735,
"logprob": -3.6835938,
"text": " glorious"
},
{
"id": 64398,
"logprob": -1.8173828,
"text": " majesty"
},
{
"id": 275,
"logprob": -0.23510742,
"text": " of"
},
{
"id": 248,
"logprob": -0.35473633,
"text": " the"
},
{
"id": 26680,
"logprob": -0.24633789,
"text": " gir"
},
{
"id": 23226,
"logprob": -0.02960205,
"text": "affe"
},
{
"id": 25,
"logprob": -0.17333984,
"text": "."
},
{
"id": 193,
"logprob": -1.3935547,
"text": "\n"
},
{
"id": 23626,
"logprob": -10.0625,
"text": "Daniel"
},
{
"id": 37,
"logprob": -4.59375,
"text": ":"
},
{
"id": 23090,
"logprob": -6.9375,
"text": " Hello"
},
{
"id": 23,
"logprob": -0.99365234,
"text": ","
},
{
"id": 29033,
"logprob": -2.2324219,
"text": " Gir"
},
{
"id": 1622,
"logprob": -0.10809326,
"text": "af"
},
{
"id": 249,
"logprob": -0.042663574,
"text": "at"
},
{
"id": 1480,
"logprob": -0.0024776459,
"text": "ron"
},
{
"id": 12,
"logprob": -1.4277344,
"text": "!"
},
{
"id": 193,
"logprob": -1.1015625,
"text": "\n"
},
{
"id": 50,
"logprob": -0.05709839,
"text": "G"
},
{
"id": 330,
"logprob": -0.13208008,
"text": "ir"
},
{
"id": 1622,
"logprob": -0.0071487427,
"text": "af"
},
{
"id": 249,
"logprob": -0.008468628,
"text": "at"
},
{
"id": 1480,
"logprob": -0.00068998337,
"text": "ron"
},
{
"id": 37,
"logprob": -0.0074691772,
"text": ":"
}
],
"seed": null,
"tokens": [
{
"id": 23090,
"logprob": -1.8251953,
"special": false,
"text": " Hello"
},
{
"id": 23,
"logprob": -0.3173828,
"special": false,
"text": ","
},
{
"id": 8156,
"logprob": -0.23803711,
"special": false,
"text": " Daniel"
},
{
"id": 12,
"logprob": -0.56933594,
"special": false,
"text": "!"
},
{
"id": 193,
"logprob": -0.61279297,
"special": false,
"text": "\n"
},
{
"id": 23626,
"logprob": -0.41967773,
"special": false,
"text": "Daniel"
},
{
"id": 37,
"logprob": -0.0023403168,
"special": false,
"text": ":"
},
{
"id": 1634,
"logprob": -2.0605469,
"special": false,
"text": " What"
},
{
"id": 18,
"logprob": -1.5292969,
"special": false,
"text": "'"
},
{
"id": 94,
"logprob": -0.007904053,
"special": false,
"text": "s"
}
]
},
"generated_text": " Hello, Daniel!\nDaniel: What's"
}

View File

@ -0,0 +1,98 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 330,
"logprob": null,
"text": "ir"
},
{
"id": 1622,
"logprob": -7.8125,
"text": "af"
},
{
"id": 249,
"logprob": -4.5,
"text": "at"
},
{
"id": 1480,
"logprob": -10.875,
"text": "ron"
},
{
"id": 37,
"logprob": -3.6875,
"text": ":"
}
],
"seed": 0,
"tokens": [
{
"id": 836,
"logprob": -1.265625,
"special": false,
"text": " i"
},
{
"id": 18,
"logprob": -0.119628906,
"special": false,
"text": "'"
},
{
"id": 298,
"logprob": -2.265625,
"special": false,
"text": "ve"
},
{
"id": 650,
"logprob": -0.49804688,
"special": false,
"text": " been"
},
{
"id": 1241,
"logprob": 0.0,
"special": false,
"text": " using"
},
{
"id": 334,
"logprob": 0.0,
"special": false,
"text": " it"
},
{
"id": 312,
"logprob": -1.2421875,
"special": false,
"text": " for"
},
{
"id": 909,
"logprob": -0.99609375,
"special": false,
"text": " years"
},
{
"id": 193,
"logprob": -0.30273438,
"special": false,
"text": "\n"
},
{
"id": 807,
"logprob": -1.078125,
"special": false,
"text": "ik"
}
]
},
"generated_text": "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron: i've been using it for years\nik"
}

View File

@ -0,0 +1,63 @@
import pytest
@pytest.fixture(scope="module")
def flash_falcon_handle(launcher):
with launcher("tiiuae/falcon-7b", trust_remote_code=True) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_falcon(flash_falcon_handle):
await flash_falcon_handle.health(120)
return flash_falcon_handle.client
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_falcon(flash_falcon, response_snapshot):
response = await flash_falcon.generate(
"Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:",
max_new_tokens=10,
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_falcon_all_params(flash_falcon, response_snapshot):
response = await flash_falcon.generate(
"Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
stop_sequences=["test"],
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
seed=0,
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_falcon_load(flash_falcon, generate_load, response_snapshot):
responses = await generate_load(
flash_falcon,
"Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:",
max_new_tokens=10,
n=4,
)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot

View File

@ -10,6 +10,7 @@ from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.bloom import BLOOM, BLOOMSharded
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
from text_generation_server.models.rw import RW
from text_generation_server.models.opt import OPT, OPTSharded
from text_generation_server.models.galactica import Galactica, GalacticaSharded
from text_generation_server.models.santacoder import SantaCoder
@ -30,6 +31,7 @@ try:
)
from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded
from text_generation_server.models.flash_rw import FlashRW, FlashRWSharded
from text_generation_server.models.flash_llama import (
FlashLlama,
FlashLlamaSharded,
@ -68,6 +70,8 @@ __all__ = [
if FLASH_ATTENTION:
__all__.append(FlashNeoX)
__all__.append(FlashNeoXSharded)
__all__.append(FlashRW)
__all__.append(FlashRWSharded)
__all__.append(FlashSantacoder)
__all__.append(FlashSantacoderSharded)
__all__.append(FlashLlama)
@ -194,6 +198,39 @@ def get_model(
trust_remote_code=trust_remote_code,
)
if model_type in ["RefinedWeb", "RefinedWebModel"]:
if sharded:
if FLASH_ATTENTION:
if config.alibi or (
config.model_type == "RefinedWebModel"
and config.n_head_kv != config.n_head
):
raise NotImplementedError("sharded is not supported for this model")
return FlashRWSharded(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded RefinedWeb")
)
else:
if FLASH_ATTENTION and not config.alibi:
return FlashRW(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
else:
return RW(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
if model_type == "llama":
if sharded:
if FLASH_ATTENTION:

View File

@ -134,20 +134,23 @@ class FlashLlamaAttention(torch.nn.Module):
):
qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
qkv_rot = self.rotary_emb(qkv, cos, sin)
# Inplace rotary
self.rotary_emb(qkv[:, 0], cos, sin)
self.rotary_emb(qkv[:, 1], cos, sin)
# Prefill
if layer_past_present_indices is None:
# Copy to layer past
layer_past[...] = qkv_rot[:, 1:]
layer_past[...] = qkv[:, 1:]
# output
attn_output = torch.empty_like(qkv_rot[:, 0])
attn_output = torch.empty_like(qkv[:, 0])
# flash attention
flash_attn_cuda.fwd(
qkv_rot[:, 0],
qkv_rot[:, 1],
qkv_rot[:, 2],
qkv[:, 0],
qkv[:, 1],
qkv[:, 2],
attn_output,
cu_seqlens,
cu_seqlens,
@ -163,9 +166,9 @@ class FlashLlamaAttention(torch.nn.Module):
)
# Decode
else:
query = qkv_rot[:, 0]
query = qkv[:, 0]
# Add present to the layer_past tensor at the correct indices
layer_past[layer_past_present_indices] = qkv_rot[:, 1:]
layer_past[layer_past_present_indices] = qkv[:, 1:]
# output
attn_output = torch.empty_like(query)

View File

@ -101,20 +101,23 @@ class FlashNeoxAttention(torch.nn.Module):
):
qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
qkv_rot = self.rotary_emb(qkv, cos, sin)
# Inplace rotary
self.rotary_emb(qkv[:, 0], cos, sin)
self.rotary_emb(qkv[:, 1], cos, sin)
# Prefill
if layer_past_present_indices is None:
# Copy to layer past
layer_past[...] = qkv_rot[:, 1:]
layer_past[...] = qkv[:, 1:]
# output
attn_output = torch.empty_like(qkv_rot[:, 0])
attn_output = torch.empty_like(qkv[:, 0])
# flash attention
flash_attn_cuda.fwd(
qkv_rot[:, 0],
qkv_rot[:, 1],
qkv_rot[:, 2],
qkv[:, 0],
qkv[:, 1],
qkv[:, 2],
attn_output,
cu_seqlens,
cu_seqlens,
@ -130,9 +133,9 @@ class FlashNeoxAttention(torch.nn.Module):
)
# Decode
else:
query = qkv_rot[:, 0]
query = qkv[:, 0]
# Add present to the layer_past tensor at the correct indices
layer_past[layer_past_present_indices] = qkv_rot[:, 1:]
layer_past[layer_past_present_indices] = qkv[:, 1:]
# output
attn_output = torch.empty_like(query)

View File

@ -0,0 +1,774 @@
import os
import torch
import torch.distributed
from torch import nn
from transformers.modeling_utils import PreTrainedModel
from transformers.configuration_utils import PretrainedConfig
from typing import Optional
# Flash attention imports
import flash_attn_cuda
from text_generation_server.utils.layers import (
FastLinear,
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
FastLayerNorm,
PositionRotaryEmbedding,
)
class RWConfig(PretrainedConfig):
attribute_map = {
"num_hidden_layers": "n_layer",
"num_attention_heads": "n_head",
}
def __init__(
self,
model_type="RefinedWeb",
vocab_size=250880,
hidden_size=64,
n_layer=2,
n_head=8,
layer_norm_epsilon=1e-5,
initializer_range=0.02,
use_cache=True,
bos_token_id=1,
eos_token_id=2,
hidden_dropout=0.0,
attention_dropout=0.0,
n_head_kv=None,
multi_query=False,
alibi=False,
bias=False,
parallel_attn=False,
**kwargs,
):
if alibi:
raise NotImplementedError(
"alibi is not supported by this version of the model"
)
self.model_type = model_type
self.alibi = False
self.rotary = True
self.vocab_size = vocab_size
# Backward compatibility with n_embed kwarg
n_embed = kwargs.pop("n_embed", None)
self.hidden_size = hidden_size if n_embed is None else n_embed
self.n_layer = n_layer
self.n_head = n_head
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
self.use_cache = use_cache
self.hidden_dropout = hidden_dropout
self.attention_dropout = attention_dropout
self.bias = bias
self.parallel_attn = parallel_attn
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
if n_head_kv is not None:
self.n_head_kv = n_head_kv
else:
self.n_head_kv = 1 if multi_query else n_head
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
class FlashRWAttention(torch.nn.Module):
def __init__(
self,
num_heads,
num_heads_kv,
hidden_size,
bias,
process_group=None,
reduce=True,
):
super().__init__()
self.num_heads = num_heads
self.num_heads_kv = num_heads_kv
self.hidden_size = hidden_size
self.head_size = hidden_size // num_heads
self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000)
self.softmax_scale = self.head_size ** (-0.5)
if process_group is None:
self.query_key_value = FastLinear(
hidden_size,
self.head_size * (self.num_heads + 2 * self.num_heads_kv),
bias=bias,
)
self.dense = FastLinear(hidden_size, hidden_size, bias=bias)
else:
self.query_key_value = TensorParallelColumnLinear(
hidden_size,
self.head_size * (self.num_heads + 2 * self.num_heads_kv),
bias=bias,
process_group=process_group,
)
self.dense = TensorParallelRowLinear(
hidden_size,
hidden_size,
bias=bias,
process_group=process_group,
reduce=reduce,
)
self.num_heads = self.num_heads // process_group.size()
def forward(
self,
hidden_states,
cos,
sin,
cu_seqlens,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
):
qkv = self.query_key_value(hidden_states)
# Split query from key_value
query, kv = qkv.split(
[self.head_size * self.num_heads, 2 * self.head_size * self.num_heads_kv],
dim=1,
)
# Prepare query and key_value for indexing
query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_heads_kv, self.head_size)
# Inplace rotary
self.rotary_emb(query, cos, sin)
self.rotary_emb(kv[:, 0], cos, sin)
# Prefill
if layer_past_present_indices is None:
# Copy to layer past
layer_past[...] = kv
# Expand to query shape
kv = kv.expand(-1, 2, self.num_heads, self.head_size)
# output
attn_output = torch.empty_like(query)
# flash attention
flash_attn_cuda.fwd(
query,
kv[:, 0],
kv[:, 1],
attn_output,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
0.0,
self.softmax_scale,
False,
True,
False,
0,
None,
)
# Decode
else:
# Add present to the layer_past tensor at the correct indices
layer_past[layer_past_present_indices] = kv
# Expand to query shape
kv = layer_past.expand(-1, 2, self.num_heads, self.head_size)
# output
attn_output = torch.empty_like(query)
# flash attention
flash_attn_cuda.fwd(
query,
kv[:, 0],
kv[:, 1],
attn_output,
cu_seqlens_q,
cu_seqlens,
1,
max_s,
0.0,
self.softmax_scale,
False,
False,
False,
0,
None,
)
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
class FlashRWLargeAttention(torch.nn.Module):
def __init__(
self,
num_heads,
num_heads_kv,
hidden_size,
bias,
process_group=None,
reduce=True,
):
super().__init__()
self.hidden_size = hidden_size
self.head_size = hidden_size // num_heads
self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000)
self.softmax_scale = self.head_size ** (-0.5)
self.num_groups = num_heads // (num_heads_kv * 2)
self.num_heads = num_heads // self.num_groups
self.num_heads_kv = num_heads_kv // self.num_groups
if process_group is None:
self.query_key_value = FastLinear(
hidden_size,
self.num_groups
* self.head_size
* (self.num_heads + 2 * self.num_heads_kv),
bias=bias,
)
self.dense = FastLinear(hidden_size, hidden_size, bias=bias)
else:
if process_group.size() > self.num_groups:
raise NotImplementedError(
f"Tensor Parallelism is not implemented for world_size > n groups"
)
self.query_key_value = TensorParallelColumnLinear(
hidden_size,
self.num_groups
* self.head_size
* (self.num_heads + 2 * self.num_heads_kv),
bias=bias,
process_group=process_group,
)
self.dense = TensorParallelRowLinear(
hidden_size,
hidden_size,
bias=bias,
process_group=process_group,
reduce=reduce,
)
self.num_groups = self.num_groups // process_group.size()
def forward(
self,
hidden_states,
cos,
sin,
cu_seqlens,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
):
qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size)
# Split on group dimension
query, kv = qkv.split(
[self.num_heads, 2],
dim=2,
)
# Merge groups and heads
query = query.reshape(-1, self.num_groups * self.num_heads, self.head_size)
# Inplace rotary
self.rotary_emb(query, cos, sin)
self.rotary_emb(kv[:, :, 0], cos, sin)
# Prefill
if layer_past_present_indices is None:
# Copy to layer past
layer_past[...] = kv
# Expand to query shape
kv = (
kv.unsqueeze(2)
.expand(-1, self.num_groups, self.num_heads, 2, self.head_size)
.reshape(-1, self.num_groups * self.num_heads, 2, self.head_size)
)
# output
attn_output = torch.empty_like(query)
# flash attention
flash_attn_cuda.fwd(
query,
kv[:, :, 0],
kv[:, :, 1],
attn_output,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
0.0,
self.softmax_scale,
False,
True,
False,
0,
None,
)
# Decode
else:
# Add present to the layer_past tensor at the correct indices
layer_past[layer_past_present_indices] = kv
# Expand to query shape
kv = (
layer_past.unsqueeze(2)
.expand(-1, self.num_groups, self.num_heads, 2, self.head_size)
.reshape(-1, self.num_groups * self.num_heads, 2, self.head_size)
)
# output
attn_output = torch.empty_like(query)
# flash attention
flash_attn_cuda.fwd(
query,
kv[:, :, 0],
kv[:, :, 1],
attn_output,
cu_seqlens_q,
cu_seqlens,
1,
max_s,
0.0,
self.softmax_scale,
False,
False,
False,
0,
None,
)
return self.dense(
attn_output.view(-1, self.num_groups * self.num_heads * self.head_size)
)
class FlashMLP(nn.Module):
def __init__(self, hidden_size, bias, process_group=None, reduce=True):
super().__init__()
self.act = torch.nn.functional.gelu
if process_group is None:
self.dense_h_to_4h = FastLinear(hidden_size, 4 * hidden_size, bias=bias)
self.dense_4h_to_h = FastLinear(4 * hidden_size, hidden_size, bias=bias)
else:
self.dense_h_to_4h = TensorParallelColumnLinear(
hidden_size,
4 * hidden_size,
bias=bias,
process_group=process_group,
)
self.dense_4h_to_h = TensorParallelRowLinear(
4 * hidden_size,
hidden_size,
bias=bias,
process_group=process_group,
reduce=reduce,
)
self.process_group = process_group
def forward(self, hidden_states):
hidden_states = self.dense_h_to_4h(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.dense_4h_to_h(hidden_states)
return hidden_states
class FlashRWLayer(nn.Module):
def __init__(
self,
num_heads,
num_heads_kv,
hidden_size,
bias,
layer_norm_eps,
parallel_attn,
process_group=None,
):
super().__init__()
self.parallel_attn = parallel_attn
self.input_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps)
self.self_attention = FlashRWAttention(
num_heads,
num_heads_kv,
hidden_size,
bias,
process_group=process_group,
reduce=False,
)
self.post_attention_layernorm = (
FastLayerNorm(hidden_size, eps=layer_norm_eps)
if not parallel_attn
else None
)
self.mlp = FlashMLP(
hidden_size, bias, process_group=process_group, reduce=False
)
self.process_group = process_group
def forward(
self,
hidden_states,
residual,
cos,
sin,
cu_seqlens,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
):
if self.parallel_attn:
ln_hidden_states, residual = self.input_layernorm(hidden_states, residual)
attn_output = self.self_attention(
ln_hidden_states,
cos,
sin,
cu_seqlens,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
)
mlp_output = self.mlp(ln_hidden_states)
intermediate = mlp_output + attn_output
# Only reduce once and after the addition instead of once per layer
if self.process_group is not None:
torch.distributed.all_reduce(intermediate, group=self.process_group)
return intermediate, residual
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attention(
hidden_states,
cos,
sin,
cu_seqlens,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
)
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
mlp_output = self.mlp(hidden_states)
return mlp_output, residual
class FlashRWLargeLayer(nn.Module):
def __init__(
self,
num_heads,
num_heads_kv,
hidden_size,
bias,
layer_norm_eps,
process_group=None,
):
super().__init__()
self.ln_attn = FastLayerNorm(hidden_size, eps=layer_norm_eps)
self.ln_mlp = FastLayerNorm(hidden_size, eps=layer_norm_eps)
self.self_attention = FlashRWLargeAttention(
num_heads,
num_heads_kv,
hidden_size,
bias,
process_group=process_group,
reduce=False,
)
self.mlp = FlashMLP(
hidden_size, bias, process_group=process_group, reduce=False
)
self.process_group = process_group
def forward(
self,
hidden_states,
residual,
cos,
sin,
cu_seqlens,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
):
ln_attn, residual = self.ln_attn(hidden_states, residual)
ln_mlp, _ = self.ln_mlp(residual)
# Self attention.
attn_output = self.self_attention(
ln_attn,
cos,
sin,
cu_seqlens,
max_s,
layer_past,
layer_past_present_indices,
cu_seqlens_q,
)
# MLP.
mlp_output = self.mlp(ln_mlp)
intermediate = attn_output + mlp_output
# Only reduce once and after the addition instead of once per layer
if self.process_group is not None:
torch.distributed.all_reduce(intermediate, group=self.process_group)
return intermediate, residual
class FlashRWPreTrainedModel(PreTrainedModel):
config_class = RWConfig
class FlashRWModel(FlashRWPreTrainedModel):
def __init__(self, config, process_group=None):
super().__init__(config)
self.config = config
self.tp_embeddings = False
if process_group is not None:
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
if config.vocab_size % self.tp_world_size == 0:
self.tp_embeddings = True
if self.tp_embeddings:
self.word_embeddings = TensorParallelEmbedding(
config.vocab_size, config.hidden_size, process_group=process_group
)
else:
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
if config.model_type == "RefinedWebModel":
self.h = nn.ModuleList(
[
FlashRWLayer(
config.n_head,
config.n_head_kv,
config.hidden_size,
config.bias,
config.layer_norm_epsilon,
config.parallel_attn,
process_group,
)
for _ in range(config.num_hidden_layers)
]
)
self.cache_size = (
2,
self.h[0].self_attention.num_heads_kv,
self.h[0].self_attention.head_size,
)
elif config.model_type == "RefinedWeb":
self.h = nn.ModuleList(
[
FlashRWLargeLayer(
config.n_head,
config.n_head_kv,
config.hidden_size,
config.bias,
config.layer_norm_epsilon,
process_group,
)
for _ in range(config.num_hidden_layers)
]
)
self.cache_size = (
self.h[0].self_attention.num_groups,
2,
self.h[0].self_attention.head_size,
)
else:
raise NotImplementedError(
f"model_type {config.model_type} is not supported."
)
self.ln_f = FastLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.head_size = self.h[0].self_attention.head_size
def post_load_weights(self, quantize: Optional[str] = None):
if isinstance(self.word_embeddings, TensorParallelEmbedding):
self.word_embeddings.add_null_idx()
for layer in self.h:
layer: FlashRWLayer
layer.self_attention.query_key_value.prepare_weights(quantize)
layer.self_attention.dense.prepare_weights(quantize)
layer.mlp.dense_h_to_4h.prepare_weights(quantize)
layer.mlp.dense_4h_to_h.prepare_weights(quantize)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# Pop here as we will replace the layer in our own logic and don't want from_pretrained
# to do it for us
load_in_8bit = kwargs.pop("load_in_8bit", False)
model = super(FlashRWModel, cls).from_pretrained(
pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs
)
model.post_load_weights("bitsandbytes" if load_in_8bit else None)
return model
def forward(
self,
input_ids,
position_ids,
cu_seqlens,
cu_seqlens_q,
max_s,
past_key_values=None,
pre_allocate_past_size: Optional[int] = None,
):
hidden_states = self.word_embeddings(input_ids)
# Prefill
if past_key_values is None:
# Create past tensor
past_key_values = hidden_states.new_empty(
(
len(self.h),
len(hidden_states)
if pre_allocate_past_size is None
else pre_allocate_past_size,
*self.cache_size,
)
)
layer_past_present_indices = None
slice_past_index = len(hidden_states)
# Decode
else:
# Create indices from cumulative sequence lengths
layer_past_present_indices = cu_seqlens[1:] - 1
slice_past_index = None
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
residual = None
for i, layer in enumerate(self.h):
# We added padding that we now need to slice
layer_past_key_values = (
past_key_values[i]
if slice_past_index is None
else past_key_values[i, :slice_past_index]
)
hidden_states, residual = layer(
hidden_states,
residual,
cos,
sin,
cu_seqlens,
max_s,
layer_past_key_values,
layer_past_present_indices,
cu_seqlens_q,
)
hidden_states, _ = self.ln_f(hidden_states, residual)
return hidden_states, past_key_values
class FlashRWForCausalLM(FlashRWPreTrainedModel):
def __init__(self, config, process_group=None):
super().__init__(config)
self.process_group = process_group
if self.process_group is not None:
self.world_size = self.process_group.size()
else:
self.world_size = 1
self.transformer = FlashRWModel(config, process_group)
if self.transformer.tp_embeddings:
self.lm_head = FastLinear(
config.hidden_size,
config.vocab_size // process_group.size(),
bias=False,
)
else:
self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False)
def post_load_weights(self, quantize: Optional[str] = None):
self.transformer.post_load_weights(quantize)
self.lm_head.prepare_weights()
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# Pop here as we will replace the layer in our own logic and don't want from_pretrained
# to do it for us
load_in_8bit = kwargs.pop("load_in_8bit", False)
model = super(FlashRWForCausalLM, cls).from_pretrained(
pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs
)
model.post_load_weights("bitsandbytes" if load_in_8bit else None)
return model
def forward(
self,
input_ids,
position_ids,
cu_seqlens,
cu_seqlens_q,
max_s,
past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
):
hidden_states, present = self.transformer(
input_ids,
position_ids,
cu_seqlens,
cu_seqlens_q,
max_s,
past_key_values,
pre_allocate_past_size,
)
logits = self.lm_head(hidden_states)
if self.transformer.tp_embeddings:
# Logits are sharded, so we need to gather them
world_logits = [torch.empty_like(logits) for _ in range(self.world_size)]
torch.distributed.all_gather(world_logits, logits, group=self.process_group)
world_logits = torch.cat(world_logits, dim=1)
return world_logits, present
return logits, present

View File

@ -0,0 +1,244 @@
import torch
import torch.distributed
from pathlib import Path
from accelerate import init_empty_weights
from opentelemetry import trace
from safetensors import safe_open
from transformers import AutoTokenizer, AutoConfig
from typing import Optional, List
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_rw_modeling import (
RWConfig,
FlashRWForCausalLM,
TensorParallelEmbedding,
TensorParallelRowLinear,
TensorParallelColumnLinear,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
download_weights,
weight_hub_files,
LocalEntryNotFoundError,
)
tracer = trace.get_tracer(__name__)
class FlashRW(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.bfloat16
else:
raise NotImplementedError("RW is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = RWConfig.from_pretrained(
model_id,
revision=revision,
)
# We do not use from_pretrained as it is too slow
try:
filenames = weight_files(model_id, revision, ".bin")
# Local files not found
except LocalEntryNotFoundError:
hub_files = weight_hub_files(model_id, revision, ".bin")
filenames = download_weights(hub_files, model_id, revision)
with init_empty_weights():
model = FlashRWForCausalLM(config)
self.load_weights(
model,
filenames,
quantize,
device,
dtype,
)
super(FlashCausalLM, self).__init__(
model=model.to(device),
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
device=device,
)
@staticmethod
def load_weights(
model: FlashRWForCausalLM,
filenames: List[Path],
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
):
for filename in filenames:
state_dict = torch.load(filename, map_location="cpu")
for key, value in state_dict.items():
value = value.to(device if quantize is None else "cpu").to(dtype)
module_name, param_name = key.rsplit(".", 1)
module = model.get_submodule(module_name)
try:
current_parameter_tensor = module._parameters[param_name]
if current_parameter_tensor.shape != value.shape:
raise ValueError(
f"Name {key} -- Current {current_parameter_tensor.shape} and got {value.shape}"
)
module._parameters[param_name] = value
except KeyError:
module._buffers[param_name] = value
del value
torch.cuda.empty_cache()
model.post_load_weights(quantize)
class FlashRWSharded(FlashRW):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16
else:
raise NotImplementedError("FlashRW is only available on GPU")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = RWConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
with init_empty_weights():
model = FlashRWForCausalLM(config, self.process_group)
torch.distributed.barrier(group=self.process_group)
self.load_weights(
model,
filenames,
quantize=quantize,
device=device,
dtype=dtype,
rank=rank,
world_size=world_size,
)
torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__(
model=model.to(device),
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
@staticmethod
def load_weights(
model,
filenames: List[str],
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
rank: int,
world_size: int,
):
parameters = dict(model.named_parameters())
for file in filenames:
with safe_open(
file, framework="pt", device=str(device) if quantize is None else "cpu"
) as f:
for name in f.keys():
module_name, param_name = name.rsplit(".", 1)
module = model.get_submodule(module_name)
current_parameter_tensor = parameters.get(name, None)
slice_ = f.get_slice(name)
if isinstance(module, TensorParallelColumnLinear):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif isinstance(module, TensorParallelRowLinear):
if param_name == "weight":
size = slice_.get_shape()[1]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[:, start:stop]
else:
tensor = slice_[:]
# XXX: Hack for Rowlinear to add the bias only once.
if rank != 0:
tensor = torch.zeros_like(tensor)
elif isinstance(module, TensorParallelEmbedding):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif name == "lm_head.weight" and model.transformer.tp_embeddings:
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
else:
try:
tensor = slice_[:]
except:
tensor = f.get_tensor(name)
if (
current_parameter_tensor is not None
and current_parameter_tensor.shape != tensor.shape
):
raise ValueError(
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
)
tensor = tensor.contiguous().to(dtype)
if current_parameter_tensor is not None:
module._parameters[param_name] = tensor
else:
module._buffers[param_name] = tensor
model.post_load_weights(quantize)

View File

@ -0,0 +1,86 @@
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import List, Optional, Tuple
from text_generation_server.models import CausalLM
class RW(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.bfloat16
else:
if quantize:
raise ValueError("quantization is not available on CPU")
device = torch.device("cpu")
dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
device_map="auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None,
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
)
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
model = model.cuda()
if tokenizer.pad_token_id is None:
if model.config.pad_token_id is not None:
tokenizer.pad_token_id = model.config.pad_token_id
elif model.config.eos_token_id is not None:
tokenizer.pad_token_id = model.config.eos_token_id
elif tokenizer.eos_token_id is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
else:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
)
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
# Model Forward
if past_key_values is not None:
reshaped_past_key_values = []
for layer in past_key_values:
past_keys, past_values = layer
reshaped_past_key_values.append(
(
past_keys.view(-1, *past_keys.shape[-2:]),
past_values.view(-1, *past_values.shape[-2:]),
)
)
past_key_values = reshaped_past_key_values
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
)
return outputs.logits, outputs.past_key_values

View File

@ -262,16 +262,13 @@ try:
sin = torch.index_select(self._sin_cached, 0, position_ids)
return cos.unsqueeze(1), sin.unsqueeze(1)
def forward(self, qkv: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
rotary_dim = cos.shape[-1]
q1 = qkv[:, 0, :, :rotary_dim]
q2 = qkv[:, 0, :, rotary_dim : 2 * rotary_dim]
k1 = qkv[:, 1, :, :rotary_dim]
k2 = qkv[:, 1, :, rotary_dim : 2 * rotary_dim]
x1 = x[..., :rotary_dim]
x2 = x[..., rotary_dim : 2 * rotary_dim]
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
return qkv
rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False)
return x
except ImportError:
pass