TMP.
This commit is contained in:
parent
edc9ce9beb
commit
7ccb8eefdc
|
@ -29,6 +29,7 @@ from typing import Optional
|
|||
|
||||
# Flash attention imports
|
||||
import flash_attn_cuda
|
||||
import dropout_layer_norm
|
||||
|
||||
from flash_attn.layers.rotary import RotaryEmbedding
|
||||
from text_generation_server.utils.layers import (
|
||||
|
@ -91,12 +92,7 @@ class LlamaRMSNorm(nn.Module):
|
|||
|
||||
|
||||
class FlashLlamaAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads,
|
||||
hidden_size,
|
||||
process_group=None,
|
||||
):
|
||||
def __init__(self, num_heads, hidden_size, process_group=None, quantize=None):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
|
@ -106,8 +102,12 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
self.softmax_scale = self.head_size ** (-0.5)
|
||||
|
||||
if process_group is None:
|
||||
self.query_key_value = FastLinear(hidden_size, 3 * hidden_size, bias=False)
|
||||
self.o_proj = FastLinear(hidden_size, hidden_size, bias=False)
|
||||
self.query_key_value = FastLinear(
|
||||
hidden_size, 3 * hidden_size, bias=False, quantize=quantize
|
||||
)
|
||||
self.o_proj = FastLinear(
|
||||
hidden_size, hidden_size, bias=False, quantize=quantize
|
||||
)
|
||||
else:
|
||||
self.num_heads = self.num_heads // process_group.size()
|
||||
self.query_key_value = TensorParallelColumnLinear(
|
||||
|
@ -115,12 +115,14 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
3 * hidden_size,
|
||||
bias=False,
|
||||
process_group=process_group,
|
||||
quantize=quantize,
|
||||
)
|
||||
self.o_proj = TensorParallelRowLinear(
|
||||
hidden_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
process_group=process_group,
|
||||
quantize=quantize,
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
@ -194,7 +196,9 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
|
||||
|
||||
class LlamaMLP(nn.Module):
|
||||
def __init__(self, act, hidden_size, intermediate_size, process_group=None):
|
||||
def __init__(
|
||||
self, act, hidden_size, intermediate_size, process_group=None, quantize=None
|
||||
):
|
||||
super().__init__()
|
||||
self.act = (
|
||||
ACT2FN[act]
|
||||
|
@ -210,9 +214,11 @@ class LlamaMLP(nn.Module):
|
|||
if process_group is None:
|
||||
# Fuse gate and up proj
|
||||
self.gate_up_proj = FastLinear(
|
||||
hidden_size, 2 * intermediate_size, bias=False
|
||||
hidden_size, 2 * intermediate_size, bias=False, quantize=quantize
|
||||
)
|
||||
self.down_proj = FastLinear(
|
||||
intermediate_size, hidden_size, bias=False, quantize=quantize
|
||||
)
|
||||
self.down_proj = FastLinear(intermediate_size, hidden_size, bias=False)
|
||||
self.intermediate_size = intermediate_size
|
||||
else:
|
||||
# Fuse gate and up proj
|
||||
|
@ -221,6 +227,7 @@ class LlamaMLP(nn.Module):
|
|||
2 * intermediate_size,
|
||||
bias=False,
|
||||
process_group=process_group,
|
||||
quantize=quantize,
|
||||
)
|
||||
self.down_proj = TensorParallelRowLinear(
|
||||
intermediate_size,
|
||||
|
@ -228,6 +235,7 @@ class LlamaMLP(nn.Module):
|
|||
bias=False,
|
||||
process_group=process_group,
|
||||
reduce=True,
|
||||
quantize=quantize,
|
||||
)
|
||||
self.intermediate_size = self.down_proj.in_features
|
||||
|
||||
|
@ -248,11 +256,16 @@ class FlashLlamaLayer(nn.Module):
|
|||
intermediate_size,
|
||||
rms_norm_eps,
|
||||
process_group=None,
|
||||
quantize=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.self_attn = FlashLlamaAttention(num_heads, hidden_size, process_group)
|
||||
self.mlp = LlamaMLP(act, hidden_size, intermediate_size, process_group)
|
||||
self.self_attn = FlashLlamaAttention(
|
||||
num_heads, hidden_size, process_group, quantize=quantize
|
||||
)
|
||||
self.mlp = LlamaMLP(
|
||||
act, hidden_size, intermediate_size, process_group, quantize=quantize
|
||||
)
|
||||
|
||||
self.input_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps)
|
||||
self.post_attention_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps)
|
||||
|
@ -309,6 +322,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||
self.embed_tokens = TensorParallelEmbedding(
|
||||
config.vocab_size, config.hidden_size, process_group=process_group
|
||||
)
|
||||
self.embed_tokens.add_null_idx()
|
||||
else:
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
|
||||
|
@ -321,6 +335,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||
config.intermediate_size,
|
||||
config.rms_norm_eps,
|
||||
process_group,
|
||||
quantize=config.quantize,
|
||||
)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
]
|
||||
|
@ -332,15 +347,15 @@ class FlashLlamaModel(torch.nn.Module):
|
|||
self.head_size = self.layers[0].self_attn.head_size
|
||||
self.num_heads = self.layers[0].self_attn.num_heads
|
||||
|
||||
def post_load_weights(self, load_in_8bit: bool = False):
|
||||
if isinstance(self.embed_tokens, TensorParallelEmbedding):
|
||||
self.embed_tokens.add_null_idx()
|
||||
for layer in self.layers:
|
||||
layer: FlashLlamaLayer
|
||||
layer.self_attn.query_key_value.prepare_weights(load_in_8bit)
|
||||
layer.self_attn.o_proj.prepare_weights(load_in_8bit)
|
||||
layer.mlp.gate_up_proj.prepare_weights(load_in_8bit)
|
||||
layer.mlp.down_proj.prepare_weights(load_in_8bit)
|
||||
# def post_load_weights(self, load_in_8bit: bool = False):
|
||||
# if isinstance(self.embed_tokens, TensorParallelEmbedding):
|
||||
# self.embed_tokens.add_null_idx()
|
||||
# for layer in self.layers:
|
||||
# layer: FlashLlamaLayer
|
||||
# layer.self_attn.query_key_value.prepare_weights(load_in_8bit)
|
||||
# layer.self_attn.o_proj.prepare_weights(load_in_8bit)
|
||||
# layer.mlp.gate_up_proj.prepare_weights(load_in_8bit)
|
||||
# layer.mlp.down_proj.prepare_weights(load_in_8bit)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -429,9 +444,9 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||
else:
|
||||
self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
def post_load_weights(self, load_in_8bit: bool = False):
|
||||
self.model.post_load_weights(load_in_8bit)
|
||||
self.lm_head.prepare_weights()
|
||||
# def post_load_weights(self, load_in_8bit: bool = False):
|
||||
# self.model.post_load_weights(load_in_8bit)
|
||||
# self.lm_head.prepare_weights()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
|
@ -76,20 +76,20 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||
hidden_size, hidden_size, process_group=process_group, reduce=reduce
|
||||
)
|
||||
|
||||
def shuffle_qkv_dims(self):
|
||||
"""Swap dims to avoid an additional permute"""
|
||||
self.query_key_value.weight = torch.nn.Parameter(
|
||||
self.query_key_value.weight.view(
|
||||
self.num_heads, 3, self.head_size, self.hidden_size
|
||||
)
|
||||
.permute(1, 0, 2, 3)
|
||||
.reshape(-1, self.hidden_size)
|
||||
)
|
||||
self.query_key_value.bias = torch.nn.Parameter(
|
||||
self.query_key_value.bias.view(self.num_heads, 3, self.head_size)
|
||||
.permute(1, 0, 2)
|
||||
.reshape(-1)
|
||||
)
|
||||
# def shuffle_qkv_dims(self):
|
||||
# """Swap dims to avoid an additional permute"""
|
||||
# self.query_key_value.weight = torch.nn.Parameter(
|
||||
# self.query_key_value.weight.view(
|
||||
# self.num_heads, 3, self.head_size, self.hidden_size
|
||||
# )
|
||||
# .permute(1, 0, 2, 3)
|
||||
# .reshape(-1, self.hidden_size)
|
||||
# )
|
||||
# self.query_key_value.bias = torch.nn.Parameter(
|
||||
# self.query_key_value.bias.view(self.num_heads, 3, self.head_size)
|
||||
# .permute(1, 0, 2)
|
||||
# .reshape(-1)
|
||||
# )
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -317,6 +317,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
|||
self.embed_in = TensorParallelEmbedding(
|
||||
config.vocab_size, config.hidden_size, process_group=process_group
|
||||
)
|
||||
self.embed_in.add_null_idx()
|
||||
else:
|
||||
self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
|
||||
|
@ -345,28 +346,28 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
|||
self.head_size = self.layers[0].attention.head_size
|
||||
self.num_heads = self.layers[0].attention.num_heads
|
||||
|
||||
def post_load_weights(self, load_in_8bit=False):
|
||||
if isinstance(self.embed_in, TensorParallelEmbedding):
|
||||
self.embed_in.add_null_idx()
|
||||
for layer in self.layers:
|
||||
layer: FlashNeoXLayer
|
||||
layer.attention.shuffle_qkv_dims()
|
||||
layer.attention.query_key_value.prepare_weights(load_in_8bit)
|
||||
layer.attention.dense.prepare_weights(load_in_8bit)
|
||||
layer.mlp.dense_h_to_4h.prepare_weights(load_in_8bit)
|
||||
layer.mlp.dense_4h_to_h.prepare_weights(load_in_8bit)
|
||||
# def post_load_weights(self, load_in_8bit=False):
|
||||
# if isinstance(self.embed_in, TensorParallelEmbedding):
|
||||
# self.embed_in.add_null_idx()
|
||||
# for layer in self.layers:
|
||||
# layer: FlashNeoXLayer
|
||||
# layer.attention.shuffle_qkv_dims()
|
||||
# layer.attention.query_key_value.prepare_weights(load_in_8bit)
|
||||
# layer.attention.dense.prepare_weights(load_in_8bit)
|
||||
# layer.mlp.dense_h_to_4h.prepare_weights(load_in_8bit)
|
||||
# layer.mlp.dense_4h_to_h.prepare_weights(load_in_8bit)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
# Pop here as we will replace the layer in our own logic and don't want from_pretrained
|
||||
# to do it for us
|
||||
load_in_8bit = kwargs.pop("load_in_8bit", False)
|
||||
model = super(FlashGPTNeoXModel, cls).from_pretrained(
|
||||
pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs
|
||||
)
|
||||
# @classmethod
|
||||
# def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
# # Pop here as we will replace the layer in our own logic and don't want from_pretrained
|
||||
# # to do it for us
|
||||
# load_in_8bit = kwargs.pop("load_in_8bit", False)
|
||||
# model = super(FlashGPTNeoXModel, cls).from_pretrained(
|
||||
# pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs
|
||||
# )
|
||||
|
||||
model.post_load_weights(load_in_8bit)
|
||||
return model
|
||||
# model.post_load_weights(load_in_8bit)
|
||||
# return model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -451,26 +452,30 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
|||
config.hidden_size,
|
||||
config.vocab_size // process_group.size(),
|
||||
bias=False,
|
||||
quantize=config.quantize,
|
||||
)
|
||||
else:
|
||||
self.embed_out = FastLinear(
|
||||
config.hidden_size, config.vocab_size, bias=False
|
||||
config.hidden_size,
|
||||
config.vocab_size,
|
||||
bias=False,
|
||||
quantize=config.quantize,
|
||||
)
|
||||
|
||||
def post_load_weights(self, load_in_8bit=False):
|
||||
self.gpt_neox.post_load_weights(load_in_8bit)
|
||||
self.embed_out.prepare_weights()
|
||||
# def post_load_weights(self, load_in_8bit=False):
|
||||
# self.gpt_neox.post_load_weights(load_in_8bit)
|
||||
# self.embed_out.prepare_weights()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
# Pop here as we will replace the layer in our own logic and don't want from_pretrained
|
||||
# to do it for us
|
||||
load_in_8bit = kwargs.pop("load_in_8bit", False)
|
||||
model = super(FlashGPTNeoXForCausalLM, cls).from_pretrained(
|
||||
pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs
|
||||
)
|
||||
model.post_load_weights(load_in_8bit)
|
||||
return model
|
||||
# @classmethod
|
||||
# def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
# # Pop here as we will replace the layer in our own logic and don't want from_pretrained
|
||||
# # to do it for us
|
||||
# load_in_8bit = kwargs.pop("load_in_8bit", False)
|
||||
# model = super(FlashGPTNeoXForCausalLM, cls).from_pretrained(
|
||||
# pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs
|
||||
# )
|
||||
# model.post_load_weights(load_in_8bit)
|
||||
# return model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
|
@ -24,6 +24,7 @@ class FlashMQAttention(torch.nn.Module):
|
|||
num_heads,
|
||||
hidden_size,
|
||||
process_group=None,
|
||||
quantize=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
|
@ -33,15 +34,20 @@ class FlashMQAttention(torch.nn.Module):
|
|||
self.softmax_scale = self.head_size ** (-0.5)
|
||||
|
||||
if process_group is None:
|
||||
self.c_attn = FastLinear(hidden_size, hidden_size + 2 * self.head_size)
|
||||
self.c_proj = FastLinear(hidden_size, hidden_size)
|
||||
self.c_attn = FastLinear(
|
||||
hidden_size, hidden_size + 2 * self.head_size, quantize=quantize
|
||||
)
|
||||
self.c_proj = FastLinear(hidden_size, hidden_size, quantize=quantize)
|
||||
else:
|
||||
self.num_heads = self.num_heads // process_group.size()
|
||||
self.c_attn = FastLinear(hidden_size, self.head_size * (self.num_heads + 2))
|
||||
self.c_attn = FastLinear(
|
||||
hidden_size, self.head_size * (self.num_heads + 2), quantize=quantize
|
||||
)
|
||||
self.c_proj = TensorParallelRowLinear(
|
||||
hidden_size,
|
||||
hidden_size,
|
||||
process_group=process_group,
|
||||
quantize=quantize,
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
@ -123,7 +129,9 @@ class FlashMQAttention(torch.nn.Module):
|
|||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, act, hidden_size, intermediate_size, process_group=None):
|
||||
def __init__(
|
||||
self, act, hidden_size, intermediate_size, process_group=None, quantize=None
|
||||
):
|
||||
super().__init__()
|
||||
self.act = (
|
||||
ACT2FN[act]
|
||||
|
@ -137,18 +145,20 @@ class MLP(nn.Module):
|
|||
)
|
||||
|
||||
if process_group is None:
|
||||
self.c_fc = FastLinear(hidden_size, intermediate_size)
|
||||
self.c_proj = FastLinear(intermediate_size, hidden_size)
|
||||
self.c_fc = FastLinear(hidden_size, intermediate_size, quantize=quantize)
|
||||
self.c_proj = FastLinear(intermediate_size, hidden_size, quantize=quantize)
|
||||
else:
|
||||
self.c_fc = TensorParallelColumnLinear(
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
process_group=process_group,
|
||||
quantize=quantize,
|
||||
)
|
||||
self.c_proj = TensorParallelRowLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
process_group=process_group,
|
||||
quantize=quantize,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
|
@ -167,20 +177,20 @@ class Block(nn.Module):
|
|||
intermediate_size,
|
||||
layer_norm_eps,
|
||||
process_group=None,
|
||||
quantize=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.ln_1 = FastLayerNorm(hidden_size, eps=layer_norm_eps)
|
||||
self.ln_2 = FastLayerNorm(hidden_size, eps=layer_norm_eps)
|
||||
self.attn = FlashMQAttention(
|
||||
num_heads,
|
||||
hidden_size,
|
||||
process_group,
|
||||
num_heads, hidden_size, process_group, quantize=quantize
|
||||
)
|
||||
self.mlp = MLP(
|
||||
act,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
process_group,
|
||||
quantize=quantize,
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
@ -231,12 +241,14 @@ class FlashSantacoderModel(nn.Module):
|
|||
reduce=False,
|
||||
process_group=process_group,
|
||||
)
|
||||
self.wte.add_null_idx()
|
||||
self.wpe = TensorParallelEmbedding(
|
||||
config.max_position_embeddings,
|
||||
config.hidden_size,
|
||||
reduce=False,
|
||||
process_group=process_group,
|
||||
)
|
||||
self.wpe.add_null_idx()
|
||||
else:
|
||||
self.wte = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.wpe = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
||||
|
@ -252,6 +264,7 @@ class FlashSantacoderModel(nn.Module):
|
|||
else 4 * config.hidden_size,
|
||||
config.layer_norm_epsilon,
|
||||
process_group,
|
||||
quantize=config.quantize,
|
||||
)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
]
|
||||
|
@ -261,16 +274,13 @@ class FlashSantacoderModel(nn.Module):
|
|||
self.head_size = self.h[0].attn.head_size
|
||||
self.num_heads = self.h[0].attn.num_heads
|
||||
|
||||
def post_load_weights(self, load_in_8bit: bool = False):
|
||||
if self.tp_embeddings:
|
||||
self.wte.add_null_idx()
|
||||
self.wpe.add_null_idx()
|
||||
for layer in self.h:
|
||||
layer: Block
|
||||
layer.attn.c_attn.prepare_weights(load_in_8bit)
|
||||
layer.attn.c_proj.prepare_weights(load_in_8bit)
|
||||
layer.mlp.c_fc.prepare_weights(load_in_8bit)
|
||||
layer.mlp.c_proj.prepare_weights(load_in_8bit)
|
||||
# def post_load_weights(self, load_in_8bit: bool = False):
|
||||
# for layer in self.h:
|
||||
# layer: Block
|
||||
# layer.attn.c_attn.prepare_weights(load_in_8bit)
|
||||
# layer.attn.c_proj.prepare_weights(load_in_8bit)
|
||||
# layer.mlp.c_fc.prepare_weights(load_in_8bit)
|
||||
# layer.mlp.c_proj.prepare_weights(load_in_8bit)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -343,13 +353,16 @@ class FlashSantacoderForCausalLM(nn.Module):
|
|||
config.hidden_size,
|
||||
config.vocab_size // process_group.size(),
|
||||
bias=False,
|
||||
quantize=None,
|
||||
)
|
||||
else:
|
||||
self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False)
|
||||
self.lm_head = FastLinear(
|
||||
config.hidden_size, config.vocab_size, bias=False, quantize=None
|
||||
)
|
||||
|
||||
def post_load_weights(self, load_in_8bit: bool = False):
|
||||
self.transformer.post_load_weights(load_in_8bit)
|
||||
self.lm_head.prepare_weights()
|
||||
# def post_load_weights(self, load_in_8bit: bool = False):
|
||||
# self.transformer.post_load_weights(load_in_8bit)
|
||||
# self.lm_head.prepare_weights()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
|
@ -140,7 +140,7 @@ class FlashLlama(FlashCausalLM):
|
|||
del value
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
model.post_load_weights(quantize)
|
||||
# model.post_load_weights(quantize)
|
||||
|
||||
|
||||
class FlashLlamaSharded(FlashLlama):
|
||||
|
@ -307,4 +307,4 @@ class FlashLlamaSharded(FlashLlama):
|
|||
module._buffers[param_name] = tensor
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
model.post_load_weights(quantize)
|
||||
# model.post_load_weights(quantize)
|
||||
|
|
|
@ -152,4 +152,4 @@ class FlashNeoXSharded(FlashNeoX):
|
|||
else:
|
||||
module._buffers[param_name] = tensor
|
||||
|
||||
model.post_load_weights(quantize)
|
||||
# model.post_load_weights(quantize)
|
||||
|
|
|
@ -160,7 +160,7 @@ class FlashSantacoder(FlashCausalLM):
|
|||
del value
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
model.post_load_weights(quantize)
|
||||
# model.post_load_weights(quantize)
|
||||
|
||||
def decode(self, generated_ids: List[int]) -> str:
|
||||
# Do not skip special tokens as they are used for custom parsing rules of the generated text
|
||||
|
@ -378,4 +378,4 @@ class FlashSantacoderSharded(FlashSantacoder):
|
|||
|
||||
model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight)
|
||||
torch.cuda.empty_cache()
|
||||
model.post_load_weights(quantize)
|
||||
# model.post_load_weights(quantize)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
|
||||
from torch import nn
|
||||
import dropout_layer_norm
|
||||
import torch.nn.functional as F
|
||||
|
||||
HAS_BITS_AND_BYTES = True
|
||||
try:
|
||||
|
@ -18,12 +18,11 @@ class FastLinear(nn.Linear):
|
|||
bias: bool = True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
quantize=None,
|
||||
) -> None:
|
||||
super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype)
|
||||
self.quantized = False
|
||||
self.quantize = quantize
|
||||
self.bnb_linear = None
|
||||
|
||||
def prepare_weights(self, quantize: bool = False):
|
||||
if quantize == "bitsandbytes":
|
||||
if not HAS_BITS_AND_BYTES:
|
||||
raise ImportError(
|
||||
|
@ -33,6 +32,7 @@ class FastLinear(nn.Linear):
|
|||
)
|
||||
|
||||
self.quantized = True
|
||||
super().__init__(in_features, out_features, bias, device, dtype)
|
||||
self.bnb_linear = Linear8bitLt(
|
||||
self.in_features,
|
||||
self.out_features,
|
||||
|
@ -51,12 +51,13 @@ class FastLinear(nn.Linear):
|
|||
elif quantize == "gptq":
|
||||
raise NotImplementedError("`gptq` is not implemented for now")
|
||||
elif quantize is None:
|
||||
super().__init__(in_features, out_features, bias, device, dtype)
|
||||
self.weight = nn.Parameter(self.weight.T)
|
||||
else:
|
||||
raise ValueError(f"Unexpected quantize `{quantize}`")
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if self.quantized:
|
||||
if self.quantize:
|
||||
return self.bnb_linear(input)
|
||||
else:
|
||||
if self.bias is not None:
|
||||
|
@ -73,6 +74,7 @@ class TensorParallelColumnLinear(FastLinear):
|
|||
bias=True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
quantize=None,
|
||||
):
|
||||
self.process_group = process_group
|
||||
self.tp_world_size = process_group.size()
|
||||
|
@ -85,6 +87,7 @@ class TensorParallelColumnLinear(FastLinear):
|
|||
bias=bias,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
quantize=quantize,
|
||||
)
|
||||
|
||||
|
||||
|
@ -98,6 +101,7 @@ class TensorParallelRowLinear(FastLinear):
|
|||
bias=True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
quantize=None,
|
||||
):
|
||||
self.process_group = process_group
|
||||
self.tp_world_size = process_group.size()
|
||||
|
@ -111,6 +115,7 @@ class TensorParallelRowLinear(FastLinear):
|
|||
bias=bias,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
quantize=quantize,
|
||||
)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
|
@ -182,40 +187,46 @@ class TensorParallelEmbedding(nn.Embedding):
|
|||
return out
|
||||
|
||||
|
||||
class FastLayerNorm(nn.LayerNorm):
|
||||
def forward(self, hidden_states, residual=None):
|
||||
if hidden_states.shape[-1] > 8192:
|
||||
if residual is not None:
|
||||
hidden_states += residual
|
||||
residual = hidden_states
|
||||
try:
|
||||
import dropout_layer_norm
|
||||
|
||||
return super(FastLayerNorm, self).forward(hidden_states), residual
|
||||
else:
|
||||
(
|
||||
normed_hidden_states,
|
||||
residual,
|
||||
*rest,
|
||||
) = dropout_layer_norm.dropout_add_ln_fwd(
|
||||
hidden_states,
|
||||
residual,
|
||||
self.weight,
|
||||
self.bias,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
0.0,
|
||||
self.eps,
|
||||
1.0,
|
||||
0,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
if residual is None:
|
||||
class FastLayerNorm(nn.LayerNorm):
|
||||
def forward(self, hidden_states, residual=None):
|
||||
if hidden_states.shape[-1] > 8192:
|
||||
if residual is not None:
|
||||
hidden_states += residual
|
||||
residual = hidden_states
|
||||
|
||||
return normed_hidden_states, residual
|
||||
return super().forward(hidden_states), residual
|
||||
else:
|
||||
(
|
||||
normed_hidden_states,
|
||||
residual,
|
||||
*rest,
|
||||
) = dropout_layer_norm.dropout_add_ln_fwd(
|
||||
hidden_states,
|
||||
residual,
|
||||
self.weight,
|
||||
self.bias,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
0.0,
|
||||
self.eps,
|
||||
1.0,
|
||||
0,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
|
||||
return normed_hidden_states, residual
|
||||
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
try:
|
||||
|
|
Loading…
Reference in New Issue