support all, test llama

This commit is contained in:
Felix Marty 2023-07-13 15:41:57 +00:00
parent 2ae65b45a8
commit 0036084294
23 changed files with 740 additions and 53 deletions

View File

@ -230,15 +230,19 @@ def launcher(event_loop):
shard_uds_path,
]
env = os.environ
if num_shard is not None:
args.extend(["--num-shard", str(num_shard)])
if quantize is not None:
args.append("--quantize")
args.append(quantize)
if quantize == "gptq":
env["GPTQ_GROUPSIZE"] = "128"
env["GPTQ_BITS"] = "4"
if trust_remote_code:
args.append("--trust-remote-code")
env = os.environ
env["LOG_LEVEL"] = "info,text_generation_router=debug"
if not use_flash_attention:

View File

@ -0,0 +1,102 @@
{
"generated_text": ", and I am going to visit the Louvre",
"details": {
"finish_reason": "length",
"generated_tokens": 10,
"seed": null,
"prefill": [
{
"id": 2,
"text": "</s>",
"logprob": null
},
{
"id": 20628,
"text": "Today",
"logprob": -11.2265625
},
{
"id": 306,
"text": "I",
"logprob": -4.1757812
},
{
"id": 626,
"text": "am",
"logprob": -1.9746094
},
{
"id": 297,
"text": "in",
"logprob": -5.4648438
},
{
"id": 3444,
"text": "France",
"logprob": -9.03125
}
],
"tokens": [
{
"id": 29892,
"text": ",",
"logprob": -0.31298828,
"special": false
},
{
"id": 322,
"text": " and",
"logprob": -1.4345703,
"special": false
},
{
"id": 306,
"text": " I",
"logprob": -0.32080078,
"special": false
},
{
"id": 626,
"text": " am",
"logprob": -1.3798828,
"special": false
},
{
"id": 2675,
"text": " going",
"logprob": -1.2304688,
"special": false
},
{
"id": 304,
"text": " to",
"logprob": -0.0014791489,
"special": false
},
{
"id": 6493,
"text": " visit",
"logprob": -1.1503906,
"special": false
},
{
"id": 278,
"text": " the",
"logprob": -0.41259766,
"special": false
},
{
"id": 4562,
"text": " Lou",
"logprob": -1.8134766,
"special": false
},
{
"id": 12675,
"text": "vre",
"logprob": -0.000767231,
"special": false
}
]
}
}

View File

@ -0,0 +1,97 @@
{
"generated_text": "The capital city of France isParis.\n The Best Way to Visit",
"details": {
"finish_reason": "length",
"generated_tokens": 10,
"seed": 0,
"prefill": [
{
"id": 2,
"text": "</s>",
"logprob": null
},
{
"id": 4272,
"text": "city",
"logprob": -12.4453125
},
{
"id": 310,
"text": "of",
"logprob": -2.4023438
},
{
"id": 3444,
"text": "France",
"logprob": -12.515625
},
{
"id": 338,
"text": "is",
"logprob": -5.1914062
}
],
"tokens": [
{
"id": 3681,
"text": " Paris",
"logprob": -0.22546387,
"special": false
},
{
"id": 29889,
"text": ".",
"logprob": 0,
"special": false
},
{
"id": 13,
"text": "\n",
"logprob": 0,
"special": false
},
{
"id": 1,
"text": "",
"logprob": 0,
"special": false
},
{
"id": 450,
"text": " The",
"logprob": 0,
"special": false
},
{
"id": 6407,
"text": " Best",
"logprob": -0.5522461,
"special": false
},
{
"id": 5307,
"text": " Way",
"logprob": 0,
"special": false
},
{
"id": 304,
"text": " to",
"logprob": 0,
"special": false
},
{
"id": 5741,
"text": " Vis",
"logprob": -2.3496094,
"special": false
},
{
"id": 277,
"text": "it",
"logprob": 0,
"special": false
}
]
}
}

View File

