Dump.
This commit is contained in:
parent
c126ca01d9
commit
c5846ee73a
|
@ -1,3 +1,4 @@
|
||||||
.idea
|
.idea
|
||||||
target
|
target
|
||||||
router/tokenizer.json
|
router/tokenizer.json
|
||||||
|
server/flash-attention
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
|
|
||||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||||
#
|
#
|
||||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||||
|
@ -112,6 +113,9 @@ class FastLinear(nn.Linear):
|
||||||
) -> None:
|
) -> None:
|
||||||
super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype)
|
super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype)
|
||||||
self.quantized = False
|
self.quantized = False
|
||||||
|
self.weight = self.weight.to(device="meta")
|
||||||
|
if bias:
|
||||||
|
self.bias = self.bias.to(device="meta")
|
||||||
self.qlinear = None
|
self.qlinear = None
|
||||||
|
|
||||||
def prepare_weights(self, layer=None, name=None, quantize: Optional[str] = None):
|
def prepare_weights(self, layer=None, name=None, quantize: Optional[str] = None):
|
||||||
|
@ -154,8 +158,12 @@ class FastLinear(nn.Linear):
|
||||||
outfeatures=self.out_features,
|
outfeatures=self.out_features,
|
||||||
bias=bool(self.bias),
|
bias=bool(self.bias),
|
||||||
)
|
)
|
||||||
|
try:
|
||||||
rank = int(os.getenv("RANK"))
|
rank = int(os.getenv("RANK"))
|
||||||
world_size = int(os.getenv("WORLD_SIZE"))
|
world_size = int(os.getenv("WORLD_SIZE"))
|
||||||
|
except:
|
||||||
|
rank = 0
|
||||||
|
world_size = 1
|
||||||
|
|
||||||
def get_row_slice(f, name):
|
def get_row_slice(f, name):
|
||||||
slice_ = f.get_slice(name)
|
slice_ = f.get_slice(name)
|
||||||
|
|
|
@ -49,12 +49,13 @@ class FlashLlama(FlashCausalLM):
|
||||||
)
|
)
|
||||||
|
|
||||||
# We do not use from_pretrained as we modified the model internal module layout
|
# We do not use from_pretrained as we modified the model internal module layout
|
||||||
try:
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
filenames = weight_files(model_id, revision, ".bin")
|
# try:
|
||||||
# Local files not found
|
# filenames = weight_files(model_id, revision, ".bin")
|
||||||
except LocalEntryNotFoundError:
|
# # Local files not found
|
||||||
hub_files = weight_hub_files(model_id, revision, ".bin")
|
# except LocalEntryNotFoundError:
|
||||||
filenames = download_weights(hub_files, model_id, revision)
|
# hub_files = weight_hub_files(model_id, revision, ".bin")
|
||||||
|
# filenames = download_weights(hub_files, model_id, revision)
|
||||||
|
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
model = FlashLlamaForCausalLM(config)
|
model = FlashLlamaForCausalLM(config)
|
||||||
|
@ -78,8 +79,15 @@ class FlashLlama(FlashCausalLM):
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
):
|
):
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
state_dict = torch.load(filename, map_location="cpu")
|
with safe_open(
|
||||||
for key, value in state_dict.items():
|
filename, framework="pt", device=str(device)
|
||||||
|
) as f:
|
||||||
|
|
||||||
|
for key in f.keys():
|
||||||
|
# tmp
|
||||||
|
if "_proj" in key:
|
||||||
|
continue
|
||||||
|
value = f.get_tensor(key)
|
||||||
value = value.to(device if not quantize else "cpu").to(dtype)
|
value = value.to(device if not quantize else "cpu").to(dtype)
|
||||||
|
|
||||||
layer_name = ".".join(key.split(".")[:4])
|
layer_name = ".".join(key.split(".")[:4])
|
||||||
|
@ -139,6 +147,10 @@ class FlashLlama(FlashCausalLM):
|
||||||
|
|
||||||
del value
|
del value
|
||||||
|
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
model.post_load_weights(quantize)
|
||||||
|
|
||||||
uninitialized_parameters = []
|
uninitialized_parameters = []
|
||||||
for n, p in model.named_parameters():
|
for n, p in model.named_parameters():
|
||||||
if p.data.device == torch.device("meta"):
|
if p.data.device == torch.device("meta"):
|
||||||
|
@ -148,9 +160,6 @@ class FlashLlama(FlashCausalLM):
|
||||||
f"found uninitialized parameters in model: {uninitialized_parameters}"
|
f"found uninitialized parameters in model: {uninitialized_parameters}"
|
||||||
)
|
)
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
model.post_load_weights(quantize)
|
|
||||||
|
|
||||||
|
|
||||||
class FlashLlamaSharded(FlashLlama):
|
class FlashLlamaSharded(FlashLlama):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -214,11 +223,14 @@ class FlashLlamaSharded(FlashLlama):
|
||||||
rank: int,
|
rank: int,
|
||||||
world_size: int,
|
world_size: int,
|
||||||
):
|
):
|
||||||
|
|
||||||
for file in filenames:
|
for file in filenames:
|
||||||
with safe_open(
|
with safe_open(
|
||||||
file, framework="pt", device=str(device) if not quantize else "cpu"
|
file, framework="pt", device=str(device)
|
||||||
) as f:
|
) as f:
|
||||||
for name in f.keys():
|
for name in f.keys():
|
||||||
|
if "_proj" in name:
|
||||||
|
continue
|
||||||
slice_ = f.get_slice(name)
|
slice_ = f.get_slice(name)
|
||||||
|
|
||||||
layer_name = ".".join(name.split(".")[:4])
|
layer_name = ".".join(name.split(".")[:4])
|
||||||
|
@ -312,6 +324,10 @@ class FlashLlamaSharded(FlashLlama):
|
||||||
else:
|
else:
|
||||||
module._buffers[param_name] = tensor
|
module._buffers[param_name] = tensor
|
||||||
|
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
model.post_load_weights(quantize)
|
||||||
|
|
||||||
uninitialized_parameters = []
|
uninitialized_parameters = []
|
||||||
for n, p in model.named_parameters():
|
for n, p in model.named_parameters():
|
||||||
if p.data.device == torch.device("meta"):
|
if p.data.device == torch.device("meta"):
|
||||||
|
@ -320,6 +336,3 @@ class FlashLlamaSharded(FlashLlama):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"found uninitialized parameters in model: {uninitialized_parameters}"
|
f"found uninitialized parameters in model: {uninitialized_parameters}"
|
||||||
)
|
)
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
model.post_load_weights(quantize)
|
|
||||||
|
|
|
@ -318,7 +318,7 @@ class QuantLinear(nn.Module):
|
||||||
self.register_buffer('scales', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16).cuda())
|
self.register_buffer('scales', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16).cuda())
|
||||||
self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32).cuda())
|
self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32).cuda())
|
||||||
if bias:
|
if bias:
|
||||||
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
|
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16).cuda())
|
||||||
else:
|
else:
|
||||||
self.bias = None
|
self.bias = None
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue