import datetime import torch import os from loguru import logger from pathlib import Path from safetensors.torch import save_file, load_file, _find_shared_tensors, _is_complete from typing import List, Dict from collections import defaultdict def _remove_duplicate_names( state_dict: Dict[str, torch.Tensor], *, preferred_names: List[str] = None, discard_names: List[str] = None, ) -> Dict[str, List[str]]: if preferred_names is None: preferred_names = [] preferred_names = set(preferred_names) if discard_names is None: discard_names = [] discard_names = set(discard_names) shareds = _find_shared_tensors(state_dict) to_remove = defaultdict(list) for shared in shareds: complete_names = set( [name for name in shared if _is_complete(state_dict[name])] ) if not complete_names: if len(shared) == 1: # Force contiguous name = list(shared)[0] state_dict[name] = state_dict[name].clone() complete_names = {name} else: raise RuntimeError( f"Error while trying to find names to remove to save state dict, but found no suitable name to keep for saving amongst: {shared}. None is covering the entire storage.Refusing to save/load the model since you could be storing much more memory than needed. Please refer to https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an issue." ) keep_name = sorted(list(complete_names))[0] # Mecanism to preferentially select keys to keep # coming from the on-disk file to allow # loading models saved with a different choice # of keep_name preferred = complete_names.difference(discard_names) if preferred: keep_name = sorted(list(preferred))[0] if preferred_names: preferred = preferred_names.intersection(complete_names) if preferred: keep_name = sorted(list(preferred))[0] for name in sorted(shared): if name != keep_name: to_remove[keep_name].append(name) return to_remove def convert_file(pt_file: Path, sf_file: Path, discard_names: List[str]): """ Convert a pytorch file to a safetensors file This will remove duplicate tensors from the file. Unfortunately, this might not respect *transformers* convention. Forcing us to check for potentially different keys during load when looking for specific tensors (making tensor sharing explicit). """ loaded = torch.load(pt_file, map_location="cpu", weights_only=True) if "state_dict" in loaded: loaded = loaded["state_dict"] to_removes = _remove_duplicate_names(loaded, discard_names=discard_names) metadata = {"format": "pt"} for kept_name, to_remove_group in to_removes.items(): for to_remove in to_remove_group: if to_remove not in metadata: metadata[to_remove] = kept_name del loaded[to_remove] # Force tensors to be contiguous loaded = {k: v.contiguous() for k, v in loaded.items()} dirname = os.path.dirname(sf_file) os.makedirs(dirname, exist_ok=True) save_file(loaded, sf_file, metadata=metadata) reloaded = load_file(sf_file) for k in loaded: pt_tensor = loaded[k] sf_tensor = reloaded[k] if not torch.equal(pt_tensor, sf_tensor): raise RuntimeError(f"The output tensors do not match for key {k}") def convert_files(pt_files: List[Path], sf_files: List[Path], discard_names: List[str]): assert len(pt_files) == len(sf_files) N = len(pt_files) # We do this instead of using tqdm because we want to parse the logs with the launcher for i, (pt_file, sf_file) in enumerate(zip(pt_files, sf_files)): # Skip blacklisted files if ( "arguments" in pt_file.name or "args" in pt_file.name or "training" in pt_file.name ): continue start = datetime.datetime.now() convert_file(pt_file, sf_file, discard_names) elapsed = datetime.datetime.now() - start logger.info(f"Convert: [{i + 1}/{N}] -- Took: {elapsed}")