@ -0,0 +1,410 @@
[
{
"generated_text": ", and I am going to visit the Louvre",
"details": {
"finish_reason": "length",
"generated_tokens": 10,
"seed": null,
"prefill": [
{
"id": 2,
"text": "</s>",
"logprob": null
},
{
"id": 20628,
"text": "Today",
"logprob": -10.734375
},
{
"id": 306,
"text": "I",
"logprob": -4.2265625
},
{
"id": 626,
"text": "am",
"logprob": -1.9794922
},
{
"id": 297,
"text": "in",
"logprob": -5.4960938
},
{
"id": 3444,
"text": "France",
"logprob": -9.1171875
}
],
"tokens": [
{
"id": 29892,
"text": ",",
"logprob": -0.30737305,
"special": false
},
{
"id": 322,
"text": " and",
"logprob": -1.3701172,
"special": false
},
{
"id": 306,
"text": " I",
"logprob": -0.31567383,
"special": false
},
{
"id": 626,
"text": " am",
"logprob": -1.3886719,
"special": false
},
{
"id": 2675,
"text": " going",
"logprob": -1.2070312,
"special": false
},
{
"id": 304,
"text": " to",
"logprob": -0.0014028549,
"special": false
},
{
"id": 6493,
"text": " visit",
"logprob": -1.1181641,
"special": false
},
{
"id": 278,
"text": " the",
"logprob": -0.3942871,
"special": false
},
{
"id": 4562,
"text": " Lou",
"logprob": -1.8789062,
"special": false
},
{
"id": 12675,
"text": "vre",
"logprob": -0.00082969666,
"special": false
}
]
}
},
{
"generated_text": ", and I am going to visit the Louvre",
"details": {
"finish_reason": "length",
"generated_tokens": 10,
"seed": null,
"prefill": [
{
"id": 2,
"text": "</s>",
"logprob": null
},
{
"id": 20628,
"text": "Today",
"logprob": -10.734375
},
{
"id": 306,
"text": "I",
"logprob": -4.2265625
},
{
"id": 626,
"text": "am",
"logprob": -1.9794922
},
{
"id": 297,
"text": "in",
"logprob": -5.4960938
},
{
"id": 3444,
"text": "France",
"logprob": -9.1171875
}
],
"tokens": [
{
"id": 29892,
"text": ",",
"logprob": -0.30737305,
"special": false
},
{
"id": 322,
"text": " and",
"logprob": -1.3720703,
"special": false
},
{
"id": 306,
"text": " I",
"logprob": -0.31469727,
"special": false
},
{
"id": 626,
"text": " am",
"logprob": -1.3916016,
"special": false
},
{
"id": 2675,
"text": " going",
"logprob": -1.2050781,
"special": false
},
{
"id": 304,
"text": " to",
"logprob": -0.0014019012,
"special": false
},
{
"id": 6493,
"text": " visit",
"logprob": -1.1162109,
"special": false
},
{
"id": 278,
"text": " the",
"logprob": -0.3959961,
"special": false
},
{
"id": 4562,
"text": " Lou",
"logprob": -1.8847656,
"special": false
},
{
"id": 12675,
"text": "vre",
"logprob": -0.0008392334,
"special": false
}
]
}
},
{
"generated_text": ", and I am going to visit the Louvre",
"details": {
"finish_reason": "length",
"generated_tokens": 10,
"seed": null,
"prefill": [
{
"id": 2,
"text": "</s>",
"logprob": null
},
{
"id": 20628,
"text": "Today",
"logprob": -10.734375
},
{
"id": 306,
"text": "I",
"logprob": -4.2265625
},
{
"id": 626,
"text": "am",
"logprob": -1.9794922
},
{
"id": 297,
"text": "in",
"logprob": -5.4960938
},
{
"id": 3444,
"text": "France",
"logprob": -9.1171875
}
],
"tokens": [
{
"id": 29892,
"text": ",",
"logprob": -0.30737305,
"special": false
},
{
"id": 322,
"text": " and",
"logprob": -1.3710938,
"special": false
},
{
"id": 306,
"text": " I",
"logprob": -0.31225586,
"special": false
},
{
"id": 626,
"text": " am",
"logprob": -1.3994141,
"special": false
},
{
"id": 2675,
"text": " going",
"logprob": -1.2060547,
"special": false
},
{
"id": 304,
"text": " to",
"logprob": -0.0013828278,
"special": false
},
{
"id": 6493,
"text": " visit",
"logprob": -1.1181641,
"special": false
},
{
"id": 278,
"text": " the",
"logprob": -0.39135742,
"special": false
},
{
"id": 4562,
"text": " Lou",
"logprob": -1.8808594,
"special": false
},
{
"id": 12675,
"text": "vre",
"logprob": -0.00084352493,
"special": false
}
]
}
},
{
"generated_text": ", and I am going to visit the Louvre",
"details": {
"finish_reason": "length",
"generated_tokens": 10,
"seed": null,
"prefill": [
{
"id": 2,
"text": "</s>",
"logprob": null
},
{
"id": 20628,
"text": "Today",
"logprob": -11.203125
},
{
"id": 306,
"text": "I",
"logprob": -4.1757812
},
{
"id": 626,
"text": "am",
"logprob": -1.9697266
},
{
"id": 297,
"text": "in",
"logprob": -5.4609375
},
{
"id": 3444,
"text": "France",
"logprob": -9.046875
}
],
"tokens": [
{
"id": 29892,
"text": ",",
"logprob": -0.3083496,
"special": false
},
{
"id": 322,
"text": " and",
"logprob": -1.4228516,
"special": false
},
{
"id": 306,
"text": " I",
"logprob": -0.32055664,
"special": false
},
{
"id": 626,
"text": " am",
"logprob": -1.3847656,
"special": false
},
{
"id": 2675,
"text": " going",
"logprob": -1.21875,
"special": false
},
{
"id": 304,
"text": " to",
"logprob": -0.0014572144,
"special": false
},
{
"id": 6493,
"text": " visit",
"logprob": -1.1542969,
"special": false
},
{
"id": 278,
"text": " the",
"logprob": -0.41455078,
"special": false
},
{
"id": 4562,
"text": " Lou",
"logprob": -1.8193359,
"special": false
},
{
"id": 12675,
"text": "vre",
"logprob": -0.0007710457,
"special": false
}
]
}
}
]

