fix(server): llama v2 GPTQ (#648)
As per title & reported https://github.com/huggingface/text-generation-inference/issues/601#issuecomment-1641435956 https://huggingface.co/TheBloke/Llama-2-70B-chat-GPTQ/discussions/5 Test it: ``` GPTQ_BITS=4 GPTQ_GROUPSIZE=1 text-generation-launcher --model-id TheBloke/Llama-2-70B-chat-GPTQ --port 8080 --num-shard 4 --quantize gptq ``` & ``` curl 127.0.0.1:8080/generate \ -X POST \ -d '{"inputs":"hey llama","parameters":{"max_new_tokens":256}}' \ -H 'Content-Type: application/json' ```
This commit is contained in:
parent
214c06f510
commit
362883f259
|
@ -148,24 +148,27 @@ class LlamaRMSNorm(nn.Module):
|
|||
|
||||
|
||||
def _load_gqa(config, prefix: str, weights):
|
||||
w = [
|
||||
weights.get_sharded(f"{prefix}.q_proj.weight", dim=0),
|
||||
weights.get_sharded(f"{prefix}.k_proj.weight", dim=0),
|
||||
weights.get_sharded(f"{prefix}.v_proj.weight", dim=0),
|
||||
]
|
||||
weight = torch.cat(w, dim=0)
|
||||
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||
bias = None
|
||||
assert config.hidden_size % config.num_attention_heads == 0
|
||||
head_size = config.hidden_size // config.num_attention_heads
|
||||
assert config.num_attention_heads % weights.process_group.size() == 0
|
||||
num_heads = config.num_attention_heads // weights.process_group.size()
|
||||
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
|
||||
assert list(weight.shape) == [
|
||||
(num_heads + 2 * num_key_value_heads) * head_size,
|
||||
config.hidden_size,
|
||||
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
||||
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
|
||||
|
||||
weight = weights.get_multi_weights_col(
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
quantize=config.quantize,
|
||||
dim=0
|
||||
)
|
||||
|
||||
if config.quantize != "gptq":
|
||||
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||
|
||||
head_size = config.hidden_size // config.num_attention_heads
|
||||
num_heads = config.num_attention_heads // weights.process_group.size()
|
||||
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
|
||||
assert list(weight.shape) == [
|
||||
(num_heads + 2 * num_key_value_heads) * head_size,
|
||||
config.hidden_size,
|
||||
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
||||
|
||||
return TensorParallelColumnLinear(get_linear(weight, bias=None, quantize=config.quantize))
|
||||
|
||||
|
||||
class FlashLlamaAttention(torch.nn.Module):
|
||||
|
|
Loading…
Reference in New Issue