safetensors optional for now
This commit is contained in:
parent
210cb4c128
commit
ac90cf38c6
|
@ -4,7 +4,6 @@ import sys
|
||||||
import gc
|
import gc
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
import torch
|
import torch
|
||||||
from safetensors.torch import load_file, save_file
|
|
||||||
import re
|
import re
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
@ -149,6 +148,10 @@ def torch_load(model_filename, model_info, map_override=None):
|
||||||
# safely load weights
|
# safely load weights
|
||||||
# TODO: safetensors supports zero copy fast load to gpu, see issue #684.
|
# TODO: safetensors supports zero copy fast load to gpu, see issue #684.
|
||||||
# GPU only for now, see https://github.com/huggingface/safetensors/issues/95
|
# GPU only for now, see https://github.com/huggingface/safetensors/issues/95
|
||||||
|
try:
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(f"The model is in safetensors format and it is not installed, use `pip install safetensors`: {e}")
|
||||||
return load_file(model_filename, device='cuda')
|
return load_file(model_filename, device='cuda')
|
||||||
else:
|
else:
|
||||||
return torch.load(model_filename, map_location=map_override)
|
return torch.load(model_filename, map_location=map_override)
|
||||||
|
@ -157,6 +160,10 @@ def torch_save(model, output_filename):
|
||||||
basename, exttype = os.path.splitext(output_filename)
|
basename, exttype = os.path.splitext(output_filename)
|
||||||
if(checkpoint_types[exttype] == 'safetensors'):
|
if(checkpoint_types[exttype] == 'safetensors'):
|
||||||
# [===== >] Reticulating brines...
|
# [===== >] Reticulating brines...
|
||||||
|
try:
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(f"Export as safetensors selected, yet it is not installed, use `pip install safetensors`: {e}")
|
||||||
save_file(model, output_filename, metadata={"format": "pt"})
|
save_file(model, output_filename, metadata={"format": "pt"})
|
||||||
else:
|
else:
|
||||||
torch.save(model, output_filename)
|
torch.save(model, output_filename)
|
||||||
|
|
|
@ -28,4 +28,3 @@ kornia
|
||||||
lark
|
lark
|
||||||
inflection
|
inflection
|
||||||
GitPython
|
GitPython
|
||||||
safetensors
|
|
||||||
|
|
Loading…
Reference in New Issue