feat(server): Update convert logic. (#483)
Should be more robust to shared tensors (ok when using `from_pretrained). But forcing us to add new checks in our loading code (since the chosen key to keep might be different from `transformers`). --------- Co-authored-by: Ubuntu <ubuntu@ip-172-31-41-161.ec2.internal>
This commit is contained in:
parent
c9c65ab323
commit
49b4b33e80
|
@ -52,7 +52,8 @@ class FlashSantacoderSharded(FlashCausalLM):
|
|||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(
|
||||
filenames, device=device, dtype=dtype, process_group=self.process_group
|
||||
filenames, device=device, dtype=dtype, process_group=self.process_group,
|
||||
aliases = {"transformer.wte.weight": ["lm_head.weight"]}
|
||||
)
|
||||
|
||||
model = FlashSantacoderForCausalLM(config, weights)
|
||||
|
|
|
@ -1,76 +1,45 @@
|
|||
import datetime
|
||||
import torch
|
||||
import os
|
||||
|
||||
from collections import defaultdict
|
||||
from loguru import logger
|
||||
from pathlib import Path
|
||||
from safetensors.torch import save_file
|
||||
from safetensors import safe_open
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
def check_file_size(source_file: Path, target_file: Path):
|
||||
"""
|
||||
Check that two files are close in size
|
||||
"""
|
||||
source_file_size = source_file.stat().st_size
|
||||
target_file_size = target_file.stat().st_size
|
||||
|
||||
if (source_file_size - target_file_size) / source_file_size > 0.05:
|
||||
raise RuntimeError(
|
||||
f"""The file size different is more than 5%:
|
||||
- {source_file}: {source_file_size}
|
||||
- {target_file}: {target_file_size}
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def remove_shared_pointers(tensors: Dict[str, torch.Tensor]):
|
||||
"""
|
||||
For a Dict of tensors, check if two or more tensors point to the same underlying memory and
|
||||
remove them
|
||||
"""
|
||||
ptrs = defaultdict(list)
|
||||
for k, v in tensors.items():
|
||||
ptrs[v.data_ptr()].append(k)
|
||||
|
||||
# Iterate over all found memory addresses
|
||||
for ptr, names in ptrs.items():
|
||||
if len(names) > 1:
|
||||
# Multiple tensors are point to the same memory
|
||||
# Only keep the first tensor
|
||||
for name in names[1:]:
|
||||
tensors.pop(name)
|
||||
from safetensors.torch import save_file, _remove_duplicate_names, load_file
|
||||
from typing import List
|
||||
|
||||
|
||||
def convert_file(pt_file: Path, sf_file: Path):
|
||||
"""
|
||||
Convert a pytorch file to a safetensors file
|
||||
This will remove duplicate tensors from the file.
|
||||
|
||||
Unfortunately, this might not respect *transformers* convention.
|
||||
Forcing us to check for potentially different keys during load when looking
|
||||
for specific tensors (making tensor sharing explicit).
|
||||
"""
|
||||
logger.info(f"Convert {pt_file} to {sf_file}.")
|
||||
loaded = torch.load(pt_file, map_location="cpu")
|
||||
if "state_dict" in loaded:
|
||||
loaded = loaded["state_dict"]
|
||||
to_removes = _remove_duplicate_names(loaded)
|
||||
|
||||
pt_state = torch.load(pt_file, map_location="cpu")
|
||||
if "state_dict" in pt_state:
|
||||
pt_state = pt_state["state_dict"]
|
||||
metadata = {"format": "pt"}
|
||||
for kept_name, to_remove_group in to_removes.items():
|
||||
for to_remove in to_remove_group:
|
||||
if to_remove not in metadata:
|
||||
metadata[to_remove] = kept_name
|
||||
del loaded[to_remove]
|
||||
# Force tensors to be contiguous
|
||||
loaded = {k: v.contiguous() for k, v in loaded.items()}
|
||||
|
||||
remove_shared_pointers(pt_state)
|
||||
|
||||
# Tensors need to be contiguous
|
||||
pt_state = {k: v.contiguous() for k, v in pt_state.items()}
|
||||
|
||||
sf_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
save_file(pt_state, str(sf_file), metadata={"format": "pt"})
|
||||
|
||||
# Check that both files are close in size
|
||||
check_file_size(pt_file, sf_file)
|
||||
|
||||
# Load safetensors state
|
||||
for k in pt_state:
|
||||
pt_tensor = pt_state[k]
|
||||
with safe_open(sf_file, framework="pt") as f:
|
||||
sf_tensor = f.get_tensor(k)
|
||||
if not torch.equal(pt_tensor, sf_tensor):
|
||||
raise RuntimeError(f"The output tensors do not match for key {k}")
|
||||
dirname = os.path.dirname(sf_file)
|
||||
os.makedirs(dirname, exist_ok=True)
|
||||
save_file(loaded, sf_file, metadata=metadata)
|
||||
reloaded = load_file(sf_file)
|
||||
for k in loaded:
|
||||
pt_tensor = loaded[k]
|
||||
sf_tensor = reloaded[k]
|
||||
if not torch.equal(pt_tensor, sf_tensor):
|
||||
raise RuntimeError(f"The output tensors do not match for key {k}")
|
||||
|
||||
|
||||
def convert_files(pt_files: List[Path], sf_files: List[Path]):
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
from pathlib import Path
|
||||
from typing import List
|
||||
from typing import List, Dict, Optional
|
||||
from safetensors import safe_open
|
||||
|
||||
|
||||
class Weights:
|
||||
def __init__(self, filenames: List[Path], device, dtype, process_group):
|
||||
def __init__(self, filenames: List[Path], device, dtype, process_group, aliases: Optional[Dict[str, List[str]]]=None):
|
||||
routing = {}
|
||||
for filename in filenames:
|
||||
with safe_open(filename, framework="pytorch") as f:
|
||||
|
@ -14,6 +14,9 @@ class Weights:
|
|||
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
|
||||
)
|
||||
routing[k] = filename
|
||||
if aliases is None:
|
||||
aliases = {}
|
||||
self.aliases = aliases
|
||||
self.routing = routing
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
|
@ -27,14 +30,19 @@ class Weights:
|
|||
|
||||
return self._handles[filename]
|
||||
|
||||
def get_filename(self, tensor_name: str) -> str:
|
||||
def get_filename(self, tensor_name: str) -> (str, str):
|
||||
filename = self.routing.get(tensor_name, None)
|
||||
if filename is None:
|
||||
aliases = self.aliases.get(tensor_name, [])
|
||||
for alias in aliases:
|
||||
filename = self.routing.get(alias, None)
|
||||
if filename is not None:
|
||||
return str(filename), alias
|
||||
raise RuntimeError(f"weight {tensor_name} does not exist")
|
||||
return str(filename)
|
||||
return str(filename), tensor_name
|
||||
|
||||
def _get_slice(self, tensor_name: str):
|
||||
filename = self.get_filename(tensor_name)
|
||||
filename, tensor_name= self.get_filename(tensor_name)
|
||||
f = self._get_handle(filename)
|
||||
slice_ = f.get_slice(tensor_name)
|
||||
return slice_
|
||||
|
@ -43,7 +51,7 @@ class Weights:
|
|||
return self._get_slice(tensor_name).get_shape()
|
||||
|
||||
def get_tensor(self, tensor_name: str):
|
||||
filename = self.get_filename(tensor_name)
|
||||
filename, tensor_name = self.get_filename(tensor_name)
|
||||
f = self._get_handle(filename)
|
||||
tensor = f.get_tensor(tensor_name)
|
||||
tensor = tensor.to(dtype=self.dtype)
|
||||
|
@ -51,7 +59,7 @@ class Weights:
|
|||
return tensor
|
||||
|
||||
def get_sharded(self, tensor_name: str, dim: int):
|
||||
filename = self.get_filename(tensor_name)
|
||||
filename, tensor_name = self.get_filename(tensor_name)
|
||||
world_size = self.process_group.size()
|
||||
rank = self.process_group.rank()
|
||||
|
||||
|
|
Loading…
Reference in New Issue