diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 11f3766e..c957a57e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -29,6 +29,7 @@ from typing import Optional # Flash attention imports import flash_attn_cuda +import dropout_layer_norm from flash_attn.layers.rotary import RotaryEmbedding from text_generation_server.utils.layers import ( @@ -91,12 +92,7 @@ class LlamaRMSNorm(nn.Module): class FlashLlamaAttention(torch.nn.Module): - def __init__( - self, - num_heads, - hidden_size, - process_group=None, - ): + def __init__(self, num_heads, hidden_size, process_group=None, quantize=None): super().__init__() self.num_heads = num_heads self.hidden_size = hidden_size @@ -106,8 +102,12 @@ class FlashLlamaAttention(torch.nn.Module): self.softmax_scale = self.head_size ** (-0.5) if process_group is None: - self.query_key_value = FastLinear(hidden_size, 3 * hidden_size, bias=False) - self.o_proj = FastLinear(hidden_size, hidden_size, bias=False) + self.query_key_value = FastLinear( + hidden_size, 3 * hidden_size, bias=False, quantize=quantize + ) + self.o_proj = FastLinear( + hidden_size, hidden_size, bias=False, quantize=quantize + ) else: self.num_heads = self.num_heads // process_group.size() self.query_key_value = TensorParallelColumnLinear( @@ -115,12 +115,14 @@ class FlashLlamaAttention(torch.nn.Module): 3 * hidden_size, bias=False, process_group=process_group, + quantize=quantize, ) self.o_proj = TensorParallelRowLinear( hidden_size, hidden_size, bias=False, process_group=process_group, + quantize=quantize, ) def forward( @@ -194,7 +196,9 @@ class FlashLlamaAttention(torch.nn.Module): class LlamaMLP(nn.Module): - def __init__(self, act, hidden_size, intermediate_size, process_group=None): + def __init__( + self, act, hidden_size, intermediate_size, process_group=None, quantize=None + ): super().__init__() self.act = ( ACT2FN[act] @@ -210,9 +214,11 @@ class LlamaMLP(nn.Module): if process_group is None: # Fuse gate and up proj self.gate_up_proj = FastLinear( - hidden_size, 2 * intermediate_size, bias=False + hidden_size, 2 * intermediate_size, bias=False, quantize=quantize + ) + self.down_proj = FastLinear( + intermediate_size, hidden_size, bias=False, quantize=quantize ) - self.down_proj = FastLinear(intermediate_size, hidden_size, bias=False) self.intermediate_size = intermediate_size else: # Fuse gate and up proj @@ -221,6 +227,7 @@ class LlamaMLP(nn.Module): 2 * intermediate_size, bias=False, process_group=process_group, + quantize=quantize, ) self.down_proj = TensorParallelRowLinear( intermediate_size, @@ -228,6 +235,7 @@ class LlamaMLP(nn.Module): bias=False, process_group=process_group, reduce=True, + quantize=quantize, ) self.intermediate_size = self.down_proj.in_features @@ -248,11 +256,16 @@ class FlashLlamaLayer(nn.Module): intermediate_size, rms_norm_eps, process_group=None, + quantize=None, ): super().__init__() - self.self_attn = FlashLlamaAttention(num_heads, hidden_size, process_group) - self.mlp = LlamaMLP(act, hidden_size, intermediate_size, process_group) + self.self_attn = FlashLlamaAttention( + num_heads, hidden_size, process_group, quantize=quantize + ) + self.mlp = LlamaMLP( + act, hidden_size, intermediate_size, process_group, quantize=quantize + ) self.input_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps) self.post_attention_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps) @@ -309,6 +322,7 @@ class FlashLlamaModel(torch.nn.Module): self.embed_tokens = TensorParallelEmbedding( config.vocab_size, config.hidden_size, process_group=process_group ) + self.embed_tokens.add_null_idx() else: self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) @@ -321,6 +335,7 @@ class FlashLlamaModel(torch.nn.Module): config.intermediate_size, config.rms_norm_eps, process_group, + quantize=config.quantize, ) for _ in range(config.num_hidden_layers) ] @@ -332,15 +347,15 @@ class FlashLlamaModel(torch.nn.Module): self.head_size = self.layers[0].self_attn.head_size self.num_heads = self.layers[0].self_attn.num_heads - def post_load_weights(self, load_in_8bit: bool = False): - if isinstance(self.embed_tokens, TensorParallelEmbedding): - self.embed_tokens.add_null_idx() - for layer in self.layers: - layer: FlashLlamaLayer - layer.self_attn.query_key_value.prepare_weights(load_in_8bit) - layer.self_attn.o_proj.prepare_weights(load_in_8bit) - layer.mlp.gate_up_proj.prepare_weights(load_in_8bit) - layer.mlp.down_proj.prepare_weights(load_in_8bit) + # def post_load_weights(self, load_in_8bit: bool = False): + # if isinstance(self.embed_tokens, TensorParallelEmbedding): + # self.embed_tokens.add_null_idx() + # for layer in self.layers: + # layer: FlashLlamaLayer + # layer.self_attn.query_key_value.prepare_weights(load_in_8bit) + # layer.self_attn.o_proj.prepare_weights(load_in_8bit) + # layer.mlp.gate_up_proj.prepare_weights(load_in_8bit) + # layer.mlp.down_proj.prepare_weights(load_in_8bit) def forward( self, @@ -429,9 +444,9 @@ class FlashLlamaForCausalLM(torch.nn.Module): else: self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False) - def post_load_weights(self, load_in_8bit: bool = False): - self.model.post_load_weights(load_in_8bit) - self.lm_head.prepare_weights() + # def post_load_weights(self, load_in_8bit: bool = False): + # self.model.post_load_weights(load_in_8bit) + # self.lm_head.prepare_weights() def forward( self, diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 369e8d4f..b8ea7670 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -76,20 +76,20 @@ class FlashNeoxAttention(torch.nn.Module): hidden_size, hidden_size, process_group=process_group, reduce=reduce ) - def shuffle_qkv_dims(self): - """Swap dims to avoid an additional permute""" - self.query_key_value.weight = torch.nn.Parameter( - self.query_key_value.weight.view( - self.num_heads, 3, self.head_size, self.hidden_size - ) - .permute(1, 0, 2, 3) - .reshape(-1, self.hidden_size) - ) - self.query_key_value.bias = torch.nn.Parameter( - self.query_key_value.bias.view(self.num_heads, 3, self.head_size) - .permute(1, 0, 2) - .reshape(-1) - ) + # def shuffle_qkv_dims(self): + # """Swap dims to avoid an additional permute""" + # self.query_key_value.weight = torch.nn.Parameter( + # self.query_key_value.weight.view( + # self.num_heads, 3, self.head_size, self.hidden_size + # ) + # .permute(1, 0, 2, 3) + # .reshape(-1, self.hidden_size) + # ) + # self.query_key_value.bias = torch.nn.Parameter( + # self.query_key_value.bias.view(self.num_heads, 3, self.head_size) + # .permute(1, 0, 2) + # .reshape(-1) + # ) def forward( self, @@ -317,6 +317,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): self.embed_in = TensorParallelEmbedding( config.vocab_size, config.hidden_size, process_group=process_group ) + self.embed_in.add_null_idx() else: self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size) @@ -345,28 +346,28 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): self.head_size = self.layers[0].attention.head_size self.num_heads = self.layers[0].attention.num_heads - def post_load_weights(self, load_in_8bit=False): - if isinstance(self.embed_in, TensorParallelEmbedding): - self.embed_in.add_null_idx() - for layer in self.layers: - layer: FlashNeoXLayer - layer.attention.shuffle_qkv_dims() - layer.attention.query_key_value.prepare_weights(load_in_8bit) - layer.attention.dense.prepare_weights(load_in_8bit) - layer.mlp.dense_h_to_4h.prepare_weights(load_in_8bit) - layer.mlp.dense_4h_to_h.prepare_weights(load_in_8bit) + # def post_load_weights(self, load_in_8bit=False): + # if isinstance(self.embed_in, TensorParallelEmbedding): + # self.embed_in.add_null_idx() + # for layer in self.layers: + # layer: FlashNeoXLayer + # layer.attention.shuffle_qkv_dims() + # layer.attention.query_key_value.prepare_weights(load_in_8bit) + # layer.attention.dense.prepare_weights(load_in_8bit) + # layer.mlp.dense_h_to_4h.prepare_weights(load_in_8bit) + # layer.mlp.dense_4h_to_h.prepare_weights(load_in_8bit) - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - # Pop here as we will replace the layer in our own logic and don't want from_pretrained - # to do it for us - load_in_8bit = kwargs.pop("load_in_8bit", False) - model = super(FlashGPTNeoXModel, cls).from_pretrained( - pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs - ) + # @classmethod + # def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + # # Pop here as we will replace the layer in our own logic and don't want from_pretrained + # # to do it for us + # load_in_8bit = kwargs.pop("load_in_8bit", False) + # model = super(FlashGPTNeoXModel, cls).from_pretrained( + # pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs + # ) - model.post_load_weights(load_in_8bit) - return model + # model.post_load_weights(load_in_8bit) + # return model def forward( self, @@ -451,26 +452,30 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): config.hidden_size, config.vocab_size // process_group.size(), bias=False, + quantize=config.quantize, ) else: self.embed_out = FastLinear( - config.hidden_size, config.vocab_size, bias=False + config.hidden_size, + config.vocab_size, + bias=False, + quantize=config.quantize, ) - def post_load_weights(self, load_in_8bit=False): - self.gpt_neox.post_load_weights(load_in_8bit) - self.embed_out.prepare_weights() + # def post_load_weights(self, load_in_8bit=False): + # self.gpt_neox.post_load_weights(load_in_8bit) + # self.embed_out.prepare_weights() - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - # Pop here as we will replace the layer in our own logic and don't want from_pretrained - # to do it for us - load_in_8bit = kwargs.pop("load_in_8bit", False) - model = super(FlashGPTNeoXForCausalLM, cls).from_pretrained( - pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs - ) - model.post_load_weights(load_in_8bit) - return model + # @classmethod + # def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + # # Pop here as we will replace the layer in our own logic and don't want from_pretrained + # # to do it for us + # load_in_8bit = kwargs.pop("load_in_8bit", False) + # model = super(FlashGPTNeoXForCausalLM, cls).from_pretrained( + # pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs + # ) + # model.post_load_weights(load_in_8bit) + # return model def forward( self, diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 9451b01a..1fdce4c1 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -24,6 +24,7 @@ class FlashMQAttention(torch.nn.Module): num_heads, hidden_size, process_group=None, + quantize=None, ): super().__init__() self.num_heads = num_heads @@ -33,15 +34,20 @@ class FlashMQAttention(torch.nn.Module): self.softmax_scale = self.head_size ** (-0.5) if process_group is None: - self.c_attn = FastLinear(hidden_size, hidden_size + 2 * self.head_size) - self.c_proj = FastLinear(hidden_size, hidden_size) + self.c_attn = FastLinear( + hidden_size, hidden_size + 2 * self.head_size, quantize=quantize + ) + self.c_proj = FastLinear(hidden_size, hidden_size, quantize=quantize) else: self.num_heads = self.num_heads // process_group.size() - self.c_attn = FastLinear(hidden_size, self.head_size * (self.num_heads + 2)) + self.c_attn = FastLinear( + hidden_size, self.head_size * (self.num_heads + 2), quantize=quantize + ) self.c_proj = TensorParallelRowLinear( hidden_size, hidden_size, process_group=process_group, + quantize=quantize, ) def forward( @@ -123,7 +129,9 @@ class FlashMQAttention(torch.nn.Module): class MLP(nn.Module): - def __init__(self, act, hidden_size, intermediate_size, process_group=None): + def __init__( + self, act, hidden_size, intermediate_size, process_group=None, quantize=None + ): super().__init__() self.act = ( ACT2FN[act] @@ -137,18 +145,20 @@ class MLP(nn.Module): ) if process_group is None: - self.c_fc = FastLinear(hidden_size, intermediate_size) - self.c_proj = FastLinear(intermediate_size, hidden_size) + self.c_fc = FastLinear(hidden_size, intermediate_size, quantize=quantize) + self.c_proj = FastLinear(intermediate_size, hidden_size, quantize=quantize) else: self.c_fc = TensorParallelColumnLinear( hidden_size, intermediate_size, process_group=process_group, + quantize=quantize, ) self.c_proj = TensorParallelRowLinear( intermediate_size, hidden_size, process_group=process_group, + quantize=quantize, ) def forward(self, hidden_states): @@ -167,20 +177,20 @@ class Block(nn.Module): intermediate_size, layer_norm_eps, process_group=None, + quantize=None, ): super().__init__() self.ln_1 = FastLayerNorm(hidden_size, eps=layer_norm_eps) self.ln_2 = FastLayerNorm(hidden_size, eps=layer_norm_eps) self.attn = FlashMQAttention( - num_heads, - hidden_size, - process_group, + num_heads, hidden_size, process_group, quantize=quantize ) self.mlp = MLP( act, hidden_size, intermediate_size, process_group, + quantize=quantize, ) def forward( @@ -231,12 +241,14 @@ class FlashSantacoderModel(nn.Module): reduce=False, process_group=process_group, ) + self.wte.add_null_idx() self.wpe = TensorParallelEmbedding( config.max_position_embeddings, config.hidden_size, reduce=False, process_group=process_group, ) + self.wpe.add_null_idx() else: self.wte = nn.Embedding(config.vocab_size, config.hidden_size) self.wpe = nn.Embedding(config.max_position_embeddings, config.hidden_size) @@ -252,6 +264,7 @@ class FlashSantacoderModel(nn.Module): else 4 * config.hidden_size, config.layer_norm_epsilon, process_group, + quantize=config.quantize, ) for _ in range(config.num_hidden_layers) ] @@ -261,16 +274,13 @@ class FlashSantacoderModel(nn.Module): self.head_size = self.h[0].attn.head_size self.num_heads = self.h[0].attn.num_heads - def post_load_weights(self, load_in_8bit: bool = False): - if self.tp_embeddings: - self.wte.add_null_idx() - self.wpe.add_null_idx() - for layer in self.h: - layer: Block - layer.attn.c_attn.prepare_weights(load_in_8bit) - layer.attn.c_proj.prepare_weights(load_in_8bit) - layer.mlp.c_fc.prepare_weights(load_in_8bit) - layer.mlp.c_proj.prepare_weights(load_in_8bit) + # def post_load_weights(self, load_in_8bit: bool = False): + # for layer in self.h: + # layer: Block + # layer.attn.c_attn.prepare_weights(load_in_8bit) + # layer.attn.c_proj.prepare_weights(load_in_8bit) + # layer.mlp.c_fc.prepare_weights(load_in_8bit) + # layer.mlp.c_proj.prepare_weights(load_in_8bit) def forward( self, @@ -343,13 +353,16 @@ class FlashSantacoderForCausalLM(nn.Module): config.hidden_size, config.vocab_size // process_group.size(), bias=False, + quantize=None, ) else: - self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False) + self.lm_head = FastLinear( + config.hidden_size, config.vocab_size, bias=False, quantize=None + ) - def post_load_weights(self, load_in_8bit: bool = False): - self.transformer.post_load_weights(load_in_8bit) - self.lm_head.prepare_weights() + # def post_load_weights(self, load_in_8bit: bool = False): + # self.transformer.post_load_weights(load_in_8bit) + # self.lm_head.prepare_weights() def forward( self, diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 0b63f904..317bb59f 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -140,7 +140,7 @@ class FlashLlama(FlashCausalLM): del value torch.cuda.empty_cache() - model.post_load_weights(quantize) + # model.post_load_weights(quantize) class FlashLlamaSharded(FlashLlama): @@ -307,4 +307,4 @@ class FlashLlamaSharded(FlashLlama): module._buffers[param_name] = tensor torch.cuda.empty_cache() - model.post_load_weights(quantize) + # model.post_load_weights(quantize) diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index 168c9195..c4e60864 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -152,4 +152,4 @@ class FlashNeoXSharded(FlashNeoX): else: module._buffers[param_name] = tensor - model.post_load_weights(quantize) + # model.post_load_weights(quantize) diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index 51a8998b..00cfc7b1 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -160,7 +160,7 @@ class FlashSantacoder(FlashCausalLM): del value torch.cuda.empty_cache() - model.post_load_weights(quantize) + # model.post_load_weights(quantize) def decode(self, generated_ids: List[int]) -> str: # Do not skip special tokens as they are used for custom parsing rules of the generated text @@ -378,4 +378,4 @@ class FlashSantacoderSharded(FlashSantacoder): model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight) torch.cuda.empty_cache() - model.post_load_weights(quantize) + # model.post_load_weights(quantize) diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 4c89e54e..316c06e0 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -1,7 +1,7 @@ import torch from torch import nn -import dropout_layer_norm +import torch.nn.functional as F HAS_BITS_AND_BYTES = True try: @@ -18,12 +18,11 @@ class FastLinear(nn.Linear): bias: bool = True, device=None, dtype=None, + quantize=None, ) -> None: - super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype) - self.quantized = False + self.quantize = quantize self.bnb_linear = None - def prepare_weights(self, quantize: bool = False): if quantize == "bitsandbytes": if not HAS_BITS_AND_BYTES: raise ImportError( @@ -33,6 +32,7 @@ class FastLinear(nn.Linear): ) self.quantized = True + super().__init__(in_features, out_features, bias, device, dtype) self.bnb_linear = Linear8bitLt( self.in_features, self.out_features, @@ -51,12 +51,13 @@ class FastLinear(nn.Linear): elif quantize == "gptq": raise NotImplementedError("`gptq` is not implemented for now") elif quantize is None: + super().__init__(in_features, out_features, bias, device, dtype) self.weight = nn.Parameter(self.weight.T) else: raise ValueError(f"Unexpected quantize `{quantize}`") def forward(self, input: torch.Tensor) -> torch.Tensor: - if self.quantized: + if self.quantize: return self.bnb_linear(input) else: if self.bias is not None: @@ -73,6 +74,7 @@ class TensorParallelColumnLinear(FastLinear): bias=True, device=None, dtype=None, + quantize=None, ): self.process_group = process_group self.tp_world_size = process_group.size() @@ -85,6 +87,7 @@ class TensorParallelColumnLinear(FastLinear): bias=bias, device=device, dtype=dtype, + quantize=quantize, ) @@ -98,6 +101,7 @@ class TensorParallelRowLinear(FastLinear): bias=True, device=None, dtype=None, + quantize=None, ): self.process_group = process_group self.tp_world_size = process_group.size() @@ -111,6 +115,7 @@ class TensorParallelRowLinear(FastLinear): bias=bias, device=device, dtype=dtype, + quantize=quantize, ) def forward(self, input: torch.Tensor) -> torch.Tensor: @@ -182,40 +187,46 @@ class TensorParallelEmbedding(nn.Embedding): return out -class FastLayerNorm(nn.LayerNorm): - def forward(self, hidden_states, residual=None): - if hidden_states.shape[-1] > 8192: - if residual is not None: - hidden_states += residual - residual = hidden_states +try: + import dropout_layer_norm - return super(FastLayerNorm, self).forward(hidden_states), residual - else: - ( - normed_hidden_states, - residual, - *rest, - ) = dropout_layer_norm.dropout_add_ln_fwd( - hidden_states, - residual, - self.weight, - self.bias, - None, - None, - None, - None, - 0.0, - self.eps, - 1.0, - 0, - None, - False, - False, - ) - if residual is None: + class FastLayerNorm(nn.LayerNorm): + def forward(self, hidden_states, residual=None): + if hidden_states.shape[-1] > 8192: + if residual is not None: + hidden_states += residual residual = hidden_states - return normed_hidden_states, residual + return super().forward(hidden_states), residual + else: + ( + normed_hidden_states, + residual, + *rest, + ) = dropout_layer_norm.dropout_add_ln_fwd( + hidden_states, + residual, + self.weight, + self.bias, + None, + None, + None, + None, + 0.0, + self.eps, + 1.0, + 0, + None, + False, + False, + ) + if residual is None: + residual = hidden_states + + return normed_hidden_states, residual + +except ImportError: + pass try: