fix(server): llama v2 GPTQ (#648)

As per title & reported
https://github.com/huggingface/text-generation-inference/issues/601#issuecomment-1641435956
https://huggingface.co/TheBloke/Llama-2-70B-chat-GPTQ/discussions/5

Test it:

```
GPTQ_BITS=4 GPTQ_GROUPSIZE=1 text-generation-launcher --model-id TheBloke/Llama-2-70B-chat-GPTQ --port 8080 --num-shard 4 --quantize gptq
```
&
```
curl 127.0.0.1:8080/generate \
    -X POST \
    -d '{"inputs":"hey llama","parameters":{"max_new_tokens":256}}' \
    -H 'Content-Type: application/json'
```
This commit is contained in:
fxmarty 2023-07-20 15:02:54 +02:00 committed by GitHub
parent 214c06f510
commit 362883f259
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 19 additions and 16 deletions

View File

@ -148,24 +148,27 @@ class LlamaRMSNorm(nn.Module):
def _load_gqa(config, prefix: str, weights):
w = [
weights.get_sharded(f"{prefix}.q_proj.weight", dim=0),
weights.get_sharded(f"{prefix}.k_proj.weight", dim=0),
weights.get_sharded(f"{prefix}.v_proj.weight", dim=0),
]
weight = torch.cat(w, dim=0)
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
bias = None
assert config.hidden_size % config.num_attention_heads == 0
head_size = config.hidden_size // config.num_attention_heads
assert config.num_attention_heads % weights.process_group.size() == 0
num_heads = config.num_attention_heads // weights.process_group.size()
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
assert list(weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
weight = weights.get_multi_weights_col(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
quantize=config.quantize,
dim=0
)
if config.quantize != "gptq":
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.hidden_size // config.num_attention_heads
num_heads = config.num_attention_heads // weights.process_group.size()
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
assert list(weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
return TensorParallelColumnLinear(get_linear(weight, bias=None, quantize=config.quantize))
class FlashLlamaAttention(torch.nn.Module):