This commit is contained in:
Ubuntu 2023-05-05 16:31:55 +00:00
parent c126ca01d9
commit c5846ee73a
4 changed files with 92 additions and 70 deletions

1
.gitignore vendored
View File

@ -1,3 +1,4 @@
.idea
target
router/tokenizer.json
server/flash-attention

View File

@ -1,4 +1,5 @@
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
@ -112,6 +113,9 @@ class FastLinear(nn.Linear):
) -> None:
super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype)
self.quantized = False
self.weight = self.weight.to(device="meta")
if bias:
self.bias = self.bias.to(device="meta")
self.qlinear = None
def prepare_weights(self, layer=None, name=None, quantize: Optional[str] = None):
@ -154,8 +158,12 @@ class FastLinear(nn.Linear):
outfeatures=self.out_features,
bias=bool(self.bias),
)
rank = int(os.getenv("RANK"))
world_size = int(os.getenv("WORLD_SIZE"))
try:
rank = int(os.getenv("RANK"))
world_size = int(os.getenv("WORLD_SIZE"))
except:
rank = 0
world_size = 1
def get_row_slice(f, name):
slice_ = f.get_slice(name)

View File

@ -49,12 +49,13 @@ class FlashLlama(FlashCausalLM):
)
# We do not use from_pretrained as we modified the model internal module layout
try:
filenames = weight_files(model_id, revision, ".bin")
# Local files not found
except LocalEntryNotFoundError:
hub_files = weight_hub_files(model_id, revision, ".bin")
filenames = download_weights(hub_files, model_id, revision)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
# try:
# filenames = weight_files(model_id, revision, ".bin")
# # Local files not found
# except LocalEntryNotFoundError:
# hub_files = weight_hub_files(model_id, revision, ".bin")
# filenames = download_weights(hub_files, model_id, revision)
with init_empty_weights():
model = FlashLlamaForCausalLM(config)
@ -78,66 +79,77 @@ class FlashLlama(FlashCausalLM):
dtype: torch.dtype,
):
for filename in filenames:
state_dict = torch.load(filename, map_location="cpu")
for key, value in state_dict.items():
value = value.to(device if not quantize else "cpu").to(dtype)
with safe_open(
filename, framework="pt", device=str(device)
) as f:
layer_name = ".".join(key.split(".")[:4])
for key in f.keys():
# tmp
if "_proj" in key:
continue
value = f.get_tensor(key)
value = value.to(device if not quantize else "cpu").to(dtype)
# Fused qkv
if "q_proj" in key or "k_proj" in key or "v_proj" in key:
final_key = layer_name + ".query_key_value.weight"
layer_name = ".".join(key.split(".")[:4])
# Fused gate and up projs
elif "gate_proj" in key or "up_proj" in key:
final_key = layer_name + ".gate_up_proj.weight"
else:
final_key = key
# Fused qkv
if "q_proj" in key or "k_proj" in key or "v_proj" in key:
final_key = layer_name + ".query_key_value.weight"
module_name, param_name = final_key.rsplit(".", 1)
module = model.get_submodule(module_name)
try:
current_parameter_tensor = module._parameters[param_name]
except KeyError:
current_parameter_tensor = None
if current_parameter_tensor is not None:
if current_parameter_tensor.device == torch.device("meta"):
# Init qkv
if "query_key_value" in final_key:
module._parameters[param_name] = value.new_empty(
(value.shape[0] * 3, value.shape[1])
)
# Init gate and up proj
elif "gate_up_proj" in final_key:
module._parameters[param_name] = value.new_empty(
(value.shape[0] * 2, value.shape[1])
)
# Copy to correct slice
if "q_proj" in key:
module._parameters[param_name][: value.shape[0]] = value
elif "k_proj" in key:
module._parameters[param_name][
value.shape[0] : value.shape[0] * 2
] = value
elif "v_proj" in key:
module._parameters[param_name][value.shape[0] * 2 :] = value
elif "gate_proj" in key:
module._parameters[param_name][: value.shape[0]] = value
elif "up_proj" in key:
module._parameters[param_name][value.shape[0] :] = value
# Fused gate and up projs
elif "gate_proj" in key or "up_proj" in key:
final_key = layer_name + ".gate_up_proj.weight"
else:
if current_parameter_tensor.shape != value.shape:
raise ValueError(
f"Name {final_key} -- Current {current_parameter_tensor.shape} and got {value.shape}"
)
module._parameters[param_name] = value
else:
module._buffers[param_name] = value
final_key = key
del value
module_name, param_name = final_key.rsplit(".", 1)
module = model.get_submodule(module_name)
try:
current_parameter_tensor = module._parameters[param_name]
except KeyError:
current_parameter_tensor = None
if current_parameter_tensor is not None:
if current_parameter_tensor.device == torch.device("meta"):
# Init qkv
if "query_key_value" in final_key:
module._parameters[param_name] = value.new_empty(
(value.shape[0] * 3, value.shape[1])
)
# Init gate and up proj
elif "gate_up_proj" in final_key:
module._parameters[param_name] = value.new_empty(
(value.shape[0] * 2, value.shape[1])
)
# Copy to correct slice
if "q_proj" in key:
module._parameters[param_name][: value.shape[0]] = value
elif "k_proj" in key:
module._parameters[param_name][
value.shape[0] : value.shape[0] * 2
] = value
elif "v_proj" in key:
module._parameters[param_name][value.shape[0] * 2 :] = value
elif "gate_proj" in key:
module._parameters[param_name][: value.shape[0]] = value
elif "up_proj" in key:
module._parameters[param_name][value.shape[0] :] = value
else:
if current_parameter_tensor.shape != value.shape:
raise ValueError(
f"Name {final_key} -- Current {current_parameter_tensor.shape} and got {value.shape}"
)
module._parameters[param_name] = value
else:
module._buffers[param_name] = value
del value
torch.cuda.empty_cache()
model.post_load_weights(quantize)
uninitialized_parameters = []
for n, p in model.named_parameters():
@ -148,9 +160,6 @@ class FlashLlama(FlashCausalLM):
f"found uninitialized parameters in model: {uninitialized_parameters}"
)
torch.cuda.empty_cache()
model.post_load_weights(quantize)
class FlashLlamaSharded(FlashLlama):
def __init__(
@ -214,11 +223,14 @@ class FlashLlamaSharded(FlashLlama):
rank: int,
world_size: int,
):
for file in filenames:
with safe_open(
file, framework="pt", device=str(device) if not quantize else "cpu"
file, framework="pt", device=str(device)
) as f:
for name in f.keys():
if "_proj" in name:
continue
slice_ = f.get_slice(name)
layer_name = ".".join(name.split(".")[:4])
@ -312,6 +324,10 @@ class FlashLlamaSharded(FlashLlama):
else:
module._buffers[param_name] = tensor
torch.cuda.empty_cache()
model.post_load_weights(quantize)
uninitialized_parameters = []
for n, p in model.named_parameters():
if p.data.device == torch.device("meta"):
@ -320,6 +336,3 @@ class FlashLlamaSharded(FlashLlama):
raise RuntimeError(
f"found uninitialized parameters in model: {uninitialized_parameters}"
)
torch.cuda.empty_cache()
model.post_load_weights(quantize)

View File

@ -318,7 +318,7 @@ class QuantLinear(nn.Module):
self.register_buffer('scales', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16).cuda())
self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32).cuda())
if bias:
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16).cuda())
else:
self.bias = None