2023-03-07 10:52:22 -07:00
|
|
|
from text_generation_server.utils.hub import (
|
|
|
|
download_weights,
|
|
|
|
weight_hub_files,
|
|
|
|
weight_files,
|
|
|
|
)
|
2023-02-14 05:02:16 -07:00
|
|
|
|
2023-03-07 10:52:22 -07:00
|
|
|
from text_generation_server.utils.convert import convert_files
|
2023-02-14 05:02:16 -07:00
|
|
|
|
|
|
|
|
|
|
|
def test_convert_files():
|
|
|
|
model_id = "bigscience/bloom-560m"
|
|
|
|
pt_filenames = weight_hub_files(model_id, extension=".bin")
|
|
|
|
local_pt_files = download_weights(pt_filenames, model_id)
|
|
|
|
local_st_files = [
|
|
|
|
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" for p in local_pt_files
|
|
|
|
]
|
2023-07-07 06:50:12 -06:00
|
|
|
convert_files(local_pt_files, local_st_files, discard_names=[])
|
2023-02-14 05:02:16 -07:00
|
|
|
|
|
|
|
found_st_files = weight_files(model_id)
|
|
|
|
|
|
|
|
assert all([p in found_st_files for p in local_st_files])
|