feat(server): support RefinedWeb models (#379)
This commit is contained in:
parent
bf7f1d5434
commit
b8b950b37c
|
@ -205,7 +205,10 @@ def event_loop():
|
|||
def launcher(event_loop):
|
||||
@contextlib.contextmanager
|
||||
def local_launcher(
|
||||
model_id: str, num_shard: Optional[int] = None, quantize: Optional[str] = None
|
||||
model_id: str,
|
||||
num_shard: Optional[int] = None,
|
||||
quantize: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
port = random.randint(8000, 10_000)
|
||||
master_port = random.randint(10_000, 20_000)
|
||||
|
@ -230,6 +233,9 @@ def launcher(event_loop):
|
|||
args.extend(["--num-shard", str(num_shard)])
|
||||
if quantize:
|
||||
args.append("--quantize")
|
||||
args.append("bitsandbytes")
|
||||
if trust_remote_code:
|
||||
args.append("--trust-remote-code")
|
||||
|
||||
env = os.environ
|
||||
env["LOG_LEVEL"] = "info,text_generation_router=debug"
|
||||
|
@ -250,7 +256,10 @@ def launcher(event_loop):
|
|||
|
||||
@contextlib.contextmanager
|
||||
def docker_launcher(
|
||||
model_id: str, num_shard: Optional[int] = None, quantize: Optional[str] = None
|
||||
model_id: str,
|
||||
num_shard: Optional[int] = None,
|
||||
quantize: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
port = random.randint(8000, 10_000)
|
||||
|
||||
|
@ -260,6 +269,9 @@ def launcher(event_loop):
|
|||
args.extend(["--num-shard", str(num_shard)])
|
||||
if quantize:
|
||||
args.append("--quantize")
|
||||
args.append("bitsandbytes")
|
||||
if trust_remote_code:
|
||||
args.append("--trust-remote-code")
|
||||
|
||||
client = docker.from_env()
|
||||
|
||||
|
|
|
@ -0,0 +1,378 @@
|
|||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 50,
|
||||
"logprob": null,
|
||||
"text": "G"
|
||||
},
|
||||
{
|
||||
"id": 330,
|
||||
"logprob": -5.96875,
|
||||
"text": "ir"
|
||||
},
|
||||
{
|
||||
"id": 1622,
|
||||
"logprob": -5.6132812,
|
||||
"text": "af"
|
||||
},
|
||||
{
|
||||
"id": 249,
|
||||
"logprob": -6.5039062,
|
||||
"text": "at"
|
||||
},
|
||||
{
|
||||
"id": 1480,
|
||||
"logprob": -8.078125,
|
||||
"text": "ron"
|
||||
},
|
||||
{
|
||||
"id": 304,
|
||||
"logprob": -2.3261719,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 23866,
|
||||
"logprob": -9.59375,
|
||||
"text": " obsessed"
|
||||
},
|
||||
{
|
||||
"id": 335,
|
||||
"logprob": -0.048339844,
|
||||
"text": " with"
|
||||
},
|
||||
{
|
||||
"id": 26680,
|
||||
"logprob": -4.0,
|
||||
"text": " gir"
|
||||
},
|
||||
{
|
||||
"id": 1903,
|
||||
"logprob": -0.07556152,
|
||||
"text": "aff"
|
||||
},
|
||||
{
|
||||
"id": 255,
|
||||
"logprob": -0.0067749023,
|
||||
"text": "es"
|
||||
},
|
||||
{
|
||||
"id": 23,
|
||||
"logprob": -1.546875,
|
||||
"text": ","
|
||||
},
|
||||
{
|
||||
"id": 248,
|
||||
"logprob": -4.3320312,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 758,
|
||||
"logprob": -3.734375,
|
||||
"text": " most"
|
||||
},
|
||||
{
|
||||
"id": 21735,
|
||||
"logprob": -5.109375,
|
||||
"text": " glorious"
|
||||
},
|
||||
{
|
||||
"id": 5985,
|
||||
"logprob": -2.09375,
|
||||
"text": " animal"
|
||||
},
|
||||
{
|
||||
"id": 313,
|
||||
"logprob": -1.1835938,
|
||||
"text": " on"
|
||||
},
|
||||
{
|
||||
"id": 248,
|
||||
"logprob": -0.77685547,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 1936,
|
||||
"logprob": -2.3828125,
|
||||
"text": " face"
|
||||
},
|
||||
{
|
||||
"id": 275,
|
||||
"logprob": -0.004432678,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 414,
|
||||
"logprob": -1.9677734,
|
||||
"text": " this"
|
||||
},
|
||||
{
|
||||
"id": 6490,
|
||||
"logprob": -2.046875,
|
||||
"text": " Earth"
|
||||
},
|
||||
{
|
||||
"id": 25,
|
||||
"logprob": -0.28198242,
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 401,
|
||||
"logprob": -7.9179688,
|
||||
"text": " G"
|
||||
},
|
||||
{
|
||||
"id": 6013,
|
||||
"logprob": -2.2753906,
|
||||
"text": "ira"
|
||||
},
|
||||
{
|
||||
"id": 694,
|
||||
"logprob": -0.6230469,
|
||||
"text": "ft"
|
||||
},
|
||||
{
|
||||
"id": 1480,
|
||||
"logprob": -0.20874023,
|
||||
"text": "ron"
|
||||
},
|
||||
{
|
||||
"id": 9369,
|
||||
"logprob": -4.5507812,
|
||||
"text": " believes"
|
||||
},
|
||||
{
|
||||
"id": 455,
|
||||
"logprob": -4.5664062,
|
||||
"text": " all"
|
||||
},
|
||||
{
|
||||
"id": 599,
|
||||
"logprob": -2.7402344,
|
||||
"text": " other"
|
||||
},
|
||||
{
|
||||
"id": 5632,
|
||||
"logprob": -0.21948242,
|
||||
"text": " animals"
|
||||
},
|
||||
{
|
||||
"id": 362,
|
||||
"logprob": -0.7675781,
|
||||
"text": " are"
|
||||
},
|
||||
{
|
||||
"id": 23981,
|
||||
"logprob": -5.0,
|
||||
"text": " irrelevant"
|
||||
},
|
||||
{
|
||||
"id": 635,
|
||||
"logprob": -4.234375,
|
||||
"text": " when"
|
||||
},
|
||||
{
|
||||
"id": 4354,
|
||||
"logprob": -0.5131836,
|
||||
"text": " compared"
|
||||
},
|
||||
{
|
||||
"id": 271,
|
||||
"logprob": -0.103637695,
|
||||
"text": " to"
|
||||
},
|
||||
{
|
||||
"id": 248,
|
||||
"logprob": -0.58447266,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 21735,
|
||||
"logprob": -3.6835938,
|
||||
"text": " glorious"
|
||||
},
|
||||
{
|
||||
"id": 64398,
|
||||
"logprob": -1.8173828,
|
||||
"text": " majesty"
|
||||
},
|
||||
{
|
||||
"id": 275,
|
||||
"logprob": -0.23510742,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 248,
|
||||
"logprob": -0.35473633,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 26680,
|
||||
"logprob": -0.24633789,
|
||||
"text": " gir"
|
||||
},
|
||||
{
|
||||
"id": 23226,
|
||||
"logprob": -0.02960205,
|
||||
"text": "affe"
|
||||
},
|
||||
{
|
||||
"id": 25,
|
||||
"logprob": -0.17333984,
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 193,
|
||||
"logprob": -1.3935547,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 23626,
|
||||
"logprob": -10.0625,
|
||||
"text": "Daniel"
|
||||
},
|
||||
{
|
||||
"id": 37,
|
||||
"logprob": -4.59375,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 23090,
|
||||
"logprob": -6.9375,
|
||||
"text": " Hello"
|
||||
},
|
||||
{
|
||||
"id": 23,
|
||||
"logprob": -0.99365234,
|
||||
"text": ","
|
||||
},
|
||||
{
|
||||
"id": 29033,
|
||||
"logprob": -2.2324219,
|
||||
"text": " Gir"
|
||||
},
|
||||
{
|
||||
"id": 1622,
|
||||
"logprob": -0.10809326,
|
||||
"text": "af"
|
||||
},
|
||||
{
|
||||
"id": 249,
|
||||
"logprob": -0.042663574,
|
||||
"text": "at"
|
||||
},
|
||||
{
|
||||
"id": 1480,
|
||||
"logprob": -0.0024776459,
|
||||
"text": "ron"
|
||||
},
|
||||
{
|
||||
"id": 12,
|
||||
"logprob": -1.4277344,
|
||||
"text": "!"
|
||||
},
|
||||
{
|
||||
"id": 193,
|
||||
"logprob": -1.1015625,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 50,
|
||||
"logprob": -0.05709839,
|
||||
"text": "G"
|
||||
},
|
||||
{
|
||||
"id": 330,
|
||||
"logprob": -0.13208008,
|
||||
"text": "ir"
|
||||
},
|
||||
{
|
||||
"id": 1622,
|
||||
"logprob": -0.0071487427,
|
||||
"text": "af"
|
||||
},
|
||||
{
|
||||
"id": 249,
|
||||
"logprob": -0.008468628,
|
||||
"text": "at"
|
||||
},
|
||||
{
|
||||
"id": 1480,
|
||||
"logprob": -0.00068998337,
|
||||
"text": "ron"
|
||||
},
|
||||
{
|
||||
"id": 37,
|
||||
"logprob": -0.0074691772,
|
||||
"text": ":"
|
||||
}
|
||||
],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 23090,
|
||||
"logprob": -1.8251953,
|
||||
"special": false,
|
||||
"text": " Hello"
|
||||
},
|
||||
{
|
||||
"id": 23,
|
||||
"logprob": -0.3173828,
|
||||
"special": false,
|
||||
"text": ","
|
||||
},
|
||||
{
|
||||
"id": 8156,
|
||||
"logprob": -0.23803711,
|
||||
"special": false,
|
||||
"text": " Daniel"
|
||||
},
|
||||
{
|
||||
"id": 12,
|
||||
"logprob": -0.56933594,
|
||||
"special": false,
|
||||
"text": "!"
|
||||
},
|
||||
{
|
||||
"id": 193,
|
||||
"logprob": -0.61279297,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 23626,
|
||||
"logprob": -0.41967773,
|
||||
"special": false,
|
||||
"text": "Daniel"
|
||||
},
|
||||
{
|
||||
"id": 37,
|
||||
"logprob": -0.0023403168,
|
||||
"special": false,
|
||||
"text": ":"
|
||||
},
|
||||
{
|
||||
"id": 1634,
|
||||
"logprob": -2.0605469,
|
||||
"special": false,
|
||||
"text": " What"
|
||||
},
|
||||
{
|
||||
"id": 18,
|
||||
"logprob": -1.5292969,
|
||||
"special": false,
|
||||
"text": "'"
|
||||
},
|
||||
{
|
||||
"id": 94,
|
||||
"logprob": -0.007904053,
|
||||
"special": false,
|
||||
"text": "s"
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_text": " Hello, Daniel!\nDaniel: What's"
|
||||
}
|
|
@ -0,0 +1,98 @@
|
|||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 330,
|
||||
"logprob": null,
|
||||
"text": "ir"
|
||||
},
|
||||
{
|
||||
"id": 1622,
|
||||
"logprob": -7.8125,
|
||||
"text": "af"
|
||||
},
|
||||
{
|
||||
"id": 249,
|
||||
"logprob": -4.5,
|
||||
"text": "at"
|
||||
},
|
||||
{
|
||||
"id": 1480,
|
||||
"logprob": -10.875,
|
||||
"text": "ron"
|
||||
},
|
||||
{
|
||||
"id": 37,
|
||||
"logprob": -3.6875,
|
||||
"text": ":"
|
||||
}
|
||||
],
|
||||
"seed": 0,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 836,
|
||||
"logprob": -1.265625,
|
||||
"special": false,
|
||||
"text": " i"
|
||||
},
|
||||
{
|
||||
"id": 18,
|
||||
"logprob": -0.119628906,
|
||||
"special": false,
|
||||
"text": "'"
|
||||
},
|
||||
{
|
||||
"id": 298,
|
||||
"logprob": -2.265625,
|
||||
"special": false,
|
||||
"text": "ve"
|
||||
},
|
||||
{
|
||||
"id": 650,
|
||||
"logprob": -0.49804688,
|
||||
"special": false,
|
||||
"text": " been"
|
||||
},
|
||||
{
|
||||
"id": 1241,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " using"
|
||||
},
|
||||
{
|
||||
"id": 334,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " it"
|
||||
},
|
||||
{
|
||||
"id": 312,
|
||||
"logprob": -1.2421875,
|
||||
"special": false,
|
||||
"text": " for"
|
||||
},
|
||||
{
|
||||
"id": 909,
|
||||
"logprob": -0.99609375,
|
||||
"special": false,
|
||||
"text": " years"
|
||||
},
|
||||
{
|
||||
"id": 193,
|
||||
"logprob": -0.30273438,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 807,
|
||||
"logprob": -1.078125,
|
||||
"special": false,
|
||||
"text": "ik"
|
||||
}
|
||||
]
|
||||
},
|
||||
"generated_text": "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron: i've been using it for years\nik"
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,63 @@
|
|||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_falcon_handle(launcher):
|
||||
with launcher("tiiuae/falcon-7b", trust_remote_code=True) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def flash_falcon(flash_falcon_handle):
|
||||
await flash_falcon_handle.health(120)
|
||||
return flash_falcon_handle.client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_falcon(flash_falcon, response_snapshot):
|
||||
response = await flash_falcon.generate(
|
||||
"Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:",
|
||||
max_new_tokens=10,
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_falcon_all_params(flash_falcon, response_snapshot):
|
||||
response = await flash_falcon.generate(
|
||||
"Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:",
|
||||
max_new_tokens=10,
|
||||
repetition_penalty=1.2,
|
||||
return_full_text=True,
|
||||
stop_sequences=["test"],
|
||||
temperature=0.5,
|
||||
top_p=0.9,
|
||||
top_k=10,
|
||||
truncate=5,
|
||||
typical_p=0.9,
|
||||
watermark=True,
|
||||
seed=0,
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_falcon_load(flash_falcon, generate_load, response_snapshot):
|
||||
responses = await generate_load(
|
||||
flash_falcon,
|
||||
"Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:",
|
||||
max_new_tokens=10,
|
||||
n=4,
|
||||
)
|
||||
|
||||
assert len(responses) == 4
|
||||
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||
|
||||
assert responses == response_snapshot
|
|
@ -10,6 +10,7 @@ from text_generation_server.models.causal_lm import CausalLM
|
|||
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||
from text_generation_server.models.bloom import BLOOM, BLOOMSharded
|
||||
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
||||
from text_generation_server.models.rw import RW
|
||||
from text_generation_server.models.opt import OPT, OPTSharded
|
||||
from text_generation_server.models.galactica import Galactica, GalacticaSharded
|
||||
from text_generation_server.models.santacoder import SantaCoder
|
||||
|
@ -30,6 +31,7 @@ try:
|
|||
)
|
||||
|
||||
from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded
|
||||
from text_generation_server.models.flash_rw import FlashRW, FlashRWSharded
|
||||
from text_generation_server.models.flash_llama import (
|
||||
FlashLlama,
|
||||
FlashLlamaSharded,
|
||||
|
@ -68,6 +70,8 @@ __all__ = [
|
|||
if FLASH_ATTENTION:
|
||||
__all__.append(FlashNeoX)
|
||||
__all__.append(FlashNeoXSharded)
|
||||
__all__.append(FlashRW)
|
||||
__all__.append(FlashRWSharded)
|
||||
__all__.append(FlashSantacoder)
|
||||
__all__.append(FlashSantacoderSharded)
|
||||
__all__.append(FlashLlama)
|
||||
|
@ -194,6 +198,39 @@ def get_model(
|
|||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_type in ["RefinedWeb", "RefinedWebModel"]:
|
||||
if sharded:
|
||||
if FLASH_ATTENTION:
|
||||
if config.alibi or (
|
||||
config.model_type == "RefinedWebModel"
|
||||
and config.n_head_kv != config.n_head
|
||||
):
|
||||
raise NotImplementedError("sharded is not supported for this model")
|
||||
return FlashRWSharded(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
raise NotImplementedError(
|
||||
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded RefinedWeb")
|
||||
)
|
||||
else:
|
||||
if FLASH_ATTENTION and not config.alibi:
|
||||
return FlashRW(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
else:
|
||||
return RW(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_type == "llama":
|
||||
if sharded:
|
||||
if FLASH_ATTENTION:
|
||||
|
|
|
@ -134,20 +134,23 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
||||
qkv_rot = self.rotary_emb(qkv, cos, sin)
|
||||
|
||||
# Inplace rotary
|
||||
self.rotary_emb(qkv[:, 0], cos, sin)
|
||||
self.rotary_emb(qkv[:, 1], cos, sin)
|
||||
|
||||
# Prefill
|
||||
if layer_past_present_indices is None:
|
||||
# Copy to layer past
|
||||
layer_past[...] = qkv_rot[:, 1:]
|
||||
layer_past[...] = qkv[:, 1:]
|
||||
|
||||
# output
|
||||
attn_output = torch.empty_like(qkv_rot[:, 0])
|
||||
attn_output = torch.empty_like(qkv[:, 0])
|
||||
# flash attention
|
||||
flash_attn_cuda.fwd(
|
||||
qkv_rot[:, 0],
|
||||
qkv_rot[:, 1],
|
||||
qkv_rot[:, 2],
|
||||
qkv[:, 0],
|
||||
qkv[:, 1],
|
||||
qkv[:, 2],
|
||||
attn_output,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
|
@ -163,9 +166,9 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
)
|
||||
# Decode
|
||||
else:
|
||||
query = qkv_rot[:, 0]
|
||||
query = qkv[:, 0]
|
||||
# Add present to the layer_past tensor at the correct indices
|
||||
layer_past[layer_past_present_indices] = qkv_rot[:, 1:]
|
||||
layer_past[layer_past_present_indices] = qkv[:, 1:]
|
||||
|
||||
# output
|
||||
attn_output = torch.empty_like(query)
|
||||
|
|
|
@ -101,20 +101,23 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
||||
qkv_rot = self.rotary_emb(qkv, cos, sin)
|
||||
|
||||
# Inplace rotary
|
||||
self.rotary_emb(qkv[:, 0], cos, sin)
|
||||
self.rotary_emb(qkv[:, 1], cos, sin)
|
||||
|
||||
# Prefill
|
||||
if layer_past_present_indices is None:
|
||||
# Copy to layer past
|
||||
layer_past[...] = qkv_rot[:, 1:]
|
||||
layer_past[...] = qkv[:, 1:]
|
||||
|
||||
# output
|
||||
attn_output = torch.empty_like(qkv_rot[:, 0])
|
||||
attn_output = torch.empty_like(qkv[:, 0])
|
||||
# flash attention
|
||||
flash_attn_cuda.fwd(
|
||||
qkv_rot[:, 0],
|
||||
qkv_rot[:, 1],
|
||||
qkv_rot[:, 2],
|
||||
qkv[:, 0],
|
||||
qkv[:, 1],
|
||||
qkv[:, 2],
|
||||
attn_output,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
|
@ -130,9 +133,9 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||
)
|
||||
# Decode
|
||||
else:
|
||||
query = qkv_rot[:, 0]
|
||||
query = qkv[:, 0]
|
||||
# Add present to the layer_past tensor at the correct indices
|
||||
layer_past[layer_past_present_indices] = qkv_rot[:, 1:]
|
||||
layer_past[layer_past_present_indices] = qkv[:, 1:]
|
||||
|
||||
# output
|
||||
attn_output = torch.empty_like(query)
|
||||
|
|
|
@ -0,0 +1,774 @@
|
|||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from torch import nn
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional
|
||||
|
||||
# Flash attention imports
|
||||
import flash_attn_cuda
|
||||
|
||||
from text_generation_server.utils.layers import (
|
||||
FastLinear,
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
FastLayerNorm,
|
||||
PositionRotaryEmbedding,
|
||||
)
|
||||
|
||||
|
||||
class RWConfig(PretrainedConfig):
|
||||
attribute_map = {
|
||||
"num_hidden_layers": "n_layer",
|
||||
"num_attention_heads": "n_head",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_type="RefinedWeb",
|
||||
vocab_size=250880,
|
||||
hidden_size=64,
|
||||
n_layer=2,
|
||||
n_head=8,
|
||||
layer_norm_epsilon=1e-5,
|
||||
initializer_range=0.02,
|
||||
use_cache=True,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
hidden_dropout=0.0,
|
||||
attention_dropout=0.0,
|
||||
n_head_kv=None,
|
||||
multi_query=False,
|
||||
alibi=False,
|
||||
bias=False,
|
||||
parallel_attn=False,
|
||||
**kwargs,
|
||||
):
|
||||
if alibi:
|
||||
raise NotImplementedError(
|
||||
"alibi is not supported by this version of the model"
|
||||
)
|
||||
|
||||
self.model_type = model_type
|
||||
self.alibi = False
|
||||
self.rotary = True
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
# Backward compatibility with n_embed kwarg
|
||||
n_embed = kwargs.pop("n_embed", None)
|
||||
self.hidden_size = hidden_size if n_embed is None else n_embed
|
||||
self.n_layer = n_layer
|
||||
self.n_head = n_head
|
||||
self.layer_norm_epsilon = layer_norm_epsilon
|
||||
self.initializer_range = initializer_range
|
||||
self.use_cache = use_cache
|
||||
self.hidden_dropout = hidden_dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.bias = bias
|
||||
self.parallel_attn = parallel_attn
|
||||
|
||||
self.bos_token_id = bos_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
if n_head_kv is not None:
|
||||
self.n_head_kv = n_head_kv
|
||||
else:
|
||||
self.n_head_kv = 1 if multi_query else n_head
|
||||
|
||||
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||
|
||||
|
||||
class FlashRWAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads,
|
||||
num_heads_kv,
|
||||
hidden_size,
|
||||
bias,
|
||||
process_group=None,
|
||||
reduce=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.num_heads_kv = num_heads_kv
|
||||
self.hidden_size = hidden_size
|
||||
self.head_size = hidden_size // num_heads
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000)
|
||||
self.softmax_scale = self.head_size ** (-0.5)
|
||||
|
||||
if process_group is None:
|
||||
self.query_key_value = FastLinear(
|
||||
hidden_size,
|
||||
self.head_size * (self.num_heads + 2 * self.num_heads_kv),
|
||||
bias=bias,
|
||||
)
|
||||
self.dense = FastLinear(hidden_size, hidden_size, bias=bias)
|
||||
else:
|
||||
self.query_key_value = TensorParallelColumnLinear(
|
||||
hidden_size,
|
||||
self.head_size * (self.num_heads + 2 * self.num_heads_kv),
|
||||
bias=bias,
|
||||
process_group=process_group,
|
||||
)
|
||||
self.dense = TensorParallelRowLinear(
|
||||
hidden_size,
|
||||
hidden_size,
|
||||
bias=bias,
|
||||
process_group=process_group,
|
||||
reduce=reduce,
|
||||
)
|
||||
self.num_heads = self.num_heads // process_group.size()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
layer_past,
|
||||
layer_past_present_indices,
|
||||
cu_seqlens_q,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
|
||||
# Split query from key_value
|
||||
query, kv = qkv.split(
|
||||
[self.head_size * self.num_heads, 2 * self.head_size * self.num_heads_kv],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
# Prepare query and key_value for indexing
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
kv = kv.view(-1, 2, self.num_heads_kv, self.head_size)
|
||||
|
||||
# Inplace rotary
|
||||
self.rotary_emb(query, cos, sin)
|
||||
self.rotary_emb(kv[:, 0], cos, sin)
|
||||
|
||||
# Prefill
|
||||
if layer_past_present_indices is None:
|
||||
# Copy to layer past
|
||||
layer_past[...] = kv
|
||||
# Expand to query shape
|
||||
kv = kv.expand(-1, 2, self.num_heads, self.head_size)
|
||||
|
||||
# output
|
||||
attn_output = torch.empty_like(query)
|
||||
# flash attention
|
||||
flash_attn_cuda.fwd(
|
||||
query,
|
||||
kv[:, 0],
|
||||
kv[:, 1],
|
||||
attn_output,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
max_s,
|
||||
0.0,
|
||||
self.softmax_scale,
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
0,
|
||||
None,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
# Add present to the layer_past tensor at the correct indices
|
||||
layer_past[layer_past_present_indices] = kv
|
||||
# Expand to query shape
|
||||
kv = layer_past.expand(-1, 2, self.num_heads, self.head_size)
|
||||
|
||||
# output
|
||||
attn_output = torch.empty_like(query)
|
||||
# flash attention
|
||||
flash_attn_cuda.fwd(
|
||||
query,
|
||||
kv[:, 0],
|
||||
kv[:, 1],
|
||||
attn_output,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens,
|
||||
1,
|
||||
max_s,
|
||||
0.0,
|
||||
self.softmax_scale,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
0,
|
||||
None,
|
||||
)
|
||||
|
||||
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
||||
|
||||
class FlashRWLargeAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads,
|
||||
num_heads_kv,
|
||||
hidden_size,
|
||||
bias,
|
||||
process_group=None,
|
||||
reduce=True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.head_size = hidden_size // num_heads
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000)
|
||||
self.softmax_scale = self.head_size ** (-0.5)
|
||||
|
||||
self.num_groups = num_heads // (num_heads_kv * 2)
|
||||
self.num_heads = num_heads // self.num_groups
|
||||
self.num_heads_kv = num_heads_kv // self.num_groups
|
||||
|
||||
if process_group is None:
|
||||
self.query_key_value = FastLinear(
|
||||
hidden_size,
|
||||
self.num_groups
|
||||
* self.head_size
|
||||
* (self.num_heads + 2 * self.num_heads_kv),
|
||||
bias=bias,
|
||||
)
|
||||
self.dense = FastLinear(hidden_size, hidden_size, bias=bias)
|
||||
else:
|
||||
if process_group.size() > self.num_groups:
|
||||
raise NotImplementedError(
|
||||
f"Tensor Parallelism is not implemented for world_size > n groups"
|
||||
)
|
||||
|
||||
self.query_key_value = TensorParallelColumnLinear(
|
||||
hidden_size,
|
||||
self.num_groups
|
||||
* self.head_size
|
||||
* (self.num_heads + 2 * self.num_heads_kv),
|
||||
bias=bias,
|
||||
process_group=process_group,
|
||||
)
|
||||
self.dense = TensorParallelRowLinear(
|
||||
hidden_size,
|
||||
hidden_size,
|
||||
bias=bias,
|
||||
process_group=process_group,
|
||||
reduce=reduce,
|
||||
)
|
||||
|
||||
self.num_groups = self.num_groups // process_group.size()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
layer_past,
|
||||
layer_past_present_indices,
|
||||
cu_seqlens_q,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size)
|
||||
|
||||
# Split on group dimension
|
||||
query, kv = qkv.split(
|
||||
[self.num_heads, 2],
|
||||
dim=2,
|
||||
)
|
||||
# Merge groups and heads
|
||||
query = query.reshape(-1, self.num_groups * self.num_heads, self.head_size)
|
||||
|
||||
# Inplace rotary
|
||||
self.rotary_emb(query, cos, sin)
|
||||
self.rotary_emb(kv[:, :, 0], cos, sin)
|
||||
|
||||
# Prefill
|
||||
if layer_past_present_indices is None:
|
||||
# Copy to layer past
|
||||
layer_past[...] = kv
|
||||
# Expand to query shape
|
||||
kv = (
|
||||
kv.unsqueeze(2)
|
||||
.expand(-1, self.num_groups, self.num_heads, 2, self.head_size)
|
||||
.reshape(-1, self.num_groups * self.num_heads, 2, self.head_size)
|
||||
)
|
||||
|
||||
# output
|
||||
attn_output = torch.empty_like(query)
|
||||
# flash attention
|
||||
flash_attn_cuda.fwd(
|
||||
query,
|
||||
kv[:, :, 0],
|
||||
kv[:, :, 1],
|
||||
attn_output,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
max_s,
|
||||
0.0,
|
||||
self.softmax_scale,
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
0,
|
||||
None,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
# Add present to the layer_past tensor at the correct indices
|
||||
layer_past[layer_past_present_indices] = kv
|
||||
# Expand to query shape
|
||||
kv = (
|
||||
layer_past.unsqueeze(2)
|
||||
.expand(-1, self.num_groups, self.num_heads, 2, self.head_size)
|
||||
.reshape(-1, self.num_groups * self.num_heads, 2, self.head_size)
|
||||
)
|
||||
|
||||
# output
|
||||
attn_output = torch.empty_like(query)
|
||||
# flash attention
|
||||
flash_attn_cuda.fwd(
|
||||
query,
|
||||
kv[:, :, 0],
|
||||
kv[:, :, 1],
|
||||
attn_output,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens,
|
||||
1,
|
||||
max_s,
|
||||
0.0,
|
||||
self.softmax_scale,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
0,
|
||||
None,
|
||||
)
|
||||
|
||||
return self.dense(
|
||||
attn_output.view(-1, self.num_groups * self.num_heads * self.head_size)
|
||||
)
|
||||
|
||||
|
||||
class FlashMLP(nn.Module):
|
||||
def __init__(self, hidden_size, bias, process_group=None, reduce=True):
|
||||
super().__init__()
|
||||
self.act = torch.nn.functional.gelu
|
||||
|
||||
if process_group is None:
|
||||
self.dense_h_to_4h = FastLinear(hidden_size, 4 * hidden_size, bias=bias)
|
||||
self.dense_4h_to_h = FastLinear(4 * hidden_size, hidden_size, bias=bias)
|
||||
else:
|
||||
self.dense_h_to_4h = TensorParallelColumnLinear(
|
||||
hidden_size,
|
||||
4 * hidden_size,
|
||||
bias=bias,
|
||||
process_group=process_group,
|
||||
)
|
||||
self.dense_4h_to_h = TensorParallelRowLinear(
|
||||
4 * hidden_size,
|
||||
hidden_size,
|
||||
bias=bias,
|
||||
process_group=process_group,
|
||||
reduce=reduce,
|
||||
)
|
||||
self.process_group = process_group
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.dense_h_to_4h(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.dense_4h_to_h(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlashRWLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads,
|
||||
num_heads_kv,
|
||||
hidden_size,
|
||||
bias,
|
||||
layer_norm_eps,
|
||||
parallel_attn,
|
||||
process_group=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.parallel_attn = parallel_attn
|
||||
|
||||
self.input_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps)
|
||||
self.self_attention = FlashRWAttention(
|
||||
num_heads,
|
||||
num_heads_kv,
|
||||
hidden_size,
|
||||
bias,
|
||||
process_group=process_group,
|
||||
reduce=False,
|
||||
)
|
||||
self.post_attention_layernorm = (
|
||||
FastLayerNorm(hidden_size, eps=layer_norm_eps)
|
||||
if not parallel_attn
|
||||
else None
|
||||
)
|
||||
|
||||
self.mlp = FlashMLP(
|
||||
hidden_size, bias, process_group=process_group, reduce=False
|
||||
)
|
||||
|
||||
self.process_group = process_group
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
residual,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
layer_past,
|
||||
layer_past_present_indices,
|
||||
cu_seqlens_q,
|
||||
):
|
||||
if self.parallel_attn:
|
||||
ln_hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
|
||||
attn_output = self.self_attention(
|
||||
ln_hidden_states,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
layer_past,
|
||||
layer_past_present_indices,
|
||||
cu_seqlens_q,
|
||||
)
|
||||
|
||||
mlp_output = self.mlp(ln_hidden_states)
|
||||
intermediate = mlp_output + attn_output
|
||||
|
||||
# Only reduce once and after the addition instead of once per layer
|
||||
if self.process_group is not None:
|
||||
torch.distributed.all_reduce(intermediate, group=self.process_group)
|
||||
|
||||
return intermediate, residual
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
|
||||
hidden_states = self.self_attention(
|
||||
hidden_states,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
layer_past,
|
||||
layer_past_present_indices,
|
||||
cu_seqlens_q,
|
||||
)
|
||||
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual
|
||||
)
|
||||
|
||||
mlp_output = self.mlp(hidden_states)
|
||||
|
||||
return mlp_output, residual
|
||||
|
||||
|
||||
class FlashRWLargeLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads,
|
||||
num_heads_kv,
|
||||
hidden_size,
|
||||
bias,
|
||||
layer_norm_eps,
|
||||
process_group=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.ln_attn = FastLayerNorm(hidden_size, eps=layer_norm_eps)
|
||||
self.ln_mlp = FastLayerNorm(hidden_size, eps=layer_norm_eps)
|
||||
|
||||
self.self_attention = FlashRWLargeAttention(
|
||||
num_heads,
|
||||
num_heads_kv,
|
||||
hidden_size,
|
||||
bias,
|
||||
process_group=process_group,
|
||||
reduce=False,
|
||||
)
|
||||
|
||||
self.mlp = FlashMLP(
|
||||
hidden_size, bias, process_group=process_group, reduce=False
|
||||
)
|
||||
|
||||
self.process_group = process_group
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
residual,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
layer_past,
|
||||
layer_past_present_indices,
|
||||
cu_seqlens_q,
|
||||
):
|
||||
ln_attn, residual = self.ln_attn(hidden_states, residual)
|
||||
ln_mlp, _ = self.ln_mlp(residual)
|
||||
|
||||
# Self attention.
|
||||
attn_output = self.self_attention(
|
||||
ln_attn,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
layer_past,
|
||||
layer_past_present_indices,
|
||||
cu_seqlens_q,
|
||||
)
|
||||
|
||||
# MLP.
|
||||
mlp_output = self.mlp(ln_mlp)
|
||||
|
||||
intermediate = attn_output + mlp_output
|
||||
|
||||
# Only reduce once and after the addition instead of once per layer
|
||||
if self.process_group is not None:
|
||||
torch.distributed.all_reduce(intermediate, group=self.process_group)
|
||||
|
||||
return intermediate, residual
|
||||
|
||||
|
||||
class FlashRWPreTrainedModel(PreTrainedModel):
|
||||
config_class = RWConfig
|
||||
|
||||
|
||||
class FlashRWModel(FlashRWPreTrainedModel):
|
||||
def __init__(self, config, process_group=None):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
self.tp_embeddings = False
|
||||
if process_group is not None:
|
||||
self.tp_rank = process_group.rank()
|
||||
self.tp_world_size = process_group.size()
|
||||
if config.vocab_size % self.tp_world_size == 0:
|
||||
self.tp_embeddings = True
|
||||
|
||||
if self.tp_embeddings:
|
||||
self.word_embeddings = TensorParallelEmbedding(
|
||||
config.vocab_size, config.hidden_size, process_group=process_group
|
||||
)
|
||||
else:
|
||||
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
|
||||
if config.model_type == "RefinedWebModel":
|
||||
self.h = nn.ModuleList(
|
||||
[
|
||||
FlashRWLayer(
|
||||
config.n_head,
|
||||
config.n_head_kv,
|
||||
config.hidden_size,
|
||||
config.bias,
|
||||
config.layer_norm_epsilon,
|
||||
config.parallel_attn,
|
||||
process_group,
|
||||
)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.cache_size = (
|
||||
2,
|
||||
self.h[0].self_attention.num_heads_kv,
|
||||
self.h[0].self_attention.head_size,
|
||||
)
|
||||
elif config.model_type == "RefinedWeb":
|
||||
self.h = nn.ModuleList(
|
||||
[
|
||||
FlashRWLargeLayer(
|
||||
config.n_head,
|
||||
config.n_head_kv,
|
||||
config.hidden_size,
|
||||
config.bias,
|
||||
config.layer_norm_epsilon,
|
||||
process_group,
|
||||
)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.cache_size = (
|
||||
self.h[0].self_attention.num_groups,
|
||||
2,
|
||||
self.h[0].self_attention.head_size,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"model_type {config.model_type} is not supported."
|
||||
)
|
||||
|
||||
self.ln_f = FastLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
self.head_size = self.h[0].self_attention.head_size
|
||||
|
||||
def post_load_weights(self, quantize: Optional[str] = None):
|
||||
if isinstance(self.word_embeddings, TensorParallelEmbedding):
|
||||
self.word_embeddings.add_null_idx()
|
||||
for layer in self.h:
|
||||
layer: FlashRWLayer
|
||||
layer.self_attention.query_key_value.prepare_weights(quantize)
|
||||
layer.self_attention.dense.prepare_weights(quantize)
|
||||
layer.mlp.dense_h_to_4h.prepare_weights(quantize)
|
||||
layer.mlp.dense_4h_to_h.prepare_weights(quantize)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
# Pop here as we will replace the layer in our own logic and don't want from_pretrained
|
||||
# to do it for us
|
||||
load_in_8bit = kwargs.pop("load_in_8bit", False)
|
||||
model = super(FlashRWModel, cls).from_pretrained(
|
||||
pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs
|
||||
)
|
||||
|
||||
model.post_load_weights("bitsandbytes" if load_in_8bit else None)
|
||||
return model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
position_ids,
|
||||
cu_seqlens,
|
||||
cu_seqlens_q,
|
||||
max_s,
|
||||
past_key_values=None,
|
||||
pre_allocate_past_size: Optional[int] = None,
|
||||
):
|
||||
hidden_states = self.word_embeddings(input_ids)
|
||||
|
||||
# Prefill
|
||||
if past_key_values is None:
|
||||
# Create past tensor
|
||||
past_key_values = hidden_states.new_empty(
|
||||
(
|
||||
len(self.h),
|
||||
len(hidden_states)
|
||||
if pre_allocate_past_size is None
|
||||
else pre_allocate_past_size,
|
||||
*self.cache_size,
|
||||
)
|
||||
)
|
||||
layer_past_present_indices = None
|
||||
slice_past_index = len(hidden_states)
|
||||
# Decode
|
||||
else:
|
||||
# Create indices from cumulative sequence lengths
|
||||
layer_past_present_indices = cu_seqlens[1:] - 1
|
||||
slice_past_index = None
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
# Avoid to index in each layer
|
||||
cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin(
|
||||
position_ids, max_s, hidden_states.dtype
|
||||
)
|
||||
|
||||
residual = None
|
||||
for i, layer in enumerate(self.h):
|
||||
# We added padding that we now need to slice
|
||||
layer_past_key_values = (
|
||||
past_key_values[i]
|
||||
if slice_past_index is None
|
||||
else past_key_values[i, :slice_past_index]
|
||||
)
|
||||
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
residual,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
layer_past_key_values,
|
||||
layer_past_present_indices,
|
||||
cu_seqlens_q,
|
||||
)
|
||||
|
||||
hidden_states, _ = self.ln_f(hidden_states, residual)
|
||||
|
||||
return hidden_states, past_key_values
|
||||
|
||||
|
||||
class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
||||
def __init__(self, config, process_group=None):
|
||||
super().__init__(config)
|
||||
|
||||
self.process_group = process_group
|
||||
if self.process_group is not None:
|
||||
self.world_size = self.process_group.size()
|
||||
else:
|
||||
self.world_size = 1
|
||||
|
||||
self.transformer = FlashRWModel(config, process_group)
|
||||
|
||||
if self.transformer.tp_embeddings:
|
||||
self.lm_head = FastLinear(
|
||||
config.hidden_size,
|
||||
config.vocab_size // process_group.size(),
|
||||
bias=False,
|
||||
)
|
||||
else:
|
||||
self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
def post_load_weights(self, quantize: Optional[str] = None):
|
||||
self.transformer.post_load_weights(quantize)
|
||||
self.lm_head.prepare_weights()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
# Pop here as we will replace the layer in our own logic and don't want from_pretrained
|
||||
# to do it for us
|
||||
load_in_8bit = kwargs.pop("load_in_8bit", False)
|
||||
model = super(FlashRWForCausalLM, cls).from_pretrained(
|
||||
pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs
|
||||
)
|
||||
model.post_load_weights("bitsandbytes" if load_in_8bit else None)
|
||||
return model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
position_ids,
|
||||
cu_seqlens,
|
||||
cu_seqlens_q,
|
||||
max_s,
|
||||
past_key_values: Optional[torch.Tensor] = None,
|
||||
pre_allocate_past_size: Optional[int] = None,
|
||||
):
|
||||
hidden_states, present = self.transformer(
|
||||
input_ids,
|
||||
position_ids,
|
||||
cu_seqlens,
|
||||
cu_seqlens_q,
|
||||
max_s,
|
||||
past_key_values,
|
||||
pre_allocate_past_size,
|
||||
)
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
if self.transformer.tp_embeddings:
|
||||
# Logits are sharded, so we need to gather them
|
||||
world_logits = [torch.empty_like(logits) for _ in range(self.world_size)]
|
||||
torch.distributed.all_gather(world_logits, logits, group=self.process_group)
|
||||
world_logits = torch.cat(world_logits, dim=1)
|
||||
|
||||
return world_logits, present
|
||||
return logits, present
|
|
@ -0,0 +1,244 @@
|
|||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from pathlib import Path
|
||||
from accelerate import init_empty_weights
|
||||
from opentelemetry import trace
|
||||
from safetensors import safe_open
|
||||
from transformers import AutoTokenizer, AutoConfig
|
||||
from typing import Optional, List
|
||||
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
from text_generation_server.models.custom_modeling.flash_rw_modeling import (
|
||||
RWConfig,
|
||||
FlashRWForCausalLM,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
download_weights,
|
||||
weight_hub_files,
|
||||
LocalEntryNotFoundError,
|
||||
)
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
class FlashRW(FlashCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.bfloat16
|
||||
else:
|
||||
raise NotImplementedError("RW is only available on GPU")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = RWConfig.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
# We do not use from_pretrained as it is too slow
|
||||
try:
|
||||
filenames = weight_files(model_id, revision, ".bin")
|
||||
# Local files not found
|
||||
except LocalEntryNotFoundError:
|
||||
hub_files = weight_hub_files(model_id, revision, ".bin")
|
||||
filenames = download_weights(hub_files, model_id, revision)
|
||||
|
||||
with init_empty_weights():
|
||||
model = FlashRWForCausalLM(config)
|
||||
|
||||
self.load_weights(
|
||||
model,
|
||||
filenames,
|
||||
quantize,
|
||||
device,
|
||||
dtype,
|
||||
)
|
||||
|
||||
super(FlashCausalLM, self).__init__(
|
||||
model=model.to(device),
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=False,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def load_weights(
|
||||
model: FlashRWForCausalLM,
|
||||
filenames: List[Path],
|
||||
quantize: Optional[str],
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
for filename in filenames:
|
||||
state_dict = torch.load(filename, map_location="cpu")
|
||||
for key, value in state_dict.items():
|
||||
value = value.to(device if quantize is None else "cpu").to(dtype)
|
||||
|
||||
module_name, param_name = key.rsplit(".", 1)
|
||||
module = model.get_submodule(module_name)
|
||||
|
||||
try:
|
||||
current_parameter_tensor = module._parameters[param_name]
|
||||
if current_parameter_tensor.shape != value.shape:
|
||||
raise ValueError(
|
||||
f"Name {key} -- Current {current_parameter_tensor.shape} and got {value.shape}"
|
||||
)
|
||||
module._parameters[param_name] = value
|
||||
except KeyError:
|
||||
module._buffers[param_name] = value
|
||||
|
||||
del value
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
model.post_load_weights(quantize)
|
||||
|
||||
|
||||
class FlashRWSharded(FlashRW):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.bfloat16
|
||||
else:
|
||||
raise NotImplementedError("FlashRW is only available on GPU")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = RWConfig.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
|
||||
with init_empty_weights():
|
||||
model = FlashRWForCausalLM(config, self.process_group)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
self.load_weights(
|
||||
model,
|
||||
filenames,
|
||||
quantize=quantize,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashCausalLM, self).__init__(
|
||||
model=model.to(device),
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=False,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def load_weights(
|
||||
model,
|
||||
filenames: List[str],
|
||||
quantize: Optional[str],
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
):
|
||||
parameters = dict(model.named_parameters())
|
||||
for file in filenames:
|
||||
with safe_open(
|
||||
file, framework="pt", device=str(device) if quantize is None else "cpu"
|
||||
) as f:
|
||||
for name in f.keys():
|
||||
module_name, param_name = name.rsplit(".", 1)
|
||||
module = model.get_submodule(module_name)
|
||||
|
||||
current_parameter_tensor = parameters.get(name, None)
|
||||
|
||||
slice_ = f.get_slice(name)
|
||||
|
||||
if isinstance(module, TensorParallelColumnLinear):
|
||||
size = slice_.get_shape()[0]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[start:stop]
|
||||
elif isinstance(module, TensorParallelRowLinear):
|
||||
if param_name == "weight":
|
||||
size = slice_.get_shape()[1]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[:, start:stop]
|
||||
else:
|
||||
tensor = slice_[:]
|
||||
# XXX: Hack for Rowlinear to add the bias only once.
|
||||
if rank != 0:
|
||||
tensor = torch.zeros_like(tensor)
|
||||
elif isinstance(module, TensorParallelEmbedding):
|
||||
size = slice_.get_shape()[0]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[start:stop]
|
||||
elif name == "lm_head.weight" and model.transformer.tp_embeddings:
|
||||
size = slice_.get_shape()[0]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[start:stop]
|
||||
else:
|
||||
try:
|
||||
tensor = slice_[:]
|
||||
except:
|
||||
tensor = f.get_tensor(name)
|
||||
|
||||
if (
|
||||
current_parameter_tensor is not None
|
||||
and current_parameter_tensor.shape != tensor.shape
|
||||
):
|
||||
raise ValueError(
|
||||
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
|
||||
)
|
||||
|
||||
tensor = tensor.contiguous().to(dtype)
|
||||
|
||||
if current_parameter_tensor is not None:
|
||||
module._parameters[param_name] = tensor
|
||||
else:
|
||||
module._buffers[param_name] = tensor
|
||||
|
||||
model.post_load_weights(quantize)
|
|
@ -0,0 +1,86 @@
|
|||
import torch
|
||||
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from text_generation_server.models import CausalLM
|
||||
|
||||
|
||||
class RW(CausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.bfloat16
|
||||
else:
|
||||
if quantize:
|
||||
raise ValueError("quantization is not available on CPU")
|
||||
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
torch_dtype=dtype,
|
||||
device_map="auto"
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
||||
else None,
|
||||
load_in_8bit=quantize == "bitsandbytes",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
||||
model = model.cuda()
|
||||
|
||||
if tokenizer.pad_token_id is None:
|
||||
if model.config.pad_token_id is not None:
|
||||
tokenizer.pad_token_id = model.config.pad_token_id
|
||||
elif model.config.eos_token_id is not None:
|
||||
tokenizer.pad_token_id = model.config.eos_token_id
|
||||
elif tokenizer.eos_token_id is not None:
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
else:
|
||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
|
||||
super(CausalLM, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
# Model Forward
|
||||
if past_key_values is not None:
|
||||
reshaped_past_key_values = []
|
||||
for layer in past_key_values:
|
||||
past_keys, past_values = layer
|
||||
reshaped_past_key_values.append(
|
||||
(
|
||||
past_keys.view(-1, *past_keys.shape[-2:]),
|
||||
past_values.view(-1, *past_values.shape[-2:]),
|
||||
)
|
||||
)
|
||||
past_key_values = reshaped_past_key_values
|
||||
|
||||
outputs = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
return outputs.logits, outputs.past_key_values
|
|
@ -262,16 +262,13 @@ try:
|
|||
sin = torch.index_select(self._sin_cached, 0, position_ids)
|
||||
return cos.unsqueeze(1), sin.unsqueeze(1)
|
||||
|
||||
def forward(self, qkv: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
|
||||
def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
|
||||
rotary_dim = cos.shape[-1]
|
||||
q1 = qkv[:, 0, :, :rotary_dim]
|
||||
q2 = qkv[:, 0, :, rotary_dim : 2 * rotary_dim]
|
||||
k1 = qkv[:, 1, :, :rotary_dim]
|
||||
k2 = qkv[:, 1, :, rotary_dim : 2 * rotary_dim]
|
||||
x1 = x[..., :rotary_dim]
|
||||
x2 = x[..., rotary_dim : 2 * rotary_dim]
|
||||
|
||||
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)
|
||||
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
|
||||
return qkv
|
||||
rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False)
|
||||
return x
|
||||
|
||||
except ImportError:
|
||||
pass
|
||||
|
|
Loading…
Reference in New Issue