fix(server): harden the weights choice to save on disk. (#561)
- Look at `transformers` base class to check for `_key_to_ignore_on_load_missing` or `_tied_weights` which are the standard attributes to select the keys to NOT save on disk (since they are ignored) - Modified safetensors code (to be reflected in safetensors even if it's an internal function). - Will not work for trust_remote_code=True repos (like santacoder). Should help with : https://github.com/huggingface/text-generation-inference/issues/555 and : https://github.com/huggingface/text-generation-inference/pull/501 and https://github.com/huggingface/text-generation-inference/issues/556 and https://github.com/huggingface/text-generation-inference/issues/482#issuecomment-1623713593
This commit is contained in:
parent
31b36cca21
commit
e943a294bc
|
@ -14,7 +14,7 @@ def test_convert_files():
|
|||
local_st_files = [
|
||||
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" for p in local_pt_files
|
||||
]
|
||||
convert_files(local_pt_files, local_st_files)
|
||||
convert_files(local_pt_files, local_st_files, discard_names=[])
|
||||
|
||||
found_st_files = weight_files(model_id)
|
||||
|
||||
|
|
|
@ -160,8 +160,26 @@ def download_weights(
|
|||
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors"
|
||||
for p in local_pt_files
|
||||
]
|
||||
try:
|
||||
from transformers import AutoConfig
|
||||
import transformers
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
)
|
||||
architecture = config.architectures[0]
|
||||
|
||||
class_ = getattr(transformers, architecture)
|
||||
|
||||
# Name for this varible depends on transformers version.
|
||||
discard_names = getattr(class_, "_tied_weights_keys", [])
|
||||
discard_names.extend(getattr(class_, "_keys_to_ignore_on_load_missing", []))
|
||||
|
||||
except Exception as e:
|
||||
discard_names = []
|
||||
# Convert pytorch weights to safetensors
|
||||
utils.convert_files(local_pt_files, local_st_files)
|
||||
utils.convert_files(local_pt_files, local_st_files, discard_names)
|
||||
|
||||
|
||||
@app.command()
|
||||
|
|
|
@ -4,11 +4,56 @@ import os
|
|||
|
||||
from loguru import logger
|
||||
from pathlib import Path
|
||||
from safetensors.torch import save_file, _remove_duplicate_names, load_file
|
||||
from typing import List
|
||||
from safetensors.torch import save_file, load_file, _find_shared_tensors, _is_complete
|
||||
from typing import List, Dict
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
def convert_file(pt_file: Path, sf_file: Path):
|
||||
def _remove_duplicate_names(
|
||||
state_dict: Dict[str, torch.Tensor],
|
||||
*,
|
||||
preferred_names: List[str] = None,
|
||||
discard_names: List[str] = None,
|
||||
) -> Dict[str, List[str]]:
|
||||
if preferred_names is None:
|
||||
preferred_names = []
|
||||
preferred_names = set(preferred_names)
|
||||
if discard_names is None:
|
||||
discard_names = []
|
||||
discard_names = set(discard_names)
|
||||
|
||||
shareds = _find_shared_tensors(state_dict)
|
||||
to_remove = defaultdict(list)
|
||||
for shared in shareds:
|
||||
complete_names = set(
|
||||
[name for name in shared if _is_complete(state_dict[name])]
|
||||
)
|
||||
if not complete_names:
|
||||
raise RuntimeError(
|
||||
f"Error while trying to find names to remove to save state dict, but found no suitable name to keep for saving amongst: {shared}. None is covering the entire storage.Refusing to save/load the model since you could be storing much more memory than needed. Please refer to https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an issue."
|
||||
)
|
||||
|
||||
keep_name = sorted(list(complete_names))[0]
|
||||
|
||||
# Mecanism to preferentially select keys to keep
|
||||
# coming from the on-disk file to allow
|
||||
# loading models saved with a different choice
|
||||
# of keep_name
|
||||
preferred = complete_names.difference(discard_names)
|
||||
if preferred:
|
||||
keep_name = sorted(list(preferred))[0]
|
||||
|
||||
if preferred_names:
|
||||
preferred = preferred_names.intersection(complete_names)
|
||||
if preferred:
|
||||
keep_name = sorted(list(preferred))[0]
|
||||
for name in sorted(shared):
|
||||
if name != keep_name:
|
||||
to_remove[keep_name].append(name)
|
||||
return to_remove
|
||||
|
||||
|
||||
def convert_file(pt_file: Path, sf_file: Path, discard_names: List[str]):
|
||||
"""
|
||||
Convert a pytorch file to a safetensors file
|
||||
This will remove duplicate tensors from the file.
|
||||
|
@ -20,7 +65,7 @@ def convert_file(pt_file: Path, sf_file: Path):
|
|||
loaded = torch.load(pt_file, map_location="cpu")
|
||||
if "state_dict" in loaded:
|
||||
loaded = loaded["state_dict"]
|
||||
to_removes = _remove_duplicate_names(loaded)
|
||||
to_removes = _remove_duplicate_names(loaded, discard_names=discard_names)
|
||||
|
||||
metadata = {"format": "pt"}
|
||||
for kept_name, to_remove_group in to_removes.items():
|
||||
|
@ -42,7 +87,7 @@ def convert_file(pt_file: Path, sf_file: Path):
|
|||
raise RuntimeError(f"The output tensors do not match for key {k}")
|
||||
|
||||
|
||||
def convert_files(pt_files: List[Path], sf_files: List[Path]):
|
||||
def convert_files(pt_files: List[Path], sf_files: List[Path], discard_names: List[str]):
|
||||
assert len(pt_files) == len(sf_files)
|
||||
|
||||
N = len(pt_files)
|
||||
|
@ -50,6 +95,6 @@ def convert_files(pt_files: List[Path], sf_files: List[Path]):
|
|||
|
||||
for i, (pt_file, sf_file) in enumerate(zip(pt_files, sf_files)):
|
||||
start = datetime.datetime.now()
|
||||
convert_file(pt_file, sf_file)
|
||||
convert_file(pt_file, sf_file, discard_names)
|
||||
elapsed = datetime.datetime.now() - start
|
||||
logger.info(f"Convert: [{i + 1}/{N}] -- Took: {elapsed}")
|
||||
|
|
Loading…
Reference in New Issue