2023-05-04 07:22:54 -06:00
import datetime
2023-02-14 05:02:16 -07:00
import torch
2023-06-23 04:40:46 -06:00
import os
2023-02-14 05:02:16 -07:00
from loguru import logger
from pathlib import Path
2023-07-07 06:50:12 -06:00
from safetensors . torch import save_file , load_file , _find_shared_tensors , _is_complete
from typing import List , Dict
from collections import defaultdict
2023-02-14 05:02:16 -07:00
2023-07-07 06:50:12 -06:00
def _remove_duplicate_names (
state_dict : Dict [ str , torch . Tensor ] ,
* ,
preferred_names : List [ str ] = None ,
discard_names : List [ str ] = None ,
) - > Dict [ str , List [ str ] ] :
if preferred_names is None :
preferred_names = [ ]
preferred_names = set ( preferred_names )
if discard_names is None :
discard_names = [ ]
discard_names = set ( discard_names )
shareds = _find_shared_tensors ( state_dict )
to_remove = defaultdict ( list )
for shared in shareds :
complete_names = set (
[ name for name in shared if _is_complete ( state_dict [ name ] ) ]
)
if not complete_names :
raise RuntimeError (
f " Error while trying to find names to remove to save state dict, but found no suitable name to keep for saving amongst: { shared } . None is covering the entire storage.Refusing to save/load the model since you could be storing much more memory than needed. Please refer to https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an issue. "
)
keep_name = sorted ( list ( complete_names ) ) [ 0 ]
# Mecanism to preferentially select keys to keep
# coming from the on-disk file to allow
# loading models saved with a different choice
# of keep_name
preferred = complete_names . difference ( discard_names )
if preferred :
keep_name = sorted ( list ( preferred ) ) [ 0 ]
if preferred_names :
preferred = preferred_names . intersection ( complete_names )
if preferred :
keep_name = sorted ( list ( preferred ) ) [ 0 ]
for name in sorted ( shared ) :
if name != keep_name :
to_remove [ keep_name ] . append ( name )
return to_remove
def convert_file ( pt_file : Path , sf_file : Path , discard_names : List [ str ] ) :
2023-02-14 05:02:16 -07:00
"""
Convert a pytorch file to a safetensors file
2023-06-23 04:40:46 -06:00
This will remove duplicate tensors from the file .
2023-02-14 05:02:16 -07:00
2023-06-23 04:40:46 -06:00
Unfortunately , this might not respect * transformers * convention .
Forcing us to check for potentially different keys during load when looking
for specific tensors ( making tensor sharing explicit ) .
"""
loaded = torch . load ( pt_file , map_location = " cpu " )
if " state_dict " in loaded :
loaded = loaded [ " state_dict " ]
2023-07-07 06:50:12 -06:00
to_removes = _remove_duplicate_names ( loaded , discard_names = discard_names )
2023-06-23 04:40:46 -06:00
metadata = { " format " : " pt " }
for kept_name , to_remove_group in to_removes . items ( ) :
for to_remove in to_remove_group :
if to_remove not in metadata :
metadata [ to_remove ] = kept_name
del loaded [ to_remove ]
# Force tensors to be contiguous
loaded = { k : v . contiguous ( ) for k , v in loaded . items ( ) }
dirname = os . path . dirname ( sf_file )
os . makedirs ( dirname , exist_ok = True )
save_file ( loaded , sf_file , metadata = metadata )
reloaded = load_file ( sf_file )
for k in loaded :
pt_tensor = loaded [ k ]
sf_tensor = reloaded [ k ]
if not torch . equal ( pt_tensor , sf_tensor ) :
raise RuntimeError ( f " The output tensors do not match for key { k } " )
2023-02-14 05:02:16 -07:00
2023-07-07 06:50:12 -06:00
def convert_files ( pt_files : List [ Path ] , sf_files : List [ Path ] , discard_names : List [ str ] ) :
2023-05-05 09:57:02 -06:00
assert len ( pt_files ) == len ( sf_files )
2023-02-14 05:02:16 -07:00
2023-05-04 07:22:54 -06:00
N = len ( pt_files )
2023-02-14 05:02:16 -07:00
# We do this instead of using tqdm because we want to parse the logs with the launcher
2023-05-05 09:57:02 -06:00
for i , ( pt_file , sf_file ) in enumerate ( zip ( pt_files , sf_files ) ) :
2023-07-13 13:54:55 -06:00
# Skip blacklisted files
if (
" arguments " in pt_file . name
or " args " in pt_file . name
or " training " in pt_file . name
) :
continue
2023-05-05 07:28:08 -06:00
start = datetime . datetime . now ( )
2023-07-07 06:50:12 -06:00
convert_file ( pt_file , sf_file , discard_names )
2023-05-04 07:22:54 -06:00
elapsed = datetime . datetime . now ( ) - start
logger . info ( f " Convert: [ { i + 1 } / { N } ] -- Took: { elapsed } " )