feat(server): cleanup flash neox loading (#139)

This commit is contained in:
OlivierDehaene 2023-03-26 16:37:21 +02:00 committed by GitHub
parent d6a93fe992
commit 678b2f3900
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 53 additions and 22 deletions

View File

@ -450,8 +450,6 @@ class FlashNeoX(Model):
next_batch_input_ids = next_batch_input_ids[0].view(1)
next_batch_past_key_values = next_batch_past_key_values[0]
print(next_batch_input_ids.shape)
next_batch = FlashNeoXBatch(
batch_id=batch.batch_id,
requests=next_batch_requests,
@ -507,6 +505,7 @@ class FlashNeoXSharded(FlashNeoX):
rank=self.rank,
world_size=self.world_size,
)
model.post_load_weights()
self.model = model.eval().to(dtype)
torch.distributed.barrier(group=self.process_group)
super(FlashNeoX, self).__init__(

View File

@ -1,6 +1,8 @@
import torch
import torch.distributed
from torch.nn import functional as F
from torch import nn
from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel
@ -24,13 +26,11 @@ class FastLinear(nn.Linear):
dtype=None,
) -> None:
super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype)
self.swap_dims = True
def transpose_weight(self):
self.weight = nn.Parameter(self.weight.T)
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.swap_dims:
self.weight = nn.Parameter(self.weight.T)
self.swap_dims = False
if self.bias is not None:
return torch.addmm(self.bias, input, self.weight)
return torch.matmul(input, self.weight)
@ -120,6 +120,10 @@ class TensorParallelEmbedding(nn.Embedding):
self.min_id = self.tp_rank * block_size
self.max_id = (self.tp_rank + 1) * block_size
# Additional entry that will map to zero
# Used for masking
self.null_idx = block_size
super().__init__(
block_size,
embedding_dim,
@ -133,15 +137,19 @@ class TensorParallelEmbedding(nn.Embedding):
dtype=dtype,
)
def add_null_idx(self):
"""Additional 0 entry used for masking"""
self.weight = nn.Parameter(F.pad(self.weight, (0, 0, 0, 1)))
def forward(self, input: torch.Tensor) -> torch.Tensor:
# `0` if input is in the correct interval, else `1`
input_mask = torch.logical_or(self.min_id > input, input >= self.max_id)
# default all out of bounds values to `self.null_idx` that will then be mapped to 0
# translate for [0, self.max_id - self.min_id[
input = input - self.min_id
# default all out of bounds values to `0`
input[input_mask] = 0
input = torch.where(
(self.min_id > input) | (input >= self.max_id),
self.null_idx,
input - self.min_id,
)
out = super().forward(input)
out[input_mask] = 0.0
torch.distributed.all_reduce(out, group=self.process_group)
return out
@ -214,11 +222,9 @@ class FlashNeoxAttention(torch.nn.Module):
hidden_size,
process_group=process_group,
)
self.swap_dims = True
# TODO: remove and swap dims when loading weights
def _swap_dims(self):
"""Swap dims for the first inference to avoid an additional permute"""
def shuffle_qkv_dims(self):
"""Swap dims to avoid an additional permute"""
self.query_key_value.weight = torch.nn.Parameter(
self.query_key_value.weight.view(
self.num_heads, 3, self.head_size, self.hidden_size
@ -231,7 +237,6 @@ class FlashNeoxAttention(torch.nn.Module):
.permute(1, 0, 2)
.reshape(-1)
)
self.swap_dims = False
def forward(
self,
@ -244,9 +249,6 @@ class FlashNeoxAttention(torch.nn.Module):
layer_past_present_indices,
cu_seqlens_q,
):
if self.swap_dims:
self._swap_dims()
qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
qkv_rot = self.rotary_emb(qkv, cos, sin)
@ -329,7 +331,6 @@ class FlashMLP(nn.Module):
hidden_size,
process_group=process_group,
)
self.heuristic = "auto"
self.process_group = process_group
def forward(self, hidden_states):
@ -531,6 +532,25 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
self.head_size = self.layers[0].attention.head_size
self.num_heads = self.layers[0].attention.num_heads
def post_load_weights(self):
if isinstance(self.embed_in, TensorParallelEmbedding):
self.embed_in.add_null_idx()
for layer in self.layers:
layer: FlashNeoXLayer
layer.attention.shuffle_qkv_dims()
layer.attention.query_key_value.transpose_weight()
layer.attention.dense.transpose_weight()
layer.mlp.dense_h_to_4h.transpose_weight()
layer.mlp.dense_4h_to_h.transpose_weight()
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
model = super(FlashGPTNeoXModel, cls).from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
)
model.post_load_weights()
return model
def forward(
self,
input_ids,
@ -627,6 +647,18 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
config.hidden_size, config.vocab_size, bias=False
)
def post_load_weights(self):
self.gpt_neox.post_load_weights()
self.embed_out.transpose_weight()
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
model = super(FlashGPTNeoXForCausalLM, cls).from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs
)
model.post_load_weights()
return model
def forward(
self,
input_ids,