feat(server): add support for llamav2 (#633)
This commit is contained in:
parent
3b71c38558
commit
211b211ec0
|
@ -39,6 +39,7 @@ from text_generation_server.utils.layers import (
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
TensorParallelHead,
|
TensorParallelHead,
|
||||||
|
get_linear,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -59,7 +60,8 @@ class LlamaRMSNorm(nn.Module):
|
||||||
hidden_states += residual
|
hidden_states += residual
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||||
hidden_states = hidden_states * torch.rsqrt(
|
hidden_states = hidden_states * torch.rsqrt(
|
||||||
variance + self.variance_epsilon
|
variance + self.variance_epsilon
|
||||||
)
|
)
|
||||||
|
@ -94,6 +96,27 @@ class LlamaRMSNorm(nn.Module):
|
||||||
return normed_hidden_states, res
|
return normed_hidden_states, res
|
||||||
|
|
||||||
|
|
||||||
|
def _load_gqa(config, prefix: str, weights):
|
||||||
|
w = [
|
||||||
|
weights.get_sharded(f"{prefix}.q_proj.weight", dim=0),
|
||||||
|
weights.get_sharded(f"{prefix}.k_proj.weight", dim=0),
|
||||||
|
weights.get_sharded(f"{prefix}.v_proj.weight", dim=0),
|
||||||
|
]
|
||||||
|
weight = torch.cat(w, dim=0)
|
||||||
|
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||||
|
bias = None
|
||||||
|
assert config.hidden_size % config.num_attention_heads == 0
|
||||||
|
head_size = config.hidden_size // config.num_attention_heads
|
||||||
|
assert config.num_attention_heads % weights.process_group.size() == 0
|
||||||
|
num_heads = config.num_attention_heads // weights.process_group.size()
|
||||||
|
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
|
||||||
|
assert list(weight.shape) == [
|
||||||
|
(num_heads + 2 * num_key_value_heads) * head_size,
|
||||||
|
config.hidden_size,
|
||||||
|
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
||||||
|
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
|
||||||
|
|
||||||
|
|
||||||
class FlashLlamaAttention(torch.nn.Module):
|
class FlashLlamaAttention(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -118,22 +141,29 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
f"and `num_shards`: {weights.process_group.size()}"
|
f"and `num_shards`: {weights.process_group.size()}"
|
||||||
)
|
)
|
||||||
self.num_heads = self.num_heads // weights.process_group.size()
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
self.query_key_value = TensorParallelColumnLinear.load_multi(
|
self.num_key_value_heads = (
|
||||||
config,
|
config.num_key_value_heads // weights.process_group.size()
|
||||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
|
||||||
dim=0,
|
|
||||||
weights=weights,
|
|
||||||
bias=False,
|
|
||||||
)
|
)
|
||||||
|
if config.num_attention_heads != config.num_key_value_heads:
|
||||||
|
self.query_key_value = _load_gqa(config, prefix, weights)
|
||||||
|
else:
|
||||||
|
self.query_key_value = TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
|
dim=0,
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
self.o_proj = TensorParallelRowLinear.load(
|
self.o_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.o_proj",
|
prefix=f"{prefix}.o_proj",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
|
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||||
self.kv_head_mapping = torch.arange(
|
self.kv_head_mapping = torch.arange(
|
||||||
0, self.num_heads, dtype=torch.int32, device=weights.device
|
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||||
)
|
).repeat_interleave(self.num_groups)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -148,26 +178,33 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
max_s,
|
max_s,
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
query, kv = qkv.split(
|
||||||
|
[
|
||||||
|
self.head_size * self.num_heads,
|
||||||
|
2 * self.head_size * self.num_key_value_heads,
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
|
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||||
|
|
||||||
# Inplace rotary
|
self.rotary_emb(query, cos, sin)
|
||||||
self.rotary_emb(qkv[:, 0], cos, sin)
|
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
self.rotary_emb(qkv[:, 1], cos, sin)
|
|
||||||
|
|
||||||
vllm_cache_ops.reshape_and_cache(
|
vllm_cache_ops.reshape_and_cache(
|
||||||
qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots
|
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
||||||
)
|
)
|
||||||
|
|
||||||
# output tensor
|
# output tensor
|
||||||
attn_output = torch.empty_like(qkv[:, 0])
|
attn_output = torch.empty_like(query)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
attention(
|
attention(
|
||||||
qkv[:, 0],
|
query,
|
||||||
qkv[:, 1],
|
torch.select(kv, dim=1, index=0),
|
||||||
qkv[:, 2],
|
torch.select(kv, dim=1, index=1),
|
||||||
attn_output,
|
attn_output,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
|
@ -179,7 +216,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
block_size = kv_cache[1].shape[3]
|
block_size = kv_cache[1].shape[3]
|
||||||
vllm_attention_ops.single_query_cached_kv_attention(
|
vllm_attention_ops.single_query_cached_kv_attention(
|
||||||
attn_output,
|
attn_output,
|
||||||
qkv[:, 0],
|
query,
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
kv_cache[1],
|
kv_cache[1],
|
||||||
self.kv_head_mapping,
|
self.kv_head_mapping,
|
||||||
|
@ -316,6 +353,7 @@ class FlashLlamaModel(torch.nn.Module):
|
||||||
|
|
||||||
self.head_size = self.layers[0].self_attn.head_size
|
self.head_size = self.layers[0].self_attn.head_size
|
||||||
self.num_heads = self.layers[0].self_attn.num_heads
|
self.num_heads = self.layers[0].self_attn.num_heads
|
||||||
|
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -69,7 +69,7 @@ class FlashLlama(FlashCausalLM):
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
num_layers=len(model.model.layers),
|
num_layers=len(model.model.layers),
|
||||||
num_kv_heads=model.model.num_heads,
|
num_kv_heads=model.model.num_key_value_heads,
|
||||||
head_size=model.model.head_size,
|
head_size=model.model.head_size,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
|
|
Loading…
Reference in New Issue