is_safetensors_compatible refactor (#2499)
* is_safetensors_compatible refactor * files list comma
This commit is contained in:
parent
a75ac3fa8d
commit
856dad57bb
|
@ -129,21 +129,49 @@ class AudioPipelineOutput(BaseOutput):
|
|||
|
||||
|
||||
def is_safetensors_compatible(filenames, variant=None) -> bool:
|
||||
pt_filenames = set(filename for filename in filenames if filename.endswith(".bin"))
|
||||
is_safetensors_compatible = any(file.endswith(".safetensors") for file in filenames)
|
||||
"""
|
||||
Checking for safetensors compatibility:
|
||||
- By default, all models are saved with the default pytorch serialization, so we use the list of default pytorch
|
||||
files to know which safetensors files are needed.
|
||||
- The model is safetensors compatible only if there is a matching safetensors file for every default pytorch file.
|
||||
|
||||
for pt_filename in pt_filenames:
|
||||
_variant = f".{variant}" if (variant is not None and variant in pt_filename) else ""
|
||||
prefix, raw = os.path.split(pt_filename)
|
||||
if raw == f"pytorch_model{_variant}.bin":
|
||||
# transformers specific
|
||||
sf_filename = os.path.join(prefix, f"model{_variant}.safetensors")
|
||||
Converting default pytorch serialized filenames to safetensors serialized filenames:
|
||||
- For models from the diffusers library, just replace the ".bin" extension with ".safetensors"
|
||||
- For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin"
|
||||
extension is replaced with ".safetensors"
|
||||
"""
|
||||
pt_filenames = []
|
||||
|
||||
sf_filenames = set()
|
||||
|
||||
for filename in filenames:
|
||||
_, extension = os.path.splitext(filename)
|
||||
|
||||
if extension == ".bin":
|
||||
pt_filenames.append(filename)
|
||||
elif extension == ".safetensors":
|
||||
sf_filenames.add(filename)
|
||||
|
||||
for filename in pt_filenames:
|
||||
# filename = 'foo/bar/baz.bam' -> path = 'foo/bar', filename = 'baz', extention = '.bam'
|
||||
path, filename = os.path.split(filename)
|
||||
filename, extension = os.path.splitext(filename)
|
||||
|
||||
if filename == "pytorch_model":
|
||||
filename = "model"
|
||||
elif filename == f"pytorch_model.{variant}":
|
||||
filename = f"model.{variant}"
|
||||
else:
|
||||
sf_filename = pt_filename[: -len(".bin")] + ".safetensors"
|
||||
if is_safetensors_compatible and sf_filename not in filenames:
|
||||
logger.warning(f"{sf_filename} not found")
|
||||
is_safetensors_compatible = False
|
||||
return is_safetensors_compatible
|
||||
filename = filename
|
||||
|
||||
expected_sf_filename = os.path.join(path, filename)
|
||||
expected_sf_filename = f"{expected_sf_filename}.safetensors"
|
||||
|
||||
if expected_sf_filename not in sf_filenames:
|
||||
logger.warning(f"{expected_sf_filename} not found")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def variant_compatible_siblings(info, variant=None) -> Union[List[os.PathLike], str]:
|
||||
|
|
|
@ -0,0 +1,134 @@
|
|||
import unittest
|
||||
|
||||
from diffusers.pipelines.pipeline_utils import is_safetensors_compatible
|
||||
|
||||
|
||||
class IsSafetensorsCompatibleTests(unittest.TestCase):
|
||||
def test_all_is_compatible(self):
|
||||
filenames = [
|
||||
"safety_checker/pytorch_model.bin",
|
||||
"safety_checker/model.safetensors",
|
||||
"vae/diffusion_pytorch_model.bin",
|
||||
"vae/diffusion_pytorch_model.safetensors",
|
||||
"text_encoder/pytorch_model.bin",
|
||||
"text_encoder/model.safetensors",
|
||||
"unet/diffusion_pytorch_model.bin",
|
||||
"unet/diffusion_pytorch_model.safetensors",
|
||||
]
|
||||
self.assertTrue(is_safetensors_compatible(filenames))
|
||||
|
||||
def test_diffusers_model_is_compatible(self):
|
||||
filenames = [
|
||||
"unet/diffusion_pytorch_model.bin",
|
||||
"unet/diffusion_pytorch_model.safetensors",
|
||||
]
|
||||
self.assertTrue(is_safetensors_compatible(filenames))
|
||||
|
||||
def test_diffusers_model_is_not_compatible(self):
|
||||
filenames = [
|
||||
"safety_checker/pytorch_model.bin",
|
||||
"safety_checker/model.safetensors",
|
||||
"vae/diffusion_pytorch_model.bin",
|
||||
"vae/diffusion_pytorch_model.safetensors",
|
||||
"text_encoder/pytorch_model.bin",
|
||||
"text_encoder/model.safetensors",
|
||||
"unet/diffusion_pytorch_model.bin",
|
||||
# Removed: 'unet/diffusion_pytorch_model.safetensors',
|
||||
]
|
||||
self.assertFalse(is_safetensors_compatible(filenames))
|
||||
|
||||
def test_transformer_model_is_compatible(self):
|
||||
filenames = [
|
||||
"text_encoder/pytorch_model.bin",
|
||||
"text_encoder/model.safetensors",
|
||||
]
|
||||
self.assertTrue(is_safetensors_compatible(filenames))
|
||||
|
||||
def test_transformer_model_is_not_compatible(self):
|
||||
filenames = [
|
||||
"safety_checker/pytorch_model.bin",
|
||||
"safety_checker/model.safetensors",
|
||||
"vae/diffusion_pytorch_model.bin",
|
||||
"vae/diffusion_pytorch_model.safetensors",
|
||||
"text_encoder/pytorch_model.bin",
|
||||
# Removed: 'text_encoder/model.safetensors',
|
||||
"unet/diffusion_pytorch_model.bin",
|
||||
"unet/diffusion_pytorch_model.safetensors",
|
||||
]
|
||||
self.assertFalse(is_safetensors_compatible(filenames))
|
||||
|
||||
def test_all_is_compatible_variant(self):
|
||||
filenames = [
|
||||
"safety_checker/pytorch_model.fp16.bin",
|
||||
"safety_checker/model.fp16.safetensors",
|
||||
"vae/diffusion_pytorch_model.fp16.bin",
|
||||
"vae/diffusion_pytorch_model.fp16.safetensors",
|
||||
"text_encoder/pytorch_model.fp16.bin",
|
||||
"text_encoder/model.fp16.safetensors",
|
||||
"unet/diffusion_pytorch_model.fp16.bin",
|
||||
"unet/diffusion_pytorch_model.fp16.safetensors",
|
||||
]
|
||||
variant = "fp16"
|
||||
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))
|
||||
|
||||
def test_diffusers_model_is_compatible_variant(self):
|
||||
filenames = [
|
||||
"unet/diffusion_pytorch_model.fp16.bin",
|
||||
"unet/diffusion_pytorch_model.fp16.safetensors",
|
||||
]
|
||||
variant = "fp16"
|
||||
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))
|
||||
|
||||
def test_diffusers_model_is_compatible_variant_partial(self):
|
||||
# pass variant but use the non-variant filenames
|
||||
filenames = [
|
||||
"unet/diffusion_pytorch_model.bin",
|
||||
"unet/diffusion_pytorch_model.safetensors",
|
||||
]
|
||||
variant = "fp16"
|
||||
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))
|
||||
|
||||
def test_diffusers_model_is_not_compatible_variant(self):
|
||||
filenames = [
|
||||
"safety_checker/pytorch_model.fp16.bin",
|
||||
"safety_checker/model.fp16.safetensors",
|
||||
"vae/diffusion_pytorch_model.fp16.bin",
|
||||
"vae/diffusion_pytorch_model.fp16.safetensors",
|
||||
"text_encoder/pytorch_model.fp16.bin",
|
||||
"text_encoder/model.fp16.safetensors",
|
||||
"unet/diffusion_pytorch_model.fp16.bin",
|
||||
# Removed: 'unet/diffusion_pytorch_model.fp16.safetensors',
|
||||
]
|
||||
variant = "fp16"
|
||||
self.assertFalse(is_safetensors_compatible(filenames, variant=variant))
|
||||
|
||||
def test_transformer_model_is_compatible_variant(self):
|
||||
filenames = [
|
||||
"text_encoder/pytorch_model.fp16.bin",
|
||||
"text_encoder/model.fp16.safetensors",
|
||||
]
|
||||
variant = "fp16"
|
||||
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))
|
||||
|
||||
def test_transformer_model_is_compatible_variant_partial(self):
|
||||
# pass variant but use the non-variant filenames
|
||||
filenames = [
|
||||
"text_encoder/pytorch_model.bin",
|
||||
"text_encoder/model.safetensors",
|
||||
]
|
||||
variant = "fp16"
|
||||
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))
|
||||
|
||||
def test_transformer_model_is_not_compatible_variant(self):
|
||||
filenames = [
|
||||
"safety_checker/pytorch_model.fp16.bin",
|
||||
"safety_checker/model.fp16.safetensors",
|
||||
"vae/diffusion_pytorch_model.fp16.bin",
|
||||
"vae/diffusion_pytorch_model.fp16.safetensors",
|
||||
"text_encoder/pytorch_model.fp16.bin",
|
||||
# 'text_encoder/model.fp16.safetensors',
|
||||
"unet/diffusion_pytorch_model.fp16.bin",
|
||||
"unet/diffusion_pytorch_model.fp16.safetensors",
|
||||
]
|
||||
variant = "fp16"
|
||||
self.assertFalse(is_safetensors_compatible(filenames, variant=variant))
|
Loading…
Reference in New Issue