[Safetensors] Make sure metadata is saved (#2506)
* [Safetensors] Make sure metadata is saved * make style
This commit is contained in:
parent
7f43f65235
commit
0e975e5ff6
|
@ -291,9 +291,6 @@ class ModelMixin(torch.nn.Module):
|
||||||
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||||
return
|
return
|
||||||
|
|
||||||
if save_function is None:
|
|
||||||
save_function = safetensors.torch.save_file if safe_serialization else torch.save
|
|
||||||
|
|
||||||
os.makedirs(save_directory, exist_ok=True)
|
os.makedirs(save_directory, exist_ok=True)
|
||||||
|
|
||||||
model_to_save = self
|
model_to_save = self
|
||||||
|
@ -310,7 +307,12 @@ class ModelMixin(torch.nn.Module):
|
||||||
weights_name = _add_variant(weights_name, variant)
|
weights_name = _add_variant(weights_name, variant)
|
||||||
|
|
||||||
# Save the model
|
# Save the model
|
||||||
save_function(state_dict, os.path.join(save_directory, weights_name))
|
if safe_serialization:
|
||||||
|
safetensors.torch.save_file(
|
||||||
|
state_dict, os.path.join(save_directory, weights_name), metadata={"format": "pt"}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
torch.save(state_dict, os.path.join(save_directory, weights_name))
|
||||||
|
|
||||||
logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
|
logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue