feat(server): cleanup flash neox loading (#139)
This commit is contained in:
parent
d6a93fe992
commit
678b2f3900
|
@ -450,8 +450,6 @@ class FlashNeoX(Model):
|
|||
next_batch_input_ids = next_batch_input_ids[0].view(1)
|
||||
next_batch_past_key_values = next_batch_past_key_values[0]
|
||||
|
||||
print(next_batch_input_ids.shape)
|
||||
|
||||
next_batch = FlashNeoXBatch(
|
||||
batch_id=batch.batch_id,
|
||||
requests=next_batch_requests,
|
||||
|
@ -507,6 +505,7 @@ class FlashNeoXSharded(FlashNeoX):
|
|||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
)
|
||||
model.post_load_weights()
|
||||
self.model = model.eval().to(dtype)
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashNeoX, self).__init__(
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from torch.nn import functional as F
|
||||
|
||||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
|
@ -24,13 +26,11 @@ class FastLinear(nn.Linear):
|
|||
dtype=None,
|
||||
) -> None:
|
||||
super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype)
|
||||
self.swap_dims = True
|
||||
|
||||
def transpose_weight(self):
|
||||
self.weight = nn.Parameter(self.weight.T)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if self.swap_dims:
|
||||
self.weight = nn.Parameter(self.weight.T)
|
||||
self.swap_dims = False
|
||||
|
||||
if self.bias is not None:
|
||||
return torch.addmm(self.bias, input, self.weight)
|
||||
return torch.matmul(input, self.weight)
|
||||
|
@ -120,6 +120,10 @@ class TensorParallelEmbedding(nn.Embedding):
|
|||
self.min_id = self.tp_rank * block_size
|
||||
self.max_id = (self.tp_rank + 1) * block_size
|
||||
|
||||
# Additional entry that will map to zero
|
||||
# Used for masking
|
||||
self.null_idx = block_size
|
||||
|
||||
super().__init__(
|
||||
block_size,
|
||||
embedding_dim,
|
||||
|
@ -133,15 +137,19 @@ class TensorParallelEmbedding(nn.Embedding):
|
|||
dtype=dtype,
|
||||
)
|
||||
|
||||
def add_null_idx(self):
|
||||
"""Additional 0 entry used for masking"""
|
||||
self.weight = nn.Parameter(F.pad(self.weight, (0, 0, 0, 1)))
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
# `0` if input is in the correct interval, else `1`
|
||||
input_mask = torch.logical_or(self.min_id > input, input >= self.max_id)
|
||||
# default all out of bounds values to `self.null_idx` that will then be mapped to 0
|
||||
# translate for [0, self.max_id - self.min_id[
|
||||
input = input - self.min_id
|
||||
# default all out of bounds values to `0`
|
||||
input[input_mask] = 0
|
||||
input = torch.where(
|
||||
(self.min_id > input) | (input >= self.max_id),
|
||||
self.null_idx,
|
||||
input - self.min_id,
|
||||
)
|
||||
out = super().forward(input)
|
||||
out[input_mask] = 0.0
|
||||
torch.distributed.all_reduce(out, group=self.process_group)
|
||||
return out
|
||||
|
||||
|
@ -214,11 +222,9 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||
hidden_size,
|
||||
process_group=process_group,
|
||||
)
|
||||
self.swap_dims = True
|
||||
|
||||
# TODO: remove and swap dims when loading weights
|
||||
def _swap_dims(self):
|
||||
"""Swap dims for the first inference to avoid an additional permute"""
|
||||
def shuffle_qkv_dims(self):
|
||||
"""Swap dims to avoid an additional permute"""
|
||||
self.query_key_value.weight = torch.nn.Parameter(
|
||||
self.query_key_value.weight.view(
|
||||
self.num_heads, 3, self.head_size, self.hidden_size
|
||||
|
@ -231,7 +237,6 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||
.permute(1, 0, 2)
|
||||
.reshape(-1)
|
||||
)
|
||||
self.swap_dims = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -244,9 +249,6 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||
layer_past_present_indices,
|
||||
cu_seqlens_q,
|
||||
):
|
||||
if self.swap_dims:
|
||||
self._swap_dims()
|
||||
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
||||
qkv_rot = self.rotary_emb(qkv, cos, sin)
|
||||
|
@ -329,7 +331,6 @@ class FlashMLP(nn.Module):
|
|||
hidden_size,
|
||||
process_group=process_group,
|
||||
)
|
||||
self.heuristic = "auto"
|
||||
self.process_group = process_group
|
||||
|
||||
def forward(self, hidden_states):
|
||||
|
@ -531,6 +532,25 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
|||
self.head_size = self.layers[0].attention.head_size
|
||||
self.num_heads = self.layers[0].attention.num_heads
|
||||
|
||||
def post_load_weights(self):
|
||||
if isinstance(self.embed_in, TensorParallelEmbedding):
|
||||
self.embed_in.add_null_idx()
|
||||
for layer in self.layers:
|
||||
layer: FlashNeoXLayer
|
||||
layer.attention.shuffle_qkv_dims()
|
||||
layer.attention.query_key_value.transpose_weight()
|
||||
layer.attention.dense.transpose_weight()
|
||||
layer.mlp.dense_h_to_4h.transpose_weight()
|
||||
layer.mlp.dense_4h_to_h.transpose_weight()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
model = super(FlashGPTNeoXModel, cls).from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, **kwargs
|
||||
)
|
||||
model.post_load_weights()
|
||||
return model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
|
@ -627,6 +647,18 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
|||
config.hidden_size, config.vocab_size, bias=False
|
||||
)
|
||||
|
||||
def post_load_weights(self):
|
||||
self.gpt_neox.post_load_weights()
|
||||
self.embed_out.transpose_weight()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
model = super(FlashGPTNeoXForCausalLM, cls).from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, **kwargs
|
||||
)
|
||||
model.post_load_weights()
|
||||
return model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
|
|
Loading…
Reference in New Issue