hf_text-generation-inference/server/tests/utils/test_convert.py

18 lines
610 B
Python

from text_generation.utils.hub import download_weights, weight_hub_files, weight_files
from text_generation.utils.convert import convert_files
def test_convert_files():
model_id = "bigscience/bloom-560m"
pt_filenames = weight_hub_files(model_id, extension=".bin")
local_pt_files = download_weights(pt_filenames, model_id)
local_st_files = [
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" for p in local_pt_files
]
convert_files(local_pt_files, local_st_files)
found_st_files = weight_files(model_id)
assert all([p in found_st_files for p in local_st_files])