This commit is contained in:
Nicolas Patry 2023-05-15 16:43:32 +02:00
parent edc9ce9beb
commit 7ccb8eefdc
7 changed files with 181 additions and 137 deletions

View File

@ -29,6 +29,7 @@ from typing import Optional
# Flash attention imports
import flash_attn_cuda
import dropout_layer_norm
from flash_attn.layers.rotary import RotaryEmbedding
from text_generation_server.utils.layers import (
@ -91,12 +92,7 @@ class LlamaRMSNorm(nn.Module):
class FlashLlamaAttention(torch.nn.Module):
def __init__(
self,
num_heads,
hidden_size,
process_group=None,
):
def __init__(self, num_heads, hidden_size, process_group=None, quantize=None):
super().__init__()
self.num_heads = num_heads
self.hidden_size = hidden_size
@ -106,8 +102,12 @@ class FlashLlamaAttention(torch.nn.Module):
self.softmax_scale = self.head_size ** (-0.5)
if process_group is None:
self.query_key_value = FastLinear(hidden_size, 3 * hidden_size, bias=False)
self.o_proj = FastLinear(hidden_size, hidden_size, bias=False)
self.query_key_value = FastLinear(
hidden_size, 3 * hidden_size, bias=False, quantize=quantize
)
self.o_proj = FastLinear(
hidden_size, hidden_size, bias=False, quantize=quantize
)
else:
self.num_heads = self.num_heads // process_group.size()
self.query_key_value = TensorParallelColumnLinear(
@ -115,12 +115,14 @@ class FlashLlamaAttention(torch.nn.Module):
3 * hidden_size,
bias=False,
process_group=process_group,
quantize=quantize,
)
self.o_proj = TensorParallelRowLinear(
hidden_size,
hidden_size,
bias=False,
process_group=process_group,
quantize=quantize,
)
def forward(
@ -194,7 +196,9 @@ class FlashLlamaAttention(torch.nn.Module):
class LlamaMLP(nn.Module):
def __init__(self, act, hidden_size, intermediate_size, process_group=None):
def __init__(
self, act, hidden_size, intermediate_size, process_group=None, quantize=None
):
super().__init__()
self.act = (
ACT2FN[act]
@ -210,9 +214,11 @@ class LlamaMLP(nn.Module):
if process_group is None:
# Fuse gate and up proj
self.gate_up_proj = FastLinear(
hidden_size, 2 * intermediate_size, bias=False
hidden_size, 2 * intermediate_size, bias=False, quantize=quantize
)
self.down_proj = FastLinear(
intermediate_size, hidden_size, bias=False, quantize=quantize
)
self.down_proj = FastLinear(intermediate_size, hidden_size, bias=False)
self.intermediate_size = intermediate_size
else:
# Fuse gate and up proj
@ -221,6 +227,7 @@ class LlamaMLP(nn.Module):
2 * intermediate_size,
bias=False,
process_group=process_group,
quantize=quantize,
)
self.down_proj = TensorParallelRowLinear(
intermediate_size,
@ -228,6 +235,7 @@ class LlamaMLP(nn.Module):
bias=False,
process_group=process_group,
reduce=True,
quantize=quantize,
)
self.intermediate_size = self.down_proj.in_features
@ -248,11 +256,16 @@ class FlashLlamaLayer(nn.Module):
intermediate_size,
rms_norm_eps,
process_group=None,
quantize=None,
):
super().__init__()
self.self_attn = FlashLlamaAttention(num_heads, hidden_size, process_group)
self.mlp = LlamaMLP(act, hidden_size, intermediate_size, process_group)
self.self_attn = FlashLlamaAttention(
num_heads, hidden_size, process_group, quantize=quantize
)
self.mlp = LlamaMLP(
act, hidden_size, intermediate_size, process_group, quantize=quantize
)
self.input_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps)
@ -309,6 +322,7 @@ class FlashLlamaModel(torch.nn.Module):
self.embed_tokens = TensorParallelEmbedding(
config.vocab_size, config.hidden_size, process_group=process_group
)
self.embed_tokens.add_null_idx()
else:
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
@ -321,6 +335,7 @@ class FlashLlamaModel(torch.nn.Module):
config.intermediate_size,
config.rms_norm_eps,
process_group,
quantize=config.quantize,
)
for _ in range(config.num_hidden_layers)
]
@ -332,15 +347,15 @@ class FlashLlamaModel(torch.nn.Module):
self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads
def post_load_weights(self, load_in_8bit: bool = False):
if isinstance(self.embed_tokens, TensorParallelEmbedding):
self.embed_tokens.add_null_idx()
for layer in self.layers:
layer: FlashLlamaLayer
layer.self_attn.query_key_value.prepare_weights(load_in_8bit)
layer.self_attn.o_proj.prepare_weights(load_in_8bit)
layer.mlp.gate_up_proj.prepare_weights(load_in_8bit)
layer.mlp.down_proj.prepare_weights(load_in_8bit)
# def post_load_weights(self, load_in_8bit: bool = False):
# if isinstance(self.embed_tokens, TensorParallelEmbedding):
# self.embed_tokens.add_null_idx()
# for layer in self.layers:
# layer: FlashLlamaLayer
# layer.self_attn.query_key_value.prepare_weights(load_in_8bit)
# layer.self_attn.o_proj.prepare_weights(load_in_8bit)
# layer.mlp.gate_up_proj.prepare_weights(load_in_8bit)
# layer.mlp.down_proj.prepare_weights(load_in_8bit)
def forward(
self,
@ -429,9 +444,9 @@ class FlashLlamaForCausalLM(torch.nn.Module):
else:
self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False)
def post_load_weights(self, load_in_8bit: bool = False):
self.model.post_load_weights(load_in_8bit)
self.lm_head.prepare_weights()
# def post_load_weights(self, load_in_8bit: bool = False):
# self.model.post_load_weights(load_in_8bit)
# self.lm_head.prepare_weights()
def forward(
self,

View File

@ -76,20 +76,20 @@ class FlashNeoxAttention(torch.nn.Module):
hidden_size, hidden_size, process_group=process_group, reduce=reduce
)
def shuffle_qkv_dims(self):
"""Swap dims to avoid an additional permute"""
self.query_key_value.weight = torch.nn.Parameter(
self.query_key_value.weight.view(
self.num_heads, 3, self.head_size, self.hidden_size
)
.permute(1, 0, 2, 3)
.reshape(-1, self.hidden_size)
)
self.query_key_value.bias = torch.nn.Parameter(
self.query_key_value.bias.view(self.num_heads, 3, self.head_size)
.permute(1, 0, 2)
.reshape(-1)
)
# def shuffle_qkv_dims(self):
# """Swap dims to avoid an additional permute"""
# self.query_key_value.weight = torch.nn.Parameter(
# self.query_key_value.weight.view(
# self.num_heads, 3, self.head_size, self.hidden_size
# )
# .permute(1, 0, 2, 3)
# .reshape(-1, self.hidden_size)
# )
# self.query_key_value.bias = torch.nn.Parameter(
# self.query_key_value.bias.view(self.num_heads, 3, self.head_size)
# .permute(1, 0, 2)
# .reshape(-1)
# )
def forward(
self,
@ -317,6 +317,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
self.embed_in = TensorParallelEmbedding(
config.vocab_size, config.hidden_size, process_group=process_group
)
self.embed_in.add_null_idx()
else:
self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size)
@ -345,28 +346,28 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
self.head_size = self.layers[0].attention.head_size
self.num_heads = self.layers[0].attention.num_heads
def post_load_weights(self, load_in_8bit=False):
if isinstance(self.embed_in, TensorParallelEmbedding):
self.embed_in.add_null_idx()
for layer in self.layers:
layer: FlashNeoXLayer
layer.attention.shuffle_qkv_dims()
layer.attention.query_key_value.prepare_weights(load_in_8bit)
layer.attention.dense.prepare_weights(load_in_8bit)
layer.mlp.dense_h_to_4h.prepare_weights(load_in_8bit)
layer.mlp.dense_4h_to_h.prepare_weights(load_in_8bit)
# def post_load_weights(self, load_in_8bit=False):
# if isinstance(self.embed_in, TensorParallelEmbedding):
# self.embed_in.add_null_idx()
# for layer in self.layers:
# layer: FlashNeoXLayer
# layer.attention.shuffle_qkv_dims()
# layer.attention.query_key_value.prepare_weights(load_in_8bit)
# layer.attention.dense.prepare_weights(load_in_8bit)
# layer.mlp.dense_h_to_4h.prepare_weights(load_in_8bit)
# layer.mlp.dense_4h_to_h.prepare_weights(load_in_8bit)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# Pop here as we will replace the layer in our own logic and don't want from_pretrained
# to do it for us
load_in_8bit = kwargs.pop("load_in_8bit", False)
model = super(FlashGPTNeoXModel, cls).from_pretrained(
pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs
)
# @classmethod
# def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# # Pop here as we will replace the layer in our own logic and don't want from_pretrained
# # to do it for us
# load_in_8bit = kwargs.pop("load_in_8bit", False)
# model = super(FlashGPTNeoXModel, cls).from_pretrained(
# pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs
# )
model.post_load_weights(load_in_8bit)
return model
# model.post_load_weights(load_in_8bit)
# return model
def forward(
self,
@ -451,26 +452,30 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
config.hidden_size,
config.vocab_size // process_group.size(),
bias=False,
quantize=config.quantize,
)
else:
self.embed_out = FastLinear(
config.hidden_size, config.vocab_size, bias=False
config.hidden_size,
config.vocab_size,
bias=False,
quantize=config.quantize,
)
def post_load_weights(self, load_in_8bit=False):
self.gpt_neox.post_load_weights(load_in_8bit)
self.embed_out.prepare_weights()
# def post_load_weights(self, load_in_8bit=False):
# self.gpt_neox.post_load_weights(load_in_8bit)
# self.embed_out.prepare_weights()
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# Pop here as we will replace the layer in our own logic and don't want from_pretrained
# to do it for us
load_in_8bit = kwargs.pop("load_in_8bit", False)
model = super(FlashGPTNeoXForCausalLM, cls).from_pretrained(
pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs
)
model.post_load_weights(load_in_8bit)
return model
# @classmethod
# def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# # Pop here as we will replace the layer in our own logic and don't want from_pretrained
# # to do it for us
# load_in_8bit = kwargs.pop("load_in_8bit", False)
# model = super(FlashGPTNeoXForCausalLM, cls).from_pretrained(
# pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs
# )
# model.post_load_weights(load_in_8bit)
# return model
def forward(
self,

View File

@ -24,6 +24,7 @@ class FlashMQAttention(torch.nn.Module):
num_heads,
hidden_size,
process_group=None,
quantize=None,
):
super().__init__()
self.num_heads = num_heads
@ -33,15 +34,20 @@ class FlashMQAttention(torch.nn.Module):
self.softmax_scale = self.head_size ** (-0.5)
if process_group is None:
self.c_attn = FastLinear(hidden_size, hidden_size + 2 * self.head_size)
self.c_proj = FastLinear(hidden_size, hidden_size)
self.c_attn = FastLinear(
hidden_size, hidden_size + 2 * self.head_size, quantize=quantize
)
self.c_proj = FastLinear(hidden_size, hidden_size, quantize=quantize)
else:
self.num_heads = self.num_heads // process_group.size()
self.c_attn = FastLinear(hidden_size, self.head_size * (self.num_heads + 2))
self.c_attn = FastLinear(
hidden_size, self.head_size * (self.num_heads + 2), quantize=quantize
)
self.c_proj = TensorParallelRowLinear(
hidden_size,
hidden_size,
process_group=process_group,
quantize=quantize,
)
def forward(
@ -123,7 +129,9 @@ class FlashMQAttention(torch.nn.Module):
class MLP(nn.Module):
def __init__(self, act, hidden_size, intermediate_size, process_group=None):
def __init__(
self, act, hidden_size, intermediate_size, process_group=None, quantize=None
):
super().__init__()
self.act = (
ACT2FN[act]
@ -137,18 +145,20 @@ class MLP(nn.Module):
)
if process_group is None:
self.c_fc = FastLinear(hidden_size, intermediate_size)
self.c_proj = FastLinear(intermediate_size, hidden_size)
self.c_fc = FastLinear(hidden_size, intermediate_size, quantize=quantize)
self.c_proj = FastLinear(intermediate_size, hidden_size, quantize=quantize)
else:
self.c_fc = TensorParallelColumnLinear(
hidden_size,
intermediate_size,
process_group=process_group,
quantize=quantize,
)
self.c_proj = TensorParallelRowLinear(
intermediate_size,
hidden_size,
process_group=process_group,
quantize=quantize,
)
def forward(self, hidden_states):
@ -167,20 +177,20 @@ class Block(nn.Module):
intermediate_size,
layer_norm_eps,
process_group=None,
quantize=None,
):
super().__init__()
self.ln_1 = FastLayerNorm(hidden_size, eps=layer_norm_eps)
self.ln_2 = FastLayerNorm(hidden_size, eps=layer_norm_eps)
self.attn = FlashMQAttention(
num_heads,
hidden_size,
process_group,
num_heads, hidden_size, process_group, quantize=quantize
)
self.mlp = MLP(
act,
hidden_size,
intermediate_size,
process_group,
quantize=quantize,
)
def forward(
@ -231,12 +241,14 @@ class FlashSantacoderModel(nn.Module):
reduce=False,
process_group=process_group,
)
self.wte.add_null_idx()
self.wpe = TensorParallelEmbedding(
config.max_position_embeddings,
config.hidden_size,
reduce=False,
process_group=process_group,
)
self.wpe.add_null_idx()
else:
self.wte = nn.Embedding(config.vocab_size, config.hidden_size)
self.wpe = nn.Embedding(config.max_position_embeddings, config.hidden_size)
@ -252,6 +264,7 @@ class FlashSantacoderModel(nn.Module):
else 4 * config.hidden_size,
config.layer_norm_epsilon,
process_group,
quantize=config.quantize,
)
for _ in range(config.num_hidden_layers)
]
@ -261,16 +274,13 @@ class FlashSantacoderModel(nn.Module):
self.head_size = self.h[0].attn.head_size
self.num_heads = self.h[0].attn.num_heads
def post_load_weights(self, load_in_8bit: bool = False):
if self.tp_embeddings:
self.wte.add_null_idx()
self.wpe.add_null_idx()
for layer in self.h:
layer: Block
layer.attn.c_attn.prepare_weights(load_in_8bit)
layer.attn.c_proj.prepare_weights(load_in_8bit)
layer.mlp.c_fc.prepare_weights(load_in_8bit)
layer.mlp.c_proj.prepare_weights(load_in_8bit)
# def post_load_weights(self, load_in_8bit: bool = False):
# for layer in self.h:
# layer: Block
# layer.attn.c_attn.prepare_weights(load_in_8bit)
# layer.attn.c_proj.prepare_weights(load_in_8bit)
# layer.mlp.c_fc.prepare_weights(load_in_8bit)
# layer.mlp.c_proj.prepare_weights(load_in_8bit)
def forward(
self,
@ -343,13 +353,16 @@ class FlashSantacoderForCausalLM(nn.Module):
config.hidden_size,
config.vocab_size // process_group.size(),
bias=False,
quantize=None,
)
else:
self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False)
self.lm_head = FastLinear(
config.hidden_size, config.vocab_size, bias=False, quantize=None
)
def post_load_weights(self, load_in_8bit: bool = False):
self.transformer.post_load_weights(load_in_8bit)
self.lm_head.prepare_weights()
# def post_load_weights(self, load_in_8bit: bool = False):
# self.transformer.post_load_weights(load_in_8bit)
# self.lm_head.prepare_weights()
def forward(
self,

View File

@ -140,7 +140,7 @@ class FlashLlama(FlashCausalLM):
del value
torch.cuda.empty_cache()
model.post_load_weights(quantize)
# model.post_load_weights(quantize)
class FlashLlamaSharded(FlashLlama):
@ -307,4 +307,4 @@ class FlashLlamaSharded(FlashLlama):
module._buffers[param_name] = tensor
torch.cuda.empty_cache()
model.post_load_weights(quantize)
# model.post_load_weights(quantize)

View File

@ -152,4 +152,4 @@ class FlashNeoXSharded(FlashNeoX):
else:
module._buffers[param_name] = tensor
model.post_load_weights(quantize)
# model.post_load_weights(quantize)

View File

@ -160,7 +160,7 @@ class FlashSantacoder(FlashCausalLM):
del value
torch.cuda.empty_cache()
model.post_load_weights(quantize)
# model.post_load_weights(quantize)
def decode(self, generated_ids: List[int]) -> str:
# Do not skip special tokens as they are used for custom parsing rules of the generated text
@ -378,4 +378,4 @@ class FlashSantacoderSharded(FlashSantacoder):
model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight)
torch.cuda.empty_cache()
model.post_load_weights(quantize)
# model.post_load_weights(quantize)