View File

@ -0,0 +1,58 @@
import pytest
@pytest.fixture(scope="module")
def flash_llama_gptq_handle(launcher):
with launcher("TheBloke/WizardLM-7B-uncensored-GPTQ", num_shard=2, quantize="gptq") as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_llama_gptq(flash_llama_gptq_handle):
await flash_llama_gptq_handle.health(300)
return flash_llama_gptq_handle.client
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_gptq(flash_llama_gptq, response_snapshot):
response = await flash_llama_gptq.generate(
"Today I am in France", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_gptq_all_params(flash_llama_gptq, response_snapshot):
response = await flash_llama_gptq.generate(
"The capital city of France is",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_gptq_load(flash_llama_gptq, generate_load, response_snapshot):
responses = await generate_load(flash_llama_gptq, "Today I am in France", max_new_tokens=10, n=4)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot

View File

@ -500,6 +500,7 @@ class CausalLM(Model):
super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
config=model.config,
requires_padding=True,
dtype=dtype,
device=device,

View File

@ -298,7 +298,6 @@ class FlashLlamaLayer(nn.Module):
class FlashLlamaModel(torch.nn.Module):
def __init__(self, config, weights):
super().__init__()
self.config = config
process_group = weights.process_group
self.tp_rank = process_group.rank()
@ -368,6 +367,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
def __init__(self, config, weights):
super().__init__()
self.config = config
self.model = FlashLlamaModel(config, weights)
self.lm_head = TensorParallelHead.load(
config,

View File

@ -73,17 +73,7 @@ def _load_multi_mqa_gptq(
qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
try:
bits = weights.get_tensor("gptq_bits").item()
groupsize = weights.get_tensor("gptq_groupsize").item()
except SafetensorError as e:
try:
import os
bits = int(os.getenv("GPTQ_BITS"))
groupsize = int(os.getenv("GPTQ_GROUPSIZE"))
except Exception:
raise e
bits, groupsize = weights.get_gptq_qparams()
qweight = qweight.to(weights.device)
qzeros = qzeros.to(weights.device)
@ -471,7 +461,6 @@ class FlashSantacoderForCausalLM(nn.Module):
self.lm_head = TensorParallelHead.load(
config, prefix="transformer.wte", weights=weights
)
self.config = config
def forward(
self,

View File

@ -6,9 +6,8 @@ import torch.distributed
import numpy as np
from dataclasses import dataclass
from loguru import logger
from opentelemetry import trace
from transformers import PreTrainedTokenizerBase
from transformers import PreTrainedTokenizerBase, PretrainedConfig
from typing import Optional, Tuple, List, Type, Union, Dict
from text_generation_server.models import Model
@ -21,6 +20,7 @@ from text_generation_server.models.types import (
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
tracer = trace.get_tracer(__name__)
BLOCK_SIZE = 16
@ -684,6 +684,7 @@ class FlashCausalLM(Model):
self,
model: torch.nn.Module,
tokenizer: PreTrainedTokenizerBase,
config: PretrainedConfig,
num_layers: int,
num_kv_heads: int,
head_size: int,
@ -699,6 +700,7 @@ class FlashCausalLM(Model):
super(FlashCausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
config=config,
requires_padding=False,
dtype=dtype,
device=device,

View File

@ -68,6 +68,7 @@ class FlashLlama(FlashCausalLM):
super(FlashLlama, self).__init__(
model=model,
tokenizer=tokenizer,
config=config,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_heads,
head_size=model.model.head_size,

View File

@ -59,6 +59,7 @@ class FlashNeoXSharded(FlashCausalLM):
super(FlashNeoXSharded, self).__init__(
model=model.to(device),
tokenizer=tokenizer,
config=config,
num_layers=len(model.gpt_neox.layers),
num_kv_heads=model.gpt_neox.num_heads,
head_size=model.gpt_neox.head_size,

View File

@ -65,6 +65,7 @@ class FlashRWSharded(FlashCausalLM):
super(FlashRWSharded, self).__init__(
model=model.to(device),
tokenizer=tokenizer,
config=config,
num_layers=len(model.transformer.h),
num_kv_heads=model.transformer.cache_size,
head_size=model.transformer.head_size,

View File

@ -66,6 +66,7 @@ class FlashSantacoderSharded(FlashCausalLM):
super(FlashSantacoderSharded, self).__init__(
model=model.to(device),
tokenizer=tokenizer,
config=config,
num_layers=len(model.transformer.h),
num_kv_heads=1,
head_size=model.transformer.head_size,

View File

@ -198,6 +198,7 @@ class GalacticaSharded(CausalLM):
super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
config=config,
requires_padding=True,
dtype=dtype,
device=device,

View File

@ -63,6 +63,7 @@ class GPTNeoxSharded(CausalLM):
super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
config=config,
requires_padding=True,
dtype=dtype,
device=device,

View File

@ -3,7 +3,7 @@ import torch
from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, TypeVar, Type
from transformers import PreTrainedTokenizerBase
from transformers import PreTrainedTokenizerBase, PretrainedConfig
from text_generation_server.models.types import Batch, GeneratedText
from text_generation_server.pb.generate_pb2 import InfoResponse
@ -23,6 +23,7 @@ class Model(ABC):
self,
model: torch.nn.Module,
tokenizer: PreTrainedTokenizerBase,
config: PretrainedConfig,
requires_padding: bool,
dtype: torch.dtype,
device: torch.device,
@ -45,24 +46,41 @@ class Model(ABC):
inspect.signature(model.forward).parameters.get("position_ids", None)
is not None
)
self.config = config
if model.config.quantize == "gptq":
if config.quantize == "gptq":
# Buffers need to be persistent to avoid any bug.
self.buffers = {}
max_dq_buffer_size = 0
for name, submodule in self.model.named_modules():
use_exllama_act_order = False
max_dq_buffer_size = 1
max_inner_outer_dim = 1
for name, submodule in model.named_modules():
if isinstance(submodule, (TensorParallelColumnLinear, TensorParallelRowLinear)) and isinstance(submodule.linear, Ex4bitLinear):
max_dq_buffer_size = max(max_dq_buffer_size, submodule.linear.qweight.numel() * 8)
intermediate_size = model.config.n_inner
max_seq_len = 2048 # TODO: we should be able to set it
self.buffers["temp_state"] = torch.zeros((max_seq_len, intermediate_size), dtype=torch.float16, device=device)
self.buffers["temp_dq"] = torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=device)
max_dq_buffer_size = max(max_dq_buffer_size, submodule.linear.qweight.numel() * 8)
if submodule.linear.act_order:
max_inner_outer_dim = max(max_inner_outer_dim, submodule.linear.height, submodule.linear.width)
use_exllama_act_order = True
if use_exllama_act_order:
# TODO: this should be set to rust side `max_total_tokens`, but TGI
# does not offer an API to expose this variable to python, as this variable
# is handled by the client but it appears the model is initialized by the server.
# An alternative could be to initialize the buffers during warmup.
max_total_tokens = 2048
else:
max_total_tokens = 1
# This temp_state buffer is required to reorder X in the act-order case.
self.buffers["temp_state"] = torch.zeros((max_total_tokens, max_inner_outer_dim), dtype=torch.float16, device=device)
# This temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
self.buffers["temp_dq"] = torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=device)
prepare_buffers(device, self.buffers["temp_state"], self.buffers["temp_dq"])
# TODO: ability to set them
matmul_recons_thd = 8
matmul_fused_remap = False
matmul_no_half2 = False

View File

@ -86,6 +86,7 @@ class MPTSharded(CausalLM):
super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
config=config,
requires_padding=False,
dtype=dtype,
device=device,

View File

@ -61,6 +61,7 @@ class OPTSharded(CausalLM):
super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
config=config,
requires_padding=True,
dtype=dtype,
device=device,

View File

@ -58,6 +58,7 @@ class RW(CausalLM):
super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
config=model.config,
requires_padding=True,
dtype=dtype,
device=device,

View File

@ -63,6 +63,7 @@ class SantaCoder(CausalLM):
super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
config=model.config,
requires_padding=True,
dtype=dtype,
device=device,

View File

@ -542,6 +542,7 @@ class Seq2SeqLM(Model):
super(Seq2SeqLM, self).__init__(
model=model,
tokenizer=tokenizer,
config=model.config,
requires_padding=True,
dtype=dtype,
device=device,

View File

@ -73,6 +73,7 @@ class T5Sharded(Seq2SeqLM):
super(Seq2SeqLM, self).__init__(
model=model,
tokenizer=tokenizer,
config=config,
requires_padding=True,
dtype=dtype,
device=device,

View File

@ -1,9 +1,8 @@
from pathlib import Path
from typing import List, Dict, Optional
from typing import List, Dict, Optional, Tuple
from safetensors import safe_open, SafetensorError
import torch
class Weights:
def __init__(
self,
@ -127,17 +126,7 @@ class Weights:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
try:
bits = self.get_tensor("gptq_bits").item()
groupsize = self.get_tensor("gptq_groupsize").item()
except SafetensorError as e:
try:
import os
bits = int(os.getenv("GPTQ_BITS"))
groupsize = int(os.getenv("GPTQ_GROUPSIZE"))
except Exception:
raise e
bits, groupsize = self.get_gptq_qparams()
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
else:
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
@ -149,7 +138,7 @@ class Weights:
use_triton_kernel = False
if self.process_group.size() > 1:
g_idx = self.get_tensor(f"{prefix}.g_idx")
groupsize = self.get_tensor("gptq_groupsize").item()
_, groupsize = self.get_gptq_qparams()
if g_idx is not None:
if not torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32)) and not (g_idx == 0).all():
@ -180,19 +169,24 @@ class Weights:
else:
g_idx = None
try:
bits = self.get_tensor("gptq_bits").item()
groupsize = self.get_tensor("gptq_groupsize").item()
except SafetensorError as e:
try:
import os
bits = int(os.getenv("GPTQ_BITS"))
groupsize = int(os.getenv("GPTQ_GROUPSIZE"))
except Exception:
raise e
bits, groupsize = self.get_gptq_qparams()
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_triton_kernel)
else:
weight = self.get_sharded(f"{prefix}.weight", dim=1)
return weight
def get_gptq_qparams(self) -> Tuple[int, int]:
try:
bits = self.get_tensor("gptq_bits").item()
groupsize = self.get_tensor("gptq_groupsize").item()
except (SafetensorError, RuntimeError) as e:
try:
import os
bits = int(os.getenv("GPTQ_BITS"))
groupsize = int(os.getenv("GPTQ_GROUPSIZE"))
except Exception:
raise e
return bits, groupsize