feat(server): Update convert logic. (#483)

Should be more robust to shared tensors (ok when using
      `from_pretrained). But forcing us to add new checks in our loading
      code (since the chosen key to keep might be different from
      `transformers`).

---------

Co-authored-by: Ubuntu <ubuntu@ip-172-31-41-161.ec2.internal>
This commit is contained in:
Nicolas Patry 2023-06-23 12:40:46 +02:00 committed by GitHub
parent c9c65ab323
commit 49b4b33e80
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 46 additions and 68 deletions

View File

@ -52,7 +52,8 @@ class FlashSantacoderSharded(FlashCausalLM):
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights( weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group filenames, device=device, dtype=dtype, process_group=self.process_group,
aliases = {"transformer.wte.weight": ["lm_head.weight"]}
) )
model = FlashSantacoderForCausalLM(config, weights) model = FlashSantacoderForCausalLM(config, weights)

View File

@ -1,76 +1,45 @@
import datetime import datetime
import torch import torch
import os
from collections import defaultdict
from loguru import logger from loguru import logger
from pathlib import Path from pathlib import Path
from safetensors.torch import save_file from safetensors.torch import save_file, _remove_duplicate_names, load_file
from safetensors import safe_open from typing import List
from typing import Dict, List
def check_file_size(source_file: Path, target_file: Path):
"""
Check that two files are close in size
"""
source_file_size = source_file.stat().st_size
target_file_size = target_file.stat().st_size
if (source_file_size - target_file_size) / source_file_size > 0.05:
raise RuntimeError(
f"""The file size different is more than 5%:
- {source_file}: {source_file_size}
- {target_file}: {target_file_size}
"""
)
def remove_shared_pointers(tensors: Dict[str, torch.Tensor]):
"""
For a Dict of tensors, check if two or more tensors point to the same underlying memory and
remove them
"""
ptrs = defaultdict(list)
for k, v in tensors.items():
ptrs[v.data_ptr()].append(k)
# Iterate over all found memory addresses
for ptr, names in ptrs.items():
if len(names) > 1:
# Multiple tensors are point to the same memory
# Only keep the first tensor
for name in names[1:]:
tensors.pop(name)
def convert_file(pt_file: Path, sf_file: Path): def convert_file(pt_file: Path, sf_file: Path):
""" """
Convert a pytorch file to a safetensors file Convert a pytorch file to a safetensors file
This will remove duplicate tensors from the file.
Unfortunately, this might not respect *transformers* convention.
Forcing us to check for potentially different keys during load when looking
for specific tensors (making tensor sharing explicit).
""" """
logger.info(f"Convert {pt_file} to {sf_file}.") loaded = torch.load(pt_file, map_location="cpu")
if "state_dict" in loaded:
loaded = loaded["state_dict"]
to_removes = _remove_duplicate_names(loaded)
pt_state = torch.load(pt_file, map_location="cpu") metadata = {"format": "pt"}
if "state_dict" in pt_state: for kept_name, to_remove_group in to_removes.items():
pt_state = pt_state["state_dict"] for to_remove in to_remove_group:
if to_remove not in metadata:
metadata[to_remove] = kept_name
del loaded[to_remove]
# Force tensors to be contiguous
loaded = {k: v.contiguous() for k, v in loaded.items()}
remove_shared_pointers(pt_state) dirname = os.path.dirname(sf_file)
os.makedirs(dirname, exist_ok=True)
# Tensors need to be contiguous save_file(loaded, sf_file, metadata=metadata)
pt_state = {k: v.contiguous() for k, v in pt_state.items()} reloaded = load_file(sf_file)
for k in loaded:
sf_file.parent.mkdir(parents=True, exist_ok=True) pt_tensor = loaded[k]
save_file(pt_state, str(sf_file), metadata={"format": "pt"}) sf_tensor = reloaded[k]
if not torch.equal(pt_tensor, sf_tensor):
# Check that both files are close in size raise RuntimeError(f"The output tensors do not match for key {k}")
check_file_size(pt_file, sf_file)
# Load safetensors state
for k in pt_state:
pt_tensor = pt_state[k]
with safe_open(sf_file, framework="pt") as f:
sf_tensor = f.get_tensor(k)
if not torch.equal(pt_tensor, sf_tensor):
raise RuntimeError(f"The output tensors do not match for key {k}")
def convert_files(pt_files: List[Path], sf_files: List[Path]): def convert_files(pt_files: List[Path], sf_files: List[Path]):

View File

@ -1,10 +1,10 @@
from pathlib import Path from pathlib import Path
from typing import List from typing import List, Dict, Optional
from safetensors import safe_open from safetensors import safe_open
class Weights: class Weights:
def __init__(self, filenames: List[Path], device, dtype, process_group): def __init__(self, filenames: List[Path], device, dtype, process_group, aliases: Optional[Dict[str, List[str]]]=None):
routing = {} routing = {}
for filename in filenames: for filename in filenames:
with safe_open(filename, framework="pytorch") as f: with safe_open(filename, framework="pytorch") as f:
@ -14,6 +14,9 @@ class Weights:
f"Key {k} was found in multiple files: {filename} and {routing[k]}" f"Key {k} was found in multiple files: {filename} and {routing[k]}"
) )
routing[k] = filename routing[k] = filename
if aliases is None:
aliases = {}
self.aliases = aliases
self.routing = routing self.routing = routing
self.device = device self.device = device
self.dtype = dtype self.dtype = dtype
@ -27,14 +30,19 @@ class Weights:
return self._handles[filename] return self._handles[filename]
def get_filename(self, tensor_name: str) -> str: def get_filename(self, tensor_name: str) -> (str, str):
filename = self.routing.get(tensor_name, None) filename = self.routing.get(tensor_name, None)
if filename is None: if filename is None:
aliases = self.aliases.get(tensor_name, [])
for alias in aliases:
filename = self.routing.get(alias, None)
if filename is not None:
return str(filename), alias
raise RuntimeError(f"weight {tensor_name} does not exist") raise RuntimeError(f"weight {tensor_name} does not exist")
return str(filename) return str(filename), tensor_name
def _get_slice(self, tensor_name: str): def _get_slice(self, tensor_name: str):
filename = self.get_filename(tensor_name) filename, tensor_name= self.get_filename(tensor_name)
f = self._get_handle(filename) f = self._get_handle(filename)
slice_ = f.get_slice(tensor_name) slice_ = f.get_slice(tensor_name)
return slice_ return slice_
@ -43,7 +51,7 @@ class Weights:
return self._get_slice(tensor_name).get_shape() return self._get_slice(tensor_name).get_shape()
def get_tensor(self, tensor_name: str): def get_tensor(self, tensor_name: str):
filename = self.get_filename(tensor_name) filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename) f = self._get_handle(filename)
tensor = f.get_tensor(tensor_name) tensor = f.get_tensor(tensor_name)
tensor = tensor.to(dtype=self.dtype) tensor = tensor.to(dtype=self.dtype)
@ -51,7 +59,7 @@ class Weights:
return tensor return tensor
def get_sharded(self, tensor_name: str, dim: int): def get_sharded(self, tensor_name: str, dim: int):
filename = self.get_filename(tensor_name) filename, tensor_name = self.get_filename(tensor_name)
world_size = self.process_group.size() world_size = self.process_group.size()
rank = self.process_group.rank() rank = self.process_group.rank()