489 lines
15 KiB
Python
489 lines
15 KiB
Python
# coding=utf-8
|
|
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
|
# and OPT implementations in this library. It has been modified from its
|
|
# original forms to accommodate minor architectural differences compared
|
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import torch
|
|
import torch.distributed
|
|
|
|
from torch import nn
|
|
from transformers.activations import ACT2FN
|
|
from transformers.configuration_utils import PretrainedConfig
|
|
from typing import Optional, List, Tuple
|
|
|
|
# Flash attention imports
|
|
import dropout_layer_norm
|
|
|
|
# vllm imports
|
|
import vllm_cache_ops
|
|
import vllm_attention_ops
|
|
|
|
from text_generation_server.utils.flash_attn import attention
|
|
from text_generation_server.utils.layers import (
|
|
TensorParallelRowLinear,
|
|
TensorParallelColumnLinear,
|
|
TensorParallelEmbedding,
|
|
PositionRotaryEmbedding,
|
|
TensorParallelHead,
|
|
get_linear,
|
|
)
|
|
|
|
|
|
class LlamaConfig(PretrainedConfig):
|
|
def __init__(
|
|
self,
|
|
vocab_size=32000,
|
|
hidden_size=4096,
|
|
intermediate_size=11008,
|
|
num_hidden_layers=32,
|
|
num_attention_heads=32,
|
|
num_key_value_heads=None,
|
|
hidden_act="silu",
|
|
max_position_embeddings=2048,
|
|
initializer_range=0.02,
|
|
rms_norm_eps=1e-6,
|
|
use_cache=True,
|
|
pad_token_id=0,
|
|
bos_token_id=1,
|
|
eos_token_id=2,
|
|
pretraining_tp=1,
|
|
tie_word_embeddings=False,
|
|
rope_scaling=None,
|
|
**kwargs,
|
|
):
|
|
self.vocab_size = vocab_size
|
|
self.max_position_embeddings = max_position_embeddings
|
|
self.hidden_size = hidden_size
|
|
self.intermediate_size = intermediate_size
|
|
self.num_hidden_layers = num_hidden_layers
|
|
self.num_attention_heads = num_attention_heads
|
|
|
|
# for backward compatibility
|
|
if num_key_value_heads is None:
|
|
num_key_value_heads = num_attention_heads
|
|
|
|
self.num_key_value_heads = num_key_value_heads
|
|
self.hidden_act = hidden_act
|
|
self.initializer_range = initializer_range
|
|
self.rms_norm_eps = rms_norm_eps
|
|
self.pretraining_tp = pretraining_tp
|
|
self.use_cache = use_cache
|
|
self.rope_scaling = rope_scaling
|
|
|
|
super().__init__(
|
|
pad_token_id=pad_token_id,
|
|
bos_token_id=bos_token_id,
|
|
eos_token_id=eos_token_id,
|
|
tie_word_embeddings=tie_word_embeddings,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
class LlamaRMSNorm(nn.Module):
|
|
def __init__(self, prefix, weights, eps=1e-6):
|
|
"""
|
|
LlamaRMSNorm is equivalent to T5LayerNorm
|
|
"""
|
|
super().__init__()
|
|
|
|
weight = weights.get_tensor(f"{prefix}.weight")
|
|
self.weight = nn.Parameter(weight)
|
|
self.variance_epsilon = eps
|
|
|
|
def forward(self, hidden_states, residual=None):
|
|
if hidden_states.shape[-1] > 8192:
|
|
if residual is not None:
|
|
hidden_states += residual
|
|
residual = hidden_states
|
|
|
|
hidden_states = hidden_states.to(torch.float32)
|
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
|
hidden_states = hidden_states * torch.rsqrt(
|
|
variance + self.variance_epsilon
|
|
)
|
|
|
|
# convert into half-precision if necessary
|
|
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
|
hidden_states = hidden_states.to(self.weight.dtype)
|
|
|
|
return self.weight * hidden_states, residual
|
|
else:
|
|
# faster post attention rms norm
|
|
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
|
hidden_states,
|
|
residual,
|
|
self.weight,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
0.0,
|
|
self.variance_epsilon,
|
|
1.0,
|
|
0,
|
|
None,
|
|
False,
|
|
True, # Activate RMSNorm
|
|
)
|
|
if res is None:
|
|
res = hidden_states
|
|
|
|
return normed_hidden_states, res
|
|
|
|
|
|
def _load_gqa(config, prefix: str, weights):
|
|
assert config.hidden_size % config.num_attention_heads == 0
|
|
assert config.num_attention_heads % weights.process_group.size() == 0
|
|
|
|
weight = weights.get_multi_weights_col(
|
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
|
quantize=config.quantize,
|
|
dim=0,
|
|
)
|
|
|
|
if config.quantize != "gptq":
|
|
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
|
|
|
head_size = config.hidden_size // config.num_attention_heads
|
|
num_heads = config.num_attention_heads // weights.process_group.size()
|
|
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
|
|
assert list(weight.shape) == [
|
|
(num_heads + 2 * num_key_value_heads) * head_size,
|
|
config.hidden_size,
|
|
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
|
|
|
return TensorParallelColumnLinear(
|
|
get_linear(weight, bias=None, quantize=config.quantize)
|
|
)
|
|
|
|
|
|
class FlashLlamaAttention(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
prefix: str,
|
|
config,
|
|
weights,
|
|
):
|
|
super().__init__()
|
|
self.num_heads = config.num_attention_heads
|
|
self.hidden_size = config.hidden_size
|
|
self.head_size = self.hidden_size // self.num_heads
|
|
|
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
|
dim=self.head_size, device=weights.device, base=10000.0,
|
|
)
|
|
|
|
self.softmax_scale = self.head_size**-0.5
|
|
|
|
if self.num_heads % weights.process_group.size() != 0:
|
|
raise ValueError(
|
|
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
|
f"and `num_shards`: {weights.process_group.size()}"
|
|
)
|
|
self.num_heads = self.num_heads // weights.process_group.size()
|
|
self.num_key_value_heads = (
|
|
config.num_key_value_heads // weights.process_group.size()
|
|
)
|
|
if config.num_attention_heads != config.num_key_value_heads:
|
|
self.query_key_value = _load_gqa(config, prefix, weights)
|
|
else:
|
|
self.query_key_value = TensorParallelColumnLinear.load_multi(
|
|
config,
|
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
|
dim=0,
|
|
weights=weights,
|
|
bias=False,
|
|
)
|
|
self.o_proj = TensorParallelRowLinear.load(
|
|
config,
|
|
prefix=f"{prefix}.o_proj",
|
|
weights=weights,
|
|
bias=False,
|
|
)
|
|
self.num_groups = self.num_heads // self.num_key_value_heads
|
|
self.kv_head_mapping = torch.arange(
|
|
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
|
).repeat_interleave(self.num_groups)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states,
|
|
cos,
|
|
sin,
|
|
cu_seqlen_prefill,
|
|
kv_cache,
|
|
block_tables,
|
|
slots,
|
|
input_lengths,
|
|
max_s,
|
|
):
|
|
qkv = self.query_key_value(hidden_states)
|
|
query, kv = qkv.split(
|
|
[
|
|
self.head_size * self.num_heads,
|
|
2 * self.head_size * self.num_key_value_heads,
|
|
],
|
|
dim=1,
|
|
)
|
|
query = query.view(-1, self.num_heads, self.head_size)
|
|
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
|
|
|
self.rotary_emb(query, cos, sin)
|
|
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
|
|
|
vllm_cache_ops.reshape_and_cache(
|
|
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
|
)
|
|
|
|
# output tensor
|
|
attn_output = torch.empty_like(query)
|
|
|
|
# Prefill
|
|
if cu_seqlen_prefill is not None:
|
|
# flash attention
|
|
attention(
|
|
query,
|
|
torch.select(kv, dim=1, index=0),
|
|
torch.select(kv, dim=1, index=1),
|
|
attn_output,
|
|
cu_seqlen_prefill,
|
|
max_s,
|
|
self.softmax_scale,
|
|
)
|
|
# Decode
|
|
else:
|
|
# kv_cache[1] => [num_blocks, num_heads, head_size, block_size]
|
|
block_size = kv_cache[1].shape[3]
|
|
vllm_attention_ops.single_query_cached_kv_attention(
|
|
attn_output,
|
|
query,
|
|
kv_cache[0],
|
|
kv_cache[1],
|
|
self.kv_head_mapping,
|
|
self.softmax_scale,
|
|
block_tables,
|
|
input_lengths,
|
|
block_size,
|
|
max_s,
|
|
)
|
|
|
|
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
|
|
|
|
|
class LlamaMLP(nn.Module):
|
|
def __init__(self, prefix, config, weights):
|
|
super().__init__()
|
|
act = config.hidden_act
|
|
self.act = (
|
|
ACT2FN[act]
|
|
if "gelu" not in act
|
|
else lambda x: torch.nn.functional.gelu(
|
|
x,
|
|
approximate="tanh"
|
|
if act in ["gelu_fast", "gelu_pytorch_tanh"]
|
|
else "none",
|
|
)
|
|
)
|
|
# Fuse gate and up proj
|
|
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
|
config,
|
|
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
|
weights=weights,
|
|
dim=0,
|
|
bias=False,
|
|
)
|
|
self.down_proj = TensorParallelRowLinear.load(
|
|
config,
|
|
prefix=f"{prefix}.down_proj",
|
|
weights=weights,
|
|
bias=False,
|
|
)
|
|
self.intermediate_size = (
|
|
config.intermediate_size // weights.process_group.size()
|
|
)
|
|
|
|
def forward(self, hidden_states):
|
|
gate_up_states = self.gate_up_proj(hidden_states)
|
|
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
|
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
|
|
|
|
|
|
class FlashLlamaLayer(nn.Module):
|
|
def __init__(self, layer_id, config, weights):
|
|
super().__init__()
|
|
prefix = f"model.layers.{layer_id}"
|
|
self.self_attn = FlashLlamaAttention(
|
|
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
|
)
|
|
self.mlp = LlamaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
|
|
|
self.input_layernorm = LlamaRMSNorm(
|
|
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
|
)
|
|
self.post_attention_layernorm = LlamaRMSNorm(
|
|
prefix=f"{prefix}.post_attention_layernorm",
|
|
weights=weights,
|
|
eps=config.rms_norm_eps,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states,
|
|
residual,
|
|
cos,
|
|
sin,
|
|
cu_seqlen_prefill,
|
|
kv_cache,
|
|
block_tables,
|
|
slots,
|
|
input_lengths,
|
|
max_s,
|
|
):
|
|
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
|
|
|
# Self Attention
|
|
attn_output = self.self_attn(
|
|
normed_hidden_states,
|
|
cos,
|
|
sin,
|
|
cu_seqlen_prefill,
|
|
kv_cache,
|
|
block_tables,
|
|
slots,
|
|
input_lengths,
|
|
max_s,
|
|
)
|
|
|
|
# faster post attention rms norm
|
|
normed_attn_res_output, attn_res = self.post_attention_layernorm(
|
|
attn_output, res
|
|
)
|
|
|
|
mlp_output = self.mlp(normed_attn_res_output)
|
|
|
|
return mlp_output, attn_res
|
|
|
|
|
|
class FlashLlamaModel(torch.nn.Module):
|
|
def __init__(self, config, weights):
|
|
super().__init__()
|
|
|
|
process_group = weights.process_group
|
|
self.tp_rank = process_group.rank()
|
|
self.tp_world_size = process_group.size()
|
|
self.embed_tokens = TensorParallelEmbedding(
|
|
prefix="model.embed_tokens", weights=weights
|
|
)
|
|
self.layers = nn.ModuleList(
|
|
[
|
|
FlashLlamaLayer(
|
|
layer_id,
|
|
config,
|
|
weights,
|
|
)
|
|
for layer_id in range(config.num_hidden_layers)
|
|
]
|
|
)
|
|
self.norm = LlamaRMSNorm(
|
|
prefix="model.norm", weights=weights, eps=config.rms_norm_eps
|
|
)
|
|
|
|
self.gradient_checkpointing = False
|
|
|
|
self.head_size = self.layers[0].self_attn.head_size
|
|
self.num_heads = self.layers[0].self_attn.num_heads
|
|
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
position_ids: torch.Tensor,
|
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
|
block_tables: torch.Tensor,
|
|
slots: torch.Tensor,
|
|
input_lengths: torch.Tensor,
|
|
max_s: int,
|
|
) -> torch.Tensor:
|
|
hidden_states = self.embed_tokens(input_ids)
|
|
|
|
# Get rotary cos and sin for this forward
|
|
# Avoid to index in each layer
|
|
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
|
position_ids, max_s, hidden_states.dtype
|
|
)
|
|
|
|
residual = None
|
|
for i, layer in enumerate(self.layers):
|
|
hidden_states, residual = layer(
|
|
hidden_states,
|
|
residual,
|
|
cos,
|
|
sin,
|
|
cu_seqlen_prefill,
|
|
kv_cache[i],
|
|
block_tables,
|
|
slots,
|
|
input_lengths,
|
|
max_s,
|
|
)
|
|
|
|
hidden_states, _ = self.norm(hidden_states, residual)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class FlashLlamaForCausalLM(torch.nn.Module):
|
|
def __init__(self, config, weights):
|
|
super().__init__()
|
|
|
|
self.model = FlashLlamaModel(config, weights)
|
|
self.lm_head = TensorParallelHead.load(
|
|
config,
|
|
prefix="lm_head",
|
|
weights=weights,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
position_ids: torch.Tensor,
|
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
|
block_tables: torch.Tensor,
|
|
slots: torch.Tensor,
|
|
input_lengths: torch.Tensor,
|
|
max_s: int,
|
|
lm_head_indices: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
hidden_states = self.model(
|
|
input_ids,
|
|
position_ids,
|
|
cu_seqlen_prefill,
|
|
kv_cache,
|
|
block_tables,
|
|
slots,
|
|
input_lengths,
|
|
max_s,
|
|
)
|
|
if lm_head_indices is not None:
|
|
hidden_states = hidden_states[lm_head_indices]
|
|
logits = self.lm_head(hidden_states)
|
|
return logits
|