feat(server): add support for llamav2 (#633)

This commit is contained in:
Nicolas Patry 2023-07-18 18:09:53 +02:00 committed by GitHub
parent 3b71c38558
commit 211b211ec0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 58 additions and 20 deletions

View File

@ -39,6 +39,7 @@ from text_generation_server.utils.layers import (
TensorParallelEmbedding,
PositionRotaryEmbedding,
TensorParallelHead,
get_linear,
)
@ -59,7 +60,8 @@ class LlamaRMSNorm(nn.Module):
hidden_states += residual
residual = hidden_states
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(
variance + self.variance_epsilon
)
@ -94,6 +96,27 @@ class LlamaRMSNorm(nn.Module):
return normed_hidden_states, res
def _load_gqa(config, prefix: str, weights):
w = [
weights.get_sharded(f"{prefix}.q_proj.weight", dim=0),
weights.get_sharded(f"{prefix}.k_proj.weight", dim=0),
weights.get_sharded(f"{prefix}.v_proj.weight", dim=0),
]
weight = torch.cat(w, dim=0)
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
bias = None
assert config.hidden_size % config.num_attention_heads == 0
head_size = config.hidden_size // config.num_attention_heads
assert config.num_attention_heads % weights.process_group.size() == 0
num_heads = config.num_attention_heads // weights.process_group.size()
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
assert list(weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
class FlashLlamaAttention(torch.nn.Module):
def __init__(
self,
@ -118,6 +141,12 @@ class FlashLlamaAttention(torch.nn.Module):
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size()
)
if config.num_attention_heads != config.num_key_value_heads:
self.query_key_value = _load_gqa(config, prefix, weights)
else:
self.query_key_value = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
@ -131,9 +160,10 @@ class FlashLlamaAttention(torch.nn.Module):
weights=weights,
bias=False,
)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_heads, dtype=torch.int32, device=weights.device
)
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
def forward(
self,
@ -148,26 +178,33 @@ class FlashLlamaAttention(torch.nn.Module):
max_s,
):
qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
query, kv = qkv.split(
[
self.head_size * self.num_heads,
2 * self.head_size * self.num_key_value_heads,
],
dim=1,
)
query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
# Inplace rotary
self.rotary_emb(qkv[:, 0], cos, sin)
self.rotary_emb(qkv[:, 1], cos, sin)
self.rotary_emb(query, cos, sin)
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
vllm_cache_ops.reshape_and_cache(
qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
)
# output tensor
attn_output = torch.empty_like(qkv[:, 0])
attn_output = torch.empty_like(query)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attention(
qkv[:, 0],
qkv[:, 1],
qkv[:, 2],
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
attn_output,
cu_seqlen_prefill,
max_s,
@ -179,7 +216,7 @@ class FlashLlamaAttention(torch.nn.Module):
block_size = kv_cache[1].shape[3]
vllm_attention_ops.single_query_cached_kv_attention(
attn_output,
qkv[:, 0],
query,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
@ -316,6 +353,7 @@ class FlashLlamaModel(torch.nn.Module):
self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
def forward(
self,

View File

@ -69,7 +69,7 @@ class FlashLlama(FlashCausalLM):
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_heads,
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,