186 lines
6.3 KiB
Python
186 lines
6.3 KiB
Python
import torch
|
|
import os
|
|
import tempfile
|
|
import json
|
|
|
|
from typing import BinaryIO
|
|
from joblib import Parallel, delayed
|
|
from functools import partial
|
|
from pathlib import Path
|
|
from tqdm import tqdm
|
|
|
|
from huggingface_hub import hf_hub_url
|
|
from huggingface_hub.file_download import _request_wrapper, hf_raise_for_status
|
|
|
|
|
|
def match_suffix(text, suffix):
|
|
return text[-len(suffix) :] == suffix
|
|
|
|
|
|
def http_get(
|
|
url: str,
|
|
temp_file: BinaryIO,
|
|
*,
|
|
timeout=10.0,
|
|
max_retries=0,
|
|
):
|
|
"""
|
|
Download a remote file. Do not gobble up errors, and will return errors tailored to the Hugging Face Hub.
|
|
"""
|
|
r = _request_wrapper(
|
|
method="GET",
|
|
url=url,
|
|
stream=True,
|
|
timeout=timeout,
|
|
max_retries=max_retries,
|
|
)
|
|
hf_raise_for_status(r)
|
|
for chunk in r.iter_content(chunk_size=1024):
|
|
if chunk: # filter out keep-alive new chunks
|
|
temp_file.write(chunk)
|
|
|
|
|
|
def cache_download_url(url: str, root_dir: Path):
|
|
filename = root_dir / url.split("/")[-1]
|
|
|
|
if not filename.exists():
|
|
temp_file_manager = partial(
|
|
tempfile.NamedTemporaryFile, mode="wb", dir=root_dir, delete=False
|
|
)
|
|
with temp_file_manager() as temp_file:
|
|
http_get(url, temp_file)
|
|
|
|
os.replace(temp_file.name, filename)
|
|
return filename
|
|
|
|
|
|
def prepare_weights(
|
|
model_name: str, cache_path: Path, save_path: Path, tp_world_size: int
|
|
):
|
|
save_paths = [
|
|
save_path / f"{model_name}_tp-rank-{tp_rank}-of-{tp_world_size}.pty"
|
|
for tp_rank in range(tp_world_size)
|
|
]
|
|
|
|
if all(save_path.exists() for save_path in save_paths):
|
|
print("Weights are already prepared")
|
|
return save_paths
|
|
|
|
cache_path.mkdir(parents=True, exist_ok=True)
|
|
if model_name == "bigscience/bloom-560m":
|
|
url = hf_hub_url(model_name, filename="pytorch_model.bin")
|
|
cache_download_url(url, cache_path)
|
|
|
|
elif model_name == "bigscience/bloom":
|
|
url = hf_hub_url(model_name, filename="pytorch_model.bin.index.json")
|
|
index_path = cache_download_url(url, cache_path)
|
|
with index_path.open("r") as f:
|
|
index = json.load(f)
|
|
|
|
# Get unique file names
|
|
weight_files = list(
|
|
set([filename for filename in index["weight_map"].values()])
|
|
)
|
|
urls = [hf_hub_url(model_name, filename=filename) for filename in weight_files]
|
|
|
|
Parallel(n_jobs=5)(
|
|
delayed(cache_download_url)(url, cache_path) for url in tqdm(urls)
|
|
)
|
|
else:
|
|
raise ValueError(f"Unknown model name: {model_name}")
|
|
|
|
shards_state_dicts = [{} for _ in range(tp_world_size)]
|
|
|
|
for weight_path in tqdm(Path(cache_path).glob("*.bin")):
|
|
state_dict = torch.load(weight_path, map_location="cpu")
|
|
|
|
keys = list(state_dict.keys())
|
|
for state_name in keys:
|
|
state = state_dict[state_name]
|
|
if any(
|
|
match_suffix(state_name, candidate)
|
|
for candidate in [
|
|
"self_attention.query_key_value.weight",
|
|
"self_attention.query_key_value.bias",
|
|
"mlp.dense_h_to_4h.weight",
|
|
"mlp.dense_h_to_4h.bias",
|
|
"word_embeddings.weight",
|
|
]
|
|
):
|
|
output_size = state.shape[0]
|
|
assert output_size % tp_world_size == 0
|
|
block_size = output_size // tp_world_size
|
|
sharded_weights = torch.split(state, block_size, dim=0)
|
|
assert len(sharded_weights) == tp_world_size
|
|
|
|
for tp_rank, shard in enumerate(sharded_weights):
|
|
shards_state_dicts[tp_rank][
|
|
"transformer." + state_name
|
|
] = shard.detach().clone()
|
|
|
|
elif match_suffix(state_name, "lm_head.weight"):
|
|
output_size = state.shape[0]
|
|
assert output_size % tp_world_size == 0
|
|
block_size = output_size // tp_world_size
|
|
sharded_weights = torch.split(state, block_size, dim=0)
|
|
assert len(sharded_weights) == tp_world_size
|
|
|
|
for tp_rank, shard in enumerate(sharded_weights):
|
|
shards_state_dicts[tp_rank][state_name] = shard.detach().clone()
|
|
|
|
elif any(
|
|
match_suffix(state_name, candidate)
|
|
for candidate in [
|
|
"self_attention.dense.weight",
|
|
"mlp.dense_4h_to_h.weight",
|
|
]
|
|
):
|
|
input_size = state.shape[1]
|
|
assert input_size % tp_world_size == 0
|
|
block_size = input_size // tp_world_size
|
|
sharded_weights = torch.split(state, block_size, dim=1)
|
|
assert len(sharded_weights) == tp_world_size
|
|
for tp_rank, shard in enumerate(sharded_weights):
|
|
shards_state_dicts[tp_rank][
|
|
"transformer." + state_name
|
|
] = shard.detach().clone()
|
|
|
|
elif any(
|
|
match_suffix(state_name, candidate)
|
|
for candidate in [
|
|
"self_attention.dense.bias",
|
|
"mlp.dense_4h_to_h.bias",
|
|
]
|
|
):
|
|
shards_state_dicts[0][
|
|
"transformer." + state_name
|
|
] = state.detach().clone()
|
|
for tp_rank in range(1, tp_world_size):
|
|
shards_state_dicts[tp_rank][
|
|
"transformer." + state_name
|
|
] = torch.zeros_like(state)
|
|
|
|
else:
|
|
# We duplicate parameters across tp ranks
|
|
for tp_rank in range(tp_world_size):
|
|
shards_state_dicts[tp_rank][
|
|
"transformer." + state_name
|
|
] = state.detach().clone()
|
|
|
|
del state_dict[state_name] # delete key from state_dict
|
|
del state # delete tensor
|
|
del state_dict
|
|
|
|
# we save state_dict
|
|
for tp_rank, (save_path, shard_state_dict) in enumerate(
|
|
zip(save_paths, shards_state_dicts)
|
|
):
|
|
save_paths.append(save_path)
|
|
save_path.parent.mkdir(parents=True, exist_ok=True)
|
|
if save_path.exists():
|
|
print(f"Skipping {save_path} as it already exists")
|
|
else:
|
|
torch.save(shard_state_dict, save_path)
|
|
|
|
return save_paths
|