This commit is contained in:
Ubuntu 2023-05-05 16:31:55 +00:00
parent c126ca01d9
commit c5846ee73a
4 changed files with 92 additions and 70 deletions

1
.gitignore vendored
View File

@ -1,3 +1,4 @@
.idea .idea
target target
router/tokenizer.json router/tokenizer.json
server/flash-attention

View File

@ -1,4 +1,5 @@
# coding=utf-8 # coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
# #
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
@ -112,6 +113,9 @@ class FastLinear(nn.Linear):
) -> None: ) -> None:
super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype) super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype)
self.quantized = False self.quantized = False
self.weight = self.weight.to(device="meta")
if bias:
self.bias = self.bias.to(device="meta")
self.qlinear = None self.qlinear = None
def prepare_weights(self, layer=None, name=None, quantize: Optional[str] = None): def prepare_weights(self, layer=None, name=None, quantize: Optional[str] = None):
@ -154,8 +158,12 @@ class FastLinear(nn.Linear):
outfeatures=self.out_features, outfeatures=self.out_features,
bias=bool(self.bias), bias=bool(self.bias),
) )
try:
rank = int(os.getenv("RANK")) rank = int(os.getenv("RANK"))
world_size = int(os.getenv("WORLD_SIZE")) world_size = int(os.getenv("WORLD_SIZE"))
except:
rank = 0
world_size = 1
def get_row_slice(f, name): def get_row_slice(f, name):
slice_ = f.get_slice(name) slice_ = f.get_slice(name)

View File

@ -49,12 +49,13 @@ class FlashLlama(FlashCausalLM):
) )
# We do not use from_pretrained as we modified the model internal module layout # We do not use from_pretrained as we modified the model internal module layout
try: filenames = weight_files(model_id, revision=revision, extension=".safetensors")
filenames = weight_files(model_id, revision, ".bin") # try:
# Local files not found # filenames = weight_files(model_id, revision, ".bin")
except LocalEntryNotFoundError: # # Local files not found
hub_files = weight_hub_files(model_id, revision, ".bin") # except LocalEntryNotFoundError:
filenames = download_weights(hub_files, model_id, revision) # hub_files = weight_hub_files(model_id, revision, ".bin")
# filenames = download_weights(hub_files, model_id, revision)
with init_empty_weights(): with init_empty_weights():
model = FlashLlamaForCausalLM(config) model = FlashLlamaForCausalLM(config)
@ -78,8 +79,15 @@ class FlashLlama(FlashCausalLM):
dtype: torch.dtype, dtype: torch.dtype,
): ):
for filename in filenames: for filename in filenames:
state_dict = torch.load(filename, map_location="cpu") with safe_open(
for key, value in state_dict.items(): filename, framework="pt", device=str(device)
) as f:
for key in f.keys():
# tmp
if "_proj" in key:
continue
value = f.get_tensor(key)
value = value.to(device if not quantize else "cpu").to(dtype) value = value.to(device if not quantize else "cpu").to(dtype)
layer_name = ".".join(key.split(".")[:4]) layer_name = ".".join(key.split(".")[:4])
@ -139,6 +147,10 @@ class FlashLlama(FlashCausalLM):
del value del value
torch.cuda.empty_cache()
model.post_load_weights(quantize)
uninitialized_parameters = [] uninitialized_parameters = []
for n, p in model.named_parameters(): for n, p in model.named_parameters():
if p.data.device == torch.device("meta"): if p.data.device == torch.device("meta"):
@ -148,9 +160,6 @@ class FlashLlama(FlashCausalLM):
f"found uninitialized parameters in model: {uninitialized_parameters}" f"found uninitialized parameters in model: {uninitialized_parameters}"
) )
torch.cuda.empty_cache()
model.post_load_weights(quantize)
class FlashLlamaSharded(FlashLlama): class FlashLlamaSharded(FlashLlama):
def __init__( def __init__(
@ -214,11 +223,14 @@ class FlashLlamaSharded(FlashLlama):
rank: int, rank: int,
world_size: int, world_size: int,
): ):
for file in filenames: for file in filenames:
with safe_open( with safe_open(
file, framework="pt", device=str(device) if not quantize else "cpu" file, framework="pt", device=str(device)
) as f: ) as f:
for name in f.keys(): for name in f.keys():
if "_proj" in name:
continue
slice_ = f.get_slice(name) slice_ = f.get_slice(name)
layer_name = ".".join(name.split(".")[:4]) layer_name = ".".join(name.split(".")[:4])
@ -312,6 +324,10 @@ class FlashLlamaSharded(FlashLlama):
else: else:
module._buffers[param_name] = tensor module._buffers[param_name] = tensor
torch.cuda.empty_cache()
model.post_load_weights(quantize)
uninitialized_parameters = [] uninitialized_parameters = []
for n, p in model.named_parameters(): for n, p in model.named_parameters():
if p.data.device == torch.device("meta"): if p.data.device == torch.device("meta"):
@ -320,6 +336,3 @@ class FlashLlamaSharded(FlashLlama):
raise RuntimeError( raise RuntimeError(
f"found uninitialized parameters in model: {uninitialized_parameters}" f"found uninitialized parameters in model: {uninitialized_parameters}"
) )
torch.cuda.empty_cache()
model.post_load_weights(quantize)

View File

@ -318,7 +318,7 @@ class QuantLinear(nn.Module):
self.register_buffer('scales', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16).cuda()) self.register_buffer('scales', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16).cuda())
self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32).cuda()) self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32).cuda())
if bias: if bias:
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16)) self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16).cuda())
else: else:
self.bias = None self.bias = None