From 678b2f39000f638e0099af0d84a98d409feca428 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Sun, 26 Mar 2023 16:37:21 +0200 Subject: [PATCH] feat(server): cleanup flash neox loading (#139) --- .../models/flash_neox.py | 3 +- .../models/flash_neox_modeling.py | 72 +++++++++++++------ 2 files changed, 53 insertions(+), 22 deletions(-) diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index cbaa78ca..b97f342a 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -450,8 +450,6 @@ class FlashNeoX(Model): next_batch_input_ids = next_batch_input_ids[0].view(1) next_batch_past_key_values = next_batch_past_key_values[0] - print(next_batch_input_ids.shape) - next_batch = FlashNeoXBatch( batch_id=batch.batch_id, requests=next_batch_requests, @@ -507,6 +505,7 @@ class FlashNeoXSharded(FlashNeoX): rank=self.rank, world_size=self.world_size, ) + model.post_load_weights() self.model = model.eval().to(dtype) torch.distributed.barrier(group=self.process_group) super(FlashNeoX, self).__init__( diff --git a/server/text_generation_server/models/flash_neox_modeling.py b/server/text_generation_server/models/flash_neox_modeling.py index dcfb613d..2e638d77 100644 --- a/server/text_generation_server/models/flash_neox_modeling.py +++ b/server/text_generation_server/models/flash_neox_modeling.py @@ -1,6 +1,8 @@ import torch import torch.distributed +from torch.nn import functional as F + from torch import nn from transformers.activations import ACT2FN from transformers.modeling_utils import PreTrainedModel @@ -24,13 +26,11 @@ class FastLinear(nn.Linear): dtype=None, ) -> None: super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype) - self.swap_dims = True + + def transpose_weight(self): + self.weight = nn.Parameter(self.weight.T) def forward(self, input: torch.Tensor) -> torch.Tensor: - if self.swap_dims: - self.weight = nn.Parameter(self.weight.T) - self.swap_dims = False - if self.bias is not None: return torch.addmm(self.bias, input, self.weight) return torch.matmul(input, self.weight) @@ -120,6 +120,10 @@ class TensorParallelEmbedding(nn.Embedding): self.min_id = self.tp_rank * block_size self.max_id = (self.tp_rank + 1) * block_size + # Additional entry that will map to zero + # Used for masking + self.null_idx = block_size + super().__init__( block_size, embedding_dim, @@ -133,15 +137,19 @@ class TensorParallelEmbedding(nn.Embedding): dtype=dtype, ) + def add_null_idx(self): + """Additional 0 entry used for masking""" + self.weight = nn.Parameter(F.pad(self.weight, (0, 0, 0, 1))) + def forward(self, input: torch.Tensor) -> torch.Tensor: - # `0` if input is in the correct interval, else `1` - input_mask = torch.logical_or(self.min_id > input, input >= self.max_id) + # default all out of bounds values to `self.null_idx` that will then be mapped to 0 # translate for [0, self.max_id - self.min_id[ - input = input - self.min_id - # default all out of bounds values to `0` - input[input_mask] = 0 + input = torch.where( + (self.min_id > input) | (input >= self.max_id), + self.null_idx, + input - self.min_id, + ) out = super().forward(input) - out[input_mask] = 0.0 torch.distributed.all_reduce(out, group=self.process_group) return out @@ -214,11 +222,9 @@ class FlashNeoxAttention(torch.nn.Module): hidden_size, process_group=process_group, ) - self.swap_dims = True - # TODO: remove and swap dims when loading weights - def _swap_dims(self): - """Swap dims for the first inference to avoid an additional permute""" + def shuffle_qkv_dims(self): + """Swap dims to avoid an additional permute""" self.query_key_value.weight = torch.nn.Parameter( self.query_key_value.weight.view( self.num_heads, 3, self.head_size, self.hidden_size @@ -231,7 +237,6 @@ class FlashNeoxAttention(torch.nn.Module): .permute(1, 0, 2) .reshape(-1) ) - self.swap_dims = False def forward( self, @@ -244,9 +249,6 @@ class FlashNeoxAttention(torch.nn.Module): layer_past_present_indices, cu_seqlens_q, ): - if self.swap_dims: - self._swap_dims() - qkv = self.query_key_value(hidden_states) qkv = qkv.view(-1, 3, self.num_heads, self.head_size) qkv_rot = self.rotary_emb(qkv, cos, sin) @@ -329,7 +331,6 @@ class FlashMLP(nn.Module): hidden_size, process_group=process_group, ) - self.heuristic = "auto" self.process_group = process_group def forward(self, hidden_states): @@ -531,6 +532,25 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): self.head_size = self.layers[0].attention.head_size self.num_heads = self.layers[0].attention.num_heads + def post_load_weights(self): + if isinstance(self.embed_in, TensorParallelEmbedding): + self.embed_in.add_null_idx() + for layer in self.layers: + layer: FlashNeoXLayer + layer.attention.shuffle_qkv_dims() + layer.attention.query_key_value.transpose_weight() + layer.attention.dense.transpose_weight() + layer.mlp.dense_h_to_4h.transpose_weight() + layer.mlp.dense_4h_to_h.transpose_weight() + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + model = super(FlashGPTNeoXModel, cls).from_pretrained( + pretrained_model_name_or_path, *model_args, **kwargs + ) + model.post_load_weights() + return model + def forward( self, input_ids, @@ -627,6 +647,18 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): config.hidden_size, config.vocab_size, bias=False ) + def post_load_weights(self): + self.gpt_neox.post_load_weights() + self.embed_out.transpose_weight() + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + model = super(FlashGPTNeoXForCausalLM, cls).from_pretrained( + pretrained_model_name_or_path, *model_args, **kwargs + ) + model.post_load_weights() + return model + def forward( self, input_ids,