fix(server): harden the weights choice to save on disk. (#561)

- Look at `transformers` base class to check for
  `_key_to_ignore_on_load_missing` or `_tied_weights` which are the
  standard attributes to select the keys to NOT save on disk (since they
  are ignored)

- Modified safetensors code (to be reflected in safetensors even if it's
  an internal function).
  
- Will not work for trust_remote_code=True repos (like santacoder).

Should help with :
https://github.com/huggingface/text-generation-inference/issues/555
and : https://github.com/huggingface/text-generation-inference/pull/501
and https://github.com/huggingface/text-generation-inference/issues/556
and
https://github.com/huggingface/text-generation-inference/issues/482#issuecomment-1623713593
This commit is contained in:
Nicolas Patry 2023-07-07 14:50:12 +02:00 committed by GitHub
parent 31b36cca21
commit e943a294bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 71 additions and 8 deletions

View File

@ -14,7 +14,7 @@ def test_convert_files():
local_st_files = [
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" for p in local_pt_files
]
convert_files(local_pt_files, local_st_files)
convert_files(local_pt_files, local_st_files, discard_names=[])
found_st_files = weight_files(model_id)

View File

@ -160,8 +160,26 @@ def download_weights(
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors"
for p in local_pt_files
]
try:
from transformers import AutoConfig
import transformers
config = AutoConfig.from_pretrained(
model_id,
revision=revision,
)
architecture = config.architectures[0]
class_ = getattr(transformers, architecture)
# Name for this varible depends on transformers version.
discard_names = getattr(class_, "_tied_weights_keys", [])
discard_names.extend(getattr(class_, "_keys_to_ignore_on_load_missing", []))
except Exception as e:
discard_names = []
# Convert pytorch weights to safetensors
utils.convert_files(local_pt_files, local_st_files)
utils.convert_files(local_pt_files, local_st_files, discard_names)
@app.command()

View File

@ -4,11 +4,56 @@ import os
from loguru import logger
from pathlib import Path
from safetensors.torch import save_file, _remove_duplicate_names, load_file
from typing import List
from safetensors.torch import save_file, load_file, _find_shared_tensors, _is_complete
from typing import List, Dict
from collections import defaultdict
def convert_file(pt_file: Path, sf_file: Path):
def _remove_duplicate_names(
state_dict: Dict[str, torch.Tensor],
*,
preferred_names: List[str] = None,
discard_names: List[str] = None,
) -> Dict[str, List[str]]:
if preferred_names is None:
preferred_names = []
preferred_names = set(preferred_names)
if discard_names is None:
discard_names = []
discard_names = set(discard_names)
shareds = _find_shared_tensors(state_dict)
to_remove = defaultdict(list)
for shared in shareds:
complete_names = set(
[name for name in shared if _is_complete(state_dict[name])]
)
if not complete_names:
raise RuntimeError(
f"Error while trying to find names to remove to save state dict, but found no suitable name to keep for saving amongst: {shared}. None is covering the entire storage.Refusing to save/load the model since you could be storing much more memory than needed. Please refer to https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an issue."
)
keep_name = sorted(list(complete_names))[0]
# Mecanism to preferentially select keys to keep
# coming from the on-disk file to allow
# loading models saved with a different choice
# of keep_name
preferred = complete_names.difference(discard_names)
if preferred:
keep_name = sorted(list(preferred))[0]
if preferred_names:
preferred = preferred_names.intersection(complete_names)
if preferred:
keep_name = sorted(list(preferred))[0]
for name in sorted(shared):
if name != keep_name:
to_remove[keep_name].append(name)
return to_remove
def convert_file(pt_file: Path, sf_file: Path, discard_names: List[str]):
"""
Convert a pytorch file to a safetensors file
This will remove duplicate tensors from the file.
@ -20,7 +65,7 @@ def convert_file(pt_file: Path, sf_file: Path):
loaded = torch.load(pt_file, map_location="cpu")
if "state_dict" in loaded:
loaded = loaded["state_dict"]
to_removes = _remove_duplicate_names(loaded)
to_removes = _remove_duplicate_names(loaded, discard_names=discard_names)
metadata = {"format": "pt"}
for kept_name, to_remove_group in to_removes.items():
@ -42,7 +87,7 @@ def convert_file(pt_file: Path, sf_file: Path):
raise RuntimeError(f"The output tensors do not match for key {k}")
def convert_files(pt_files: List[Path], sf_files: List[Path]):
def convert_files(pt_files: List[Path], sf_files: List[Path], discard_names: List[str]):
assert len(pt_files) == len(sf_files)
N = len(pt_files)
@ -50,6 +95,6 @@ def convert_files(pt_files: List[Path], sf_files: List[Path]):
for i, (pt_file, sf_file) in enumerate(zip(pt_files, sf_files)):
start = datetime.datetime.now()
convert_file(pt_file, sf_file)
convert_file(pt_file, sf_file, discard_names)
elapsed = datetime.datetime.now() - start
logger.info(f"Convert: [{i + 1}/{N}] -- Took: {elapsed}")