feat(server): add support for llamav2 (#633)

This commit is contained in:
Nicolas Patry 2023-07-18 18:09:53 +02:00 committed by GitHub
parent 3b71c38558
commit 211b211ec0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 58 additions and 20 deletions

View File

@ -39,6 +39,7 @@ from text_generation_server.utils.layers import (
TensorParallelEmbedding, TensorParallelEmbedding,
PositionRotaryEmbedding, PositionRotaryEmbedding,
TensorParallelHead, TensorParallelHead,
get_linear,
) )
@ -59,7 +60,8 @@ class LlamaRMSNorm(nn.Module):
hidden_states += residual hidden_states += residual
residual = hidden_states residual = hidden_states
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt( hidden_states = hidden_states * torch.rsqrt(
variance + self.variance_epsilon variance + self.variance_epsilon
) )
@ -94,6 +96,27 @@ class LlamaRMSNorm(nn.Module):
return normed_hidden_states, res return normed_hidden_states, res
def _load_gqa(config, prefix: str, weights):
w = [
weights.get_sharded(f"{prefix}.q_proj.weight", dim=0),
weights.get_sharded(f"{prefix}.k_proj.weight", dim=0),
weights.get_sharded(f"{prefix}.v_proj.weight", dim=0),
]
weight = torch.cat(w, dim=0)
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
bias = None
assert config.hidden_size % config.num_attention_heads == 0
head_size = config.hidden_size // config.num_attention_heads
assert config.num_attention_heads % weights.process_group.size() == 0
num_heads = config.num_attention_heads // weights.process_group.size()
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
assert list(weight.shape) == [
(num_heads + 2 * num_key_value_heads) * head_size,
config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
class FlashLlamaAttention(torch.nn.Module): class FlashLlamaAttention(torch.nn.Module):
def __init__( def __init__(
self, self,
@ -118,6 +141,12 @@ class FlashLlamaAttention(torch.nn.Module):
f"and `num_shards`: {weights.process_group.size()}" f"and `num_shards`: {weights.process_group.size()}"
) )
self.num_heads = self.num_heads // weights.process_group.size() self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size()
)
if config.num_attention_heads != config.num_key_value_heads:
self.query_key_value = _load_gqa(config, prefix, weights)
else:
self.query_key_value = TensorParallelColumnLinear.load_multi( self.query_key_value = TensorParallelColumnLinear.load_multi(
config, config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
@ -131,9 +160,10 @@ class FlashLlamaAttention(torch.nn.Module):
weights=weights, weights=weights,
bias=False, bias=False,
) )
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange( self.kv_head_mapping = torch.arange(
0, self.num_heads, dtype=torch.int32, device=weights.device 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
) ).repeat_interleave(self.num_groups)
def forward( def forward(
self, self,
@ -148,26 +178,33 @@ class FlashLlamaAttention(torch.nn.Module):
max_s, max_s,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, 3, self.num_heads, self.head_size) query, kv = qkv.split(
[
self.head_size * self.num_heads,
2 * self.head_size * self.num_key_value_heads,
],
dim=1,
)
query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
# Inplace rotary self.rotary_emb(query, cos, sin)
self.rotary_emb(qkv[:, 0], cos, sin) self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
self.rotary_emb(qkv[:, 1], cos, sin)
vllm_cache_ops.reshape_and_cache( vllm_cache_ops.reshape_and_cache(
qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
) )
# output tensor # output tensor
attn_output = torch.empty_like(qkv[:, 0]) attn_output = torch.empty_like(query)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attention( attention(
qkv[:, 0], query,
qkv[:, 1], torch.select(kv, dim=1, index=0),
qkv[:, 2], torch.select(kv, dim=1, index=1),
attn_output, attn_output,
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
@ -179,7 +216,7 @@ class FlashLlamaAttention(torch.nn.Module):
block_size = kv_cache[1].shape[3] block_size = kv_cache[1].shape[3]
vllm_attention_ops.single_query_cached_kv_attention( vllm_attention_ops.single_query_cached_kv_attention(
attn_output, attn_output,
qkv[:, 0], query,
kv_cache[0], kv_cache[0],
kv_cache[1], kv_cache[1],
self.kv_head_mapping, self.kv_head_mapping,
@ -316,6 +353,7 @@ class FlashLlamaModel(torch.nn.Module):
self.head_size = self.layers[0].self_attn.head_size self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads self.num_heads = self.layers[0].self_attn.num_heads
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
def forward( def forward(
self, self,

View File

@ -69,7 +69,7 @@ class FlashLlama(FlashCausalLM):
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
num_layers=len(model.model.layers), num_layers=len(model.model.layers),
num_kv_heads=model.model.num_heads, num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size, head_size=model.model.head_size,
dtype=dtype, dtype=dtype,
device=device, device=device,