support all, test llama
This commit is contained in:
parent
2ae65b45a8
commit
0036084294
|
@ -230,15 +230,19 @@ def launcher(event_loop):
|
|||
shard_uds_path,
|
||||
]
|
||||
|
||||
env = os.environ
|
||||
|
||||
if num_shard is not None:
|
||||
args.extend(["--num-shard", str(num_shard)])
|
||||
if quantize is not None:
|
||||
args.append("--quantize")
|
||||
args.append(quantize)
|
||||
if quantize == "gptq":
|
||||
env["GPTQ_GROUPSIZE"] = "128"
|
||||
env["GPTQ_BITS"] = "4"
|
||||
if trust_remote_code:
|
||||
args.append("--trust-remote-code")
|
||||
|
||||
env = os.environ
|
||||
env["LOG_LEVEL"] = "info,text_generation_router=debug"
|
||||
|
||||
if not use_flash_attention:
|
||||
|
|
|
@ -0,0 +1,102 @@
|
|||
{
|
||||
"generated_text": ", and I am going to visit the Louvre",
|
||||
"details": {
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"seed": null,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 2,
|
||||
"text": "</s>",
|
||||
"logprob": null
|
||||
},
|
||||
{
|
||||
"id": 20628,
|
||||
"text": "Today",
|
||||
"logprob": -11.2265625
|
||||
},
|
||||
{
|
||||
"id": 306,
|
||||
"text": "I",
|
||||
"logprob": -4.1757812
|
||||
},
|
||||
{
|
||||
"id": 626,
|
||||
"text": "am",
|
||||
"logprob": -1.9746094
|
||||
},
|
||||
{
|
||||
"id": 297,
|
||||
"text": "in",
|
||||
"logprob": -5.4648438
|
||||
},
|
||||
{
|
||||
"id": 3444,
|
||||
"text": "France",
|
||||
"logprob": -9.03125
|
||||
}
|
||||
],
|
||||
"tokens": [
|
||||
{
|
||||
"id": 29892,
|
||||
"text": ",",
|
||||
"logprob": -0.31298828,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 322,
|
||||
"text": " and",
|
||||
"logprob": -1.4345703,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 306,
|
||||
"text": " I",
|
||||
"logprob": -0.32080078,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 626,
|
||||
"text": " am",
|
||||
"logprob": -1.3798828,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 2675,
|
||||
"text": " going",
|
||||
"logprob": -1.2304688,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 304,
|
||||
"text": " to",
|
||||
"logprob": -0.0014791489,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 6493,
|
||||
"text": " visit",
|
||||
"logprob": -1.1503906,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 278,
|
||||
"text": " the",
|
||||
"logprob": -0.41259766,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 4562,
|
||||
"text": " Lou",
|
||||
"logprob": -1.8134766,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 12675,
|
||||
"text": "vre",
|
||||
"logprob": -0.000767231,
|
||||
"special": false
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
|
@ -0,0 +1,97 @@
|
|||
{
|
||||
"generated_text": "The capital city of France isParis.\n The Best Way to Visit",
|
||||
"details": {
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"seed": 0,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 2,
|
||||
"text": "</s>",
|
||||
"logprob": null
|
||||
},
|
||||
{
|
||||
"id": 4272,
|
||||
"text": "city",
|
||||
"logprob": -12.4453125
|
||||
},
|
||||
{
|
||||
"id": 310,
|
||||
"text": "of",
|
||||
"logprob": -2.4023438
|
||||
},
|
||||
{
|
||||
"id": 3444,
|
||||
"text": "France",
|
||||
"logprob": -12.515625
|
||||
},
|
||||
{
|
||||
"id": 338,
|
||||
"text": "is",
|
||||
"logprob": -5.1914062
|
||||
}
|
||||
],
|
||||
"tokens": [
|
||||
{
|
||||
"id": 3681,
|
||||
"text": " Paris",
|
||||
"logprob": -0.22546387,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 29889,
|
||||
"text": ".",
|
||||
"logprob": 0,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"text": "\n",
|
||||
"logprob": 0,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"text": "",
|
||||
"logprob": 0,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 450,
|
||||
"text": " The",
|
||||
"logprob": 0,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 6407,
|
||||
"text": " Best",
|
||||
"logprob": -0.5522461,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 5307,
|
||||
"text": " Way",
|
||||
"logprob": 0,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 304,
|
||||
"text": " to",
|
||||
"logprob": 0,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 5741,
|
||||
"text": " Vis",
|
||||
"logprob": -2.3496094,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 277,
|
||||
"text": "it",
|
||||
"logprob": 0,
|
||||
"special": false
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
|
@ -0,0 +1,410 @@
|
|||
[
|
||||
{
|
||||
"generated_text": ", and I am going to visit the Louvre",
|
||||
"details": {
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"seed": null,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 2,
|
||||
"text": "</s>",
|
||||
"logprob": null
|
||||
},
|
||||
{
|
||||
"id": 20628,
|
||||
"text": "Today",
|
||||
"logprob": -10.734375
|
||||
},
|
||||
{
|
||||
"id": 306,
|
||||
"text": "I",
|
||||
"logprob": -4.2265625
|
||||
},
|
||||
{
|
||||
"id": 626,
|
||||
"text": "am",
|
||||
"logprob": -1.9794922
|
||||
},
|
||||
{
|
||||
"id": 297,
|
||||
"text": "in",
|
||||
"logprob": -5.4960938
|
||||
},
|
||||
{
|
||||
"id": 3444,
|
||||
"text": "France",
|
||||
"logprob": -9.1171875
|
||||
}
|
||||
],
|
||||
"tokens": [
|
||||
{
|
||||
"id": 29892,
|
||||
"text": ",",
|
||||
"logprob": -0.30737305,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 322,
|
||||
"text": " and",
|
||||
"logprob": -1.3701172,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 306,
|
||||
"text": " I",
|
||||
"logprob": -0.31567383,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 626,
|
||||
"text": " am",
|
||||
"logprob": -1.3886719,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 2675,
|
||||
"text": " going",
|
||||
"logprob": -1.2070312,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 304,
|
||||
"text": " to",
|
||||
"logprob": -0.0014028549,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 6493,
|
||||
"text": " visit",
|
||||
"logprob": -1.1181641,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 278,
|
||||
"text": " the",
|
||||
"logprob": -0.3942871,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 4562,
|
||||
"text": " Lou",
|
||||
"logprob": -1.8789062,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 12675,
|
||||
"text": "vre",
|
||||
"logprob": -0.00082969666,
|
||||
"special": false
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"generated_text": ", and I am going to visit the Louvre",
|
||||
"details": {
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"seed": null,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 2,
|
||||
"text": "</s>",
|
||||
"logprob": null
|
||||
},
|
||||
{
|
||||
"id": 20628,
|
||||
"text": "Today",
|
||||
"logprob": -10.734375
|
||||
},
|
||||
{
|
||||
"id": 306,
|
||||
"text": "I",
|
||||
"logprob": -4.2265625
|
||||
},
|
||||
{
|
||||
"id": 626,
|
||||
"text": "am",
|
||||
"logprob": -1.9794922
|
||||
},
|
||||
{
|
||||
"id": 297,
|
||||
"text": "in",
|
||||
"logprob": -5.4960938
|
||||
},
|
||||
{
|
||||
"id": 3444,
|
||||
"text": "France",
|
||||
"logprob": -9.1171875
|
||||
}
|
||||
],
|
||||
"tokens": [
|
||||
{
|
||||
"id": 29892,
|
||||
"text": ",",
|
||||
"logprob": -0.30737305,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 322,
|
||||
"text": " and",
|
||||
"logprob": -1.3720703,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 306,
|
||||
"text": " I",
|
||||
"logprob": -0.31469727,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 626,
|
||||
"text": " am",
|
||||
"logprob": -1.3916016,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 2675,
|
||||
"text": " going",
|
||||
"logprob": -1.2050781,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 304,
|
||||
"text": " to",
|
||||
"logprob": -0.0014019012,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 6493,
|
||||
"text": " visit",
|
||||
"logprob": -1.1162109,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 278,
|
||||
"text": " the",
|
||||
"logprob": -0.3959961,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 4562,
|
||||
"text": " Lou",
|
||||
"logprob": -1.8847656,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 12675,
|
||||
"text": "vre",
|
||||
"logprob": -0.0008392334,
|
||||
"special": false
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"generated_text": ", and I am going to visit the Louvre",
|
||||
"details": {
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"seed": null,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 2,
|
||||
"text": "</s>",
|
||||
"logprob": null
|
||||
},
|
||||
{
|
||||
"id": 20628,
|
||||
"text": "Today",
|
||||
"logprob": -10.734375
|
||||
},
|
||||
{
|
||||
"id": 306,
|
||||
"text": "I",
|
||||
"logprob": -4.2265625
|
||||
},
|
||||
{
|
||||
"id": 626,
|
||||
"text": "am",
|
||||
"logprob": -1.9794922
|
||||
},
|
||||
{
|
||||
"id": 297,
|
||||
"text": "in",
|
||||
"logprob": -5.4960938
|
||||
},
|
||||
{
|
||||
"id": 3444,
|
||||
"text": "France",
|
||||
"logprob": -9.1171875
|
||||
}
|
||||
],
|
||||
"tokens": [
|
||||
{
|
||||
"id": 29892,
|
||||
"text": ",",
|
||||
"logprob": -0.30737305,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 322,
|
||||
"text": " and",
|
||||
"logprob": -1.3710938,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 306,
|
||||
"text": " I",
|
||||
"logprob": -0.31225586,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 626,
|
||||
"text": " am",
|
||||
"logprob": -1.3994141,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 2675,
|
||||
"text": " going",
|
||||
"logprob": -1.2060547,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 304,
|
||||
"text": " to",
|
||||
"logprob": -0.0013828278,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 6493,
|
||||
"text": " visit",
|
||||
"logprob": -1.1181641,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 278,
|
||||
"text": " the",
|
||||
"logprob": -0.39135742,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 4562,
|
||||
"text": " Lou",
|
||||
"logprob": -1.8808594,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 12675,
|
||||
"text": "vre",
|
||||
"logprob": -0.00084352493,
|
||||
"special": false
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"generated_text": ", and I am going to visit the Louvre",
|
||||
"details": {
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"seed": null,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 2,
|
||||
"text": "</s>",
|
||||
"logprob": null
|
||||
},
|
||||
{
|
||||
"id": 20628,
|
||||
"text": "Today",
|
||||
"logprob": -11.203125
|
||||
},
|
||||
{
|
||||
"id": 306,
|
||||
"text": "I",
|
||||
"logprob": -4.1757812
|
||||
},
|
||||
{
|
||||
"id": 626,
|
||||
"text": "am",
|
||||
"logprob": -1.9697266
|
||||
},
|
||||
{
|
||||
"id": 297,
|
||||
"text": "in",
|
||||
"logprob": -5.4609375
|
||||
},
|
||||
{
|
||||
"id": 3444,
|
||||
"text": "France",
|
||||
"logprob": -9.046875
|
||||
}
|
||||
],
|
||||
"tokens": [
|
||||
{
|
||||
"id": 29892,
|
||||
"text": ",",
|
||||
"logprob": -0.3083496,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 322,
|
||||
"text": " and",
|
||||
"logprob": -1.4228516,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 306,
|
||||
"text": " I",
|
||||
"logprob": -0.32055664,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 626,
|
||||
"text": " am",
|
||||
"logprob": -1.3847656,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 2675,
|
||||
"text": " going",
|
||||
"logprob": -1.21875,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 304,
|
||||
"text": " to",
|
||||
"logprob": -0.0014572144,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 6493,
|
||||
"text": " visit",
|
||||
"logprob": -1.1542969,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 278,
|
||||
"text": " the",
|
||||
"logprob": -0.41455078,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 4562,
|
||||
"text": " Lou",
|
||||
"logprob": -1.8193359,
|
||||
"special": false
|
||||
},
|
||||
{
|
||||
"id": 12675,
|
||||
"text": "vre",
|
||||
"logprob": -0.0007710457,
|
||||
"special": false
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
|
@ -0,0 +1,58 @@
|
|||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_llama_gptq_handle(launcher):
|
||||
with launcher("TheBloke/WizardLM-7B-uncensored-GPTQ", num_shard=2, quantize="gptq") as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def flash_llama_gptq(flash_llama_gptq_handle):
|
||||
await flash_llama_gptq_handle.health(300)
|
||||
return flash_llama_gptq_handle.client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_gptq(flash_llama_gptq, response_snapshot):
|
||||
response = await flash_llama_gptq.generate(
|
||||
"Today I am in France", max_new_tokens=10, decoder_input_details=True
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_gptq_all_params(flash_llama_gptq, response_snapshot):
|
||||
response = await flash_llama_gptq.generate(
|
||||
"The capital city of France is",
|
||||
max_new_tokens=10,
|
||||
repetition_penalty=1.2,
|
||||
return_full_text=True,
|
||||
temperature=0.5,
|
||||
top_p=0.9,
|
||||
top_k=10,
|
||||
truncate=5,
|
||||
typical_p=0.9,
|
||||
watermark=True,
|
||||
decoder_input_details=True,
|
||||
seed=0,
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_llama_gptq_load(flash_llama_gptq, generate_load, response_snapshot):
|
||||
responses = await generate_load(flash_llama_gptq, "Today I am in France", max_new_tokens=10, n=4)
|
||||
|
||||
assert len(responses) == 4
|
||||
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||
|
||||
assert responses == response_snapshot
|
|
@ -500,6 +500,7 @@ class CausalLM(Model):
|
|||
super(CausalLM, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
config=model.config,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
|
|
|
@ -298,7 +298,6 @@ class FlashLlamaLayer(nn.Module):
|
|||
class FlashLlamaModel(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
process_group = weights.process_group
|
||||
self.tp_rank = process_group.rank()
|
||||
|
@ -368,6 +367,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||
def __init__(self, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.model = FlashLlamaModel(config, weights)
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
config,
|
||||
|
|
|
@ -73,17 +73,7 @@ def _load_multi_mqa_gptq(
|
|||
qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
|
||||
|
||||
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
|
||||
try:
|
||||
bits = weights.get_tensor("gptq_bits").item()
|
||||
groupsize = weights.get_tensor("gptq_groupsize").item()
|
||||
except SafetensorError as e:
|
||||
try:
|
||||
import os
|
||||
|
||||
bits = int(os.getenv("GPTQ_BITS"))
|
||||
groupsize = int(os.getenv("GPTQ_GROUPSIZE"))
|
||||
except Exception:
|
||||
raise e
|
||||
bits, groupsize = weights.get_gptq_qparams()
|
||||
|
||||
qweight = qweight.to(weights.device)
|
||||
qzeros = qzeros.to(weights.device)
|
||||
|
@ -471,7 +461,6 @@ class FlashSantacoderForCausalLM(nn.Module):
|
|||
self.lm_head = TensorParallelHead.load(
|
||||
config, prefix="transformer.wte", weights=weights
|
||||
)
|
||||
self.config = config
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
|
@ -6,9 +6,8 @@ import torch.distributed
|
|||
import numpy as np
|
||||
|
||||
from dataclasses import dataclass
|
||||
from loguru import logger
|
||||
from opentelemetry import trace
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from transformers import PreTrainedTokenizerBase, PretrainedConfig
|
||||
from typing import Optional, Tuple, List, Type, Union, Dict
|
||||
|
||||
from text_generation_server.models import Model
|
||||
|
@ -21,6 +20,7 @@ from text_generation_server.models.types import (
|
|||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
||||
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
BLOCK_SIZE = 16
|
||||
|
@ -684,6 +684,7 @@ class FlashCausalLM(Model):
|
|||
self,
|
||||
model: torch.nn.Module,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
config: PretrainedConfig,
|
||||
num_layers: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
|
@ -699,6 +700,7 @@ class FlashCausalLM(Model):
|
|||
super(FlashCausalLM, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
config=config,
|
||||
requires_padding=False,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
|
|
|
@ -68,6 +68,7 @@ class FlashLlama(FlashCausalLM):
|
|||
super(FlashLlama, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
config=config,
|
||||
num_layers=len(model.model.layers),
|
||||
num_kv_heads=model.model.num_heads,
|
||||
head_size=model.model.head_size,
|
||||
|
|
|
@ -59,6 +59,7 @@ class FlashNeoXSharded(FlashCausalLM):
|
|||
super(FlashNeoXSharded, self).__init__(
|
||||
model=model.to(device),
|
||||
tokenizer=tokenizer,
|
||||
config=config,
|
||||
num_layers=len(model.gpt_neox.layers),
|
||||
num_kv_heads=model.gpt_neox.num_heads,
|
||||
head_size=model.gpt_neox.head_size,
|
||||
|
|
|
@ -65,6 +65,7 @@ class FlashRWSharded(FlashCausalLM):
|
|||
super(FlashRWSharded, self).__init__(
|
||||
model=model.to(device),
|
||||
tokenizer=tokenizer,
|
||||
config=config,
|
||||
num_layers=len(model.transformer.h),
|
||||
num_kv_heads=model.transformer.cache_size,
|
||||
head_size=model.transformer.head_size,
|
||||
|
|
|
@ -66,6 +66,7 @@ class FlashSantacoderSharded(FlashCausalLM):
|
|||
super(FlashSantacoderSharded, self).__init__(
|
||||
model=model.to(device),
|
||||
tokenizer=tokenizer,
|
||||
config=config,
|
||||
num_layers=len(model.transformer.h),
|
||||
num_kv_heads=1,
|
||||
head_size=model.transformer.head_size,
|
||||
|
|
|
@ -198,6 +198,7 @@ class GalacticaSharded(CausalLM):
|
|||
super(CausalLM, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
config=config,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
|
|
|
@ -63,6 +63,7 @@ class GPTNeoxSharded(CausalLM):
|
|||
super(CausalLM, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
config=config,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
|
|
|
@ -3,7 +3,7 @@ import torch
|
|||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Tuple, Optional, TypeVar, Type
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from transformers import PreTrainedTokenizerBase, PretrainedConfig
|
||||
|
||||
from text_generation_server.models.types import Batch, GeneratedText
|
||||
from text_generation_server.pb.generate_pb2 import InfoResponse
|
||||
|
@ -23,6 +23,7 @@ class Model(ABC):
|
|||
self,
|
||||
model: torch.nn.Module,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
config: PretrainedConfig,
|
||||
requires_padding: bool,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
|
@ -45,24 +46,41 @@ class Model(ABC):
|
|||
inspect.signature(model.forward).parameters.get("position_ids", None)
|
||||
is not None
|
||||
)
|
||||
self.config = config
|
||||
|
||||
if model.config.quantize == "gptq":
|
||||
if config.quantize == "gptq":
|
||||
# Buffers need to be persistent to avoid any bug.
|
||||
self.buffers = {}
|
||||
max_dq_buffer_size = 0
|
||||
for name, submodule in self.model.named_modules():
|
||||
use_exllama_act_order = False
|
||||
max_dq_buffer_size = 1
|
||||
max_inner_outer_dim = 1
|
||||
for name, submodule in model.named_modules():
|
||||
if isinstance(submodule, (TensorParallelColumnLinear, TensorParallelRowLinear)) and isinstance(submodule.linear, Ex4bitLinear):
|
||||
max_dq_buffer_size = max(max_dq_buffer_size, submodule.linear.qweight.numel() * 8)
|
||||
|
||||
intermediate_size = model.config.n_inner
|
||||
max_seq_len = 2048 # TODO: we should be able to set it
|
||||
|
||||
self.buffers["temp_state"] = torch.zeros((max_seq_len, intermediate_size), dtype=torch.float16, device=device)
|
||||
self.buffers["temp_dq"] = torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=device)
|
||||
|
||||
max_dq_buffer_size = max(max_dq_buffer_size, submodule.linear.qweight.numel() * 8)
|
||||
|
||||
if submodule.linear.act_order:
|
||||
max_inner_outer_dim = max(max_inner_outer_dim, submodule.linear.height, submodule.linear.width)
|
||||
|
||||
use_exllama_act_order = True
|
||||
|
||||
if use_exllama_act_order:
|
||||
# TODO: this should be set to rust side `max_total_tokens`, but TGI
|
||||
# does not offer an API to expose this variable to python, as this variable
|
||||
# is handled by the client but it appears the model is initialized by the server.
|
||||
# An alternative could be to initialize the buffers during warmup.
|
||||
max_total_tokens = 2048
|
||||
else:
|
||||
max_total_tokens = 1
|
||||
|
||||
# This temp_state buffer is required to reorder X in the act-order case.
|
||||
self.buffers["temp_state"] = torch.zeros((max_total_tokens, max_inner_outer_dim), dtype=torch.float16, device=device)
|
||||
|
||||
# This temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
|
||||
self.buffers["temp_dq"] = torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=device)
|
||||
|
||||
prepare_buffers(device, self.buffers["temp_state"], self.buffers["temp_dq"])
|
||||
|
||||
# TODO: ability to set them
|
||||
matmul_recons_thd = 8
|
||||
matmul_fused_remap = False
|
||||
matmul_no_half2 = False
|
||||
|
|
|
@ -86,6 +86,7 @@ class MPTSharded(CausalLM):
|
|||
super(CausalLM, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
config=config,
|
||||
requires_padding=False,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
|
|
|
@ -61,6 +61,7 @@ class OPTSharded(CausalLM):
|
|||
super(CausalLM, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
config=config,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
|
|
|
@ -58,6 +58,7 @@ class RW(CausalLM):
|
|||
super(CausalLM, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
config=model.config,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
|
|
|
@ -63,6 +63,7 @@ class SantaCoder(CausalLM):
|
|||
super(CausalLM, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
config=model.config,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
|
|
|
@ -542,6 +542,7 @@ class Seq2SeqLM(Model):
|
|||
super(Seq2SeqLM, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
config=model.config,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
|
|
|
@ -73,6 +73,7 @@ class T5Sharded(Seq2SeqLM):
|
|||
super(Seq2SeqLM, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
config=config,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
from pathlib import Path
|
||||
from typing import List, Dict, Optional
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
from safetensors import safe_open, SafetensorError
|
||||
import torch
|
||||
|
||||
|
||||
class Weights:
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -127,17 +126,7 @@ class Weights:
|
|||
torch.testing.assert_close(w2, w[0])
|
||||
g_idx = w[0]
|
||||
|
||||
try:
|
||||
bits = self.get_tensor("gptq_bits").item()
|
||||
groupsize = self.get_tensor("gptq_groupsize").item()
|
||||
except SafetensorError as e:
|
||||
try:
|
||||
import os
|
||||
|
||||
bits = int(os.getenv("GPTQ_BITS"))
|
||||
groupsize = int(os.getenv("GPTQ_GROUPSIZE"))
|
||||
except Exception:
|
||||
raise e
|
||||
bits, groupsize = self.get_gptq_qparams()
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
|
||||
else:
|
||||
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||
|
@ -149,7 +138,7 @@ class Weights:
|
|||
use_triton_kernel = False
|
||||
if self.process_group.size() > 1:
|
||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
||||
groupsize = self.get_tensor("gptq_groupsize").item()
|
||||
_, groupsize = self.get_gptq_qparams()
|
||||
|
||||
if g_idx is not None:
|
||||
if not torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32)) and not (g_idx == 0).all():
|
||||
|
@ -180,19 +169,24 @@ class Weights:
|
|||
else:
|
||||
g_idx = None
|
||||
|
||||
try:
|
||||
bits = self.get_tensor("gptq_bits").item()
|
||||
groupsize = self.get_tensor("gptq_groupsize").item()
|
||||
except SafetensorError as e:
|
||||
try:
|
||||
import os
|
||||
|
||||
bits = int(os.getenv("GPTQ_BITS"))
|
||||
groupsize = int(os.getenv("GPTQ_GROUPSIZE"))
|
||||
except Exception:
|
||||
raise e
|
||||
bits, groupsize = self.get_gptq_qparams()
|
||||
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_triton_kernel)
|
||||
else:
|
||||
weight = self.get_sharded(f"{prefix}.weight", dim=1)
|
||||
return weight
|
||||
|
||||
def get_gptq_qparams(self) -> Tuple[int, int]:
|
||||
try:
|
||||
bits = self.get_tensor("gptq_bits").item()
|
||||
groupsize = self.get_tensor("gptq_groupsize").item()
|
||||
except (SafetensorError, RuntimeError) as e:
|
||||
try:
|
||||
import os
|
||||
|
||||
bits = int(os.getenv("GPTQ_BITS"))
|
||||
groupsize = int(os.getenv("GPTQ_GROUPSIZE"))
|
||||
except Exception:
|
||||
raise e
|
||||
|
||||
return bits, groupsize
|
||||
|
|
Loading…
Reference in New Issue