hf_text-generation-inference/server/text_generation_server/utils/convert.py

56 lines
1.9 KiB
Python
Raw Normal View History

import datetime
2023-02-14 05:02:16 -07:00
import torch
import os
2023-02-14 05:02:16 -07:00
from loguru import logger
from pathlib import Path
from safetensors.torch import save_file, _remove_duplicate_names, load_file
from typing import List
2023-02-14 05:02:16 -07:00
def convert_file(pt_file: Path, sf_file: Path):
2023-02-14 05:02:16 -07:00
"""
Convert a pytorch file to a safetensors file
This will remove duplicate tensors from the file.
2023-02-14 05:02:16 -07:00
Unfortunately, this might not respect *transformers* convention.
Forcing us to check for potentially different keys during load when looking
for specific tensors (making tensor sharing explicit).
"""
loaded = torch.load(pt_file, map_location="cpu")
if "state_dict" in loaded:
loaded = loaded["state_dict"]
to_removes = _remove_duplicate_names(loaded)
metadata = {"format": "pt"}
for kept_name, to_remove_group in to_removes.items():
for to_remove in to_remove_group:
if to_remove not in metadata:
metadata[to_remove] = kept_name
del loaded[to_remove]
# Force tensors to be contiguous
loaded = {k: v.contiguous() for k, v in loaded.items()}
dirname = os.path.dirname(sf_file)
os.makedirs(dirname, exist_ok=True)
save_file(loaded, sf_file, metadata=metadata)
reloaded = load_file(sf_file)
for k in loaded:
pt_tensor = loaded[k]
sf_tensor = reloaded[k]
if not torch.equal(pt_tensor, sf_tensor):
raise RuntimeError(f"The output tensors do not match for key {k}")
2023-02-14 05:02:16 -07:00
def convert_files(pt_files: List[Path], sf_files: List[Path]):
assert len(pt_files) == len(sf_files)
2023-02-14 05:02:16 -07:00
N = len(pt_files)
2023-02-14 05:02:16 -07:00
# We do this instead of using tqdm because we want to parse the logs with the launcher
for i, (pt_file, sf_file) in enumerate(zip(pt_files, sf_files)):
2023-05-05 07:28:08 -06:00
start = datetime.datetime.now()
convert_file(pt_file, sf_file)
elapsed = datetime.datetime.now() - start
logger.info(f"Convert: [{i + 1}/{N}] -- Took: {elapsed}")