View File

@ -1,7 +1,7 @@
import torch
from torch import nn
import dropout_layer_norm
import torch.nn.functional as F
HAS_BITS_AND_BYTES = True
try:
@ -18,12 +18,11 @@ class FastLinear(nn.Linear):
bias: bool = True,
device=None,
dtype=None,
quantize=None,
) -> None:
super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype)
self.quantized = False
self.quantize = quantize
self.bnb_linear = None
def prepare_weights(self, quantize: bool = False):
if quantize == "bitsandbytes":
if not HAS_BITS_AND_BYTES:
raise ImportError(
@ -33,6 +32,7 @@ class FastLinear(nn.Linear):
)
self.quantized = True
super().__init__(in_features, out_features, bias, device, dtype)
self.bnb_linear = Linear8bitLt(
self.in_features,
self.out_features,
@ -51,12 +51,13 @@ class FastLinear(nn.Linear):
elif quantize == "gptq":
raise NotImplementedError("`gptq` is not implemented for now")
elif quantize is None:
super().__init__(in_features, out_features, bias, device, dtype)
self.weight = nn.Parameter(self.weight.T)
else:
raise ValueError(f"Unexpected quantize `{quantize}`")
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.quantized:
if self.quantize:
return self.bnb_linear(input)
else:
if self.bias is not None:
@ -73,6 +74,7 @@ class TensorParallelColumnLinear(FastLinear):
bias=True,
device=None,
dtype=None,
quantize=None,
):
self.process_group = process_group
self.tp_world_size = process_group.size()
@ -85,6 +87,7 @@ class TensorParallelColumnLinear(FastLinear):
bias=bias,
device=device,
dtype=dtype,
quantize=quantize,
)
@ -98,6 +101,7 @@ class TensorParallelRowLinear(FastLinear):
bias=True,
device=None,
dtype=None,
quantize=None,
):
self.process_group = process_group
self.tp_world_size = process_group.size()
@ -111,6 +115,7 @@ class TensorParallelRowLinear(FastLinear):
bias=bias,
device=device,
dtype=dtype,
quantize=quantize,
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
@ -182,40 +187,46 @@ class TensorParallelEmbedding(nn.Embedding):
return out
class FastLayerNorm(nn.LayerNorm):
def forward(self, hidden_states, residual=None):
if hidden_states.shape[-1] > 8192:
if residual is not None:
hidden_states += residual
residual = hidden_states
try:
import dropout_layer_norm
return super(FastLayerNorm, self).forward(hidden_states), residual
else:
(
normed_hidden_states,
residual,
*rest,
) = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.weight,
self.bias,
None,
None,
None,
None,
0.0,
self.eps,
1.0,
0,
None,
False,
False,
)
if residual is None:
class FastLayerNorm(nn.LayerNorm):
def forward(self, hidden_states, residual=None):
if hidden_states.shape[-1] > 8192:
if residual is not None:
hidden_states += residual
residual = hidden_states
return normed_hidden_states, residual
return super().forward(hidden_states), residual
else:
(
normed_hidden_states,
residual,
*rest,
) = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.weight,
self.bias,
None,
None,
None,
None,
0.0,
self.eps,
1.0,
0,
None,
False,
False,
)
if residual is None:
residual = hidden_states
return normed_hidden_states, residual
except ImportError:
pass
try: