Fix dtype model loading (#1449)

* Add test

* up

* no bfloat16 for mps

* fix

* rename test
This commit is contained in:
Patrick von Platen 2022-11-30 11:31:50 +01:00 committed by GitHub
parent 110ffe2589
commit 20ce68f945
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 47 additions and 14 deletions

View File

@ -472,6 +472,21 @@ class ModelMixin(torch.nn.Module):
model = cls.from_config(config, **unused_kwargs)
state_dict = load_state_dict(model_file)
dtype = set(v.dtype for v in state_dict.values())
if len(dtype) > 1 and torch.float32 not in dtype:
raise ValueError(
f"The weights of the model file {model_file} have a mixture of incompatible dtypes {dtype}. Please"
f" make sure that {model_file} weights have only one dtype."
)
elif len(dtype) > 1 and torch.float32 in dtype:
dtype = torch.float32
else:
dtype = dtype.pop()
# move model to correct dtype
model = model.to(dtype)
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
model,
state_dict,

View File

@ -63,8 +63,8 @@ class UNet1DModelTests(ModelTesterMixin, unittest.TestCase):
super().test_outputs_equivalence()
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_from_pretrained_save_pretrained(self):
super().test_from_pretrained_save_pretrained()
def test_from_save_pretrained(self):
super().test_from_save_pretrained()
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_model_from_pretrained(self):
@ -183,8 +183,8 @@ class UNetRLModelTests(ModelTesterMixin, unittest.TestCase):
super().test_outputs_equivalence()
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_from_pretrained_save_pretrained(self):
super().test_from_pretrained_save_pretrained()
def test_from_save_pretrained(self):
super().test_from_save_pretrained()
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_model_from_pretrained(self):

View File

@ -42,7 +42,7 @@ class VersatileDiffusionMegaPipelineIntegrationTests(unittest.TestCase):
gc.collect()
torch.cuda.empty_cache()
def test_from_pretrained_save_pretrained(self):
def test_from_save_pretrained(self):
pipe = VersatileDiffusionPipeline.from_pretrained("shi-labs/versatile-diffusion", torch_dtype=torch.float16)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

View File

@ -27,7 +27,7 @@ from diffusers.utils import torch_device
class ModelTesterMixin:
def test_from_pretrained_save_pretrained(self):
def test_from_save_pretrained(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
@ -57,6 +57,24 @@ class ModelTesterMixin:
max_diff = (image - new_image).abs().sum().item()
self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
def test_from_save_pretrained_dtype(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
if torch_device == "mps" and dtype == torch.bfloat16:
continue
with tempfile.TemporaryDirectory() as tmpdirname:
model.to(dtype)
model.save_pretrained(tmpdirname)
new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=True)
assert new_model.dtype == dtype
new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=False)
assert new_model.dtype == dtype
def test_determinism(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)

View File

@ -659,7 +659,7 @@ class PipelineSlowTests(unittest.TestCase):
== "Keyword arguments {'not_used': True} are not expected by DDPMPipeline and will be ignored.\n"
)
def test_from_pretrained_save_pretrained(self):
def test_from_save_pretrained(self):
# 1. Load models
model = UNet2DModel(
block_out_channels=(32, 64),

View File

@ -334,7 +334,7 @@ class SchedulerCommonTest(unittest.TestCase):
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def test_from_pretrained_save_pretrained(self):
def test_from_save_pretrained(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
@ -875,7 +875,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def test_from_pretrained_save_pretrained(self):
def test_from_save_pretrained(self):
pass
def check_over_forward(self, time_step=0, **forward_kwargs):
@ -1068,7 +1068,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def test_from_pretrained_save_pretrained(self):
def test_from_save_pretrained(self):
pass
def check_over_forward(self, time_step=0, **forward_kwargs):
@ -1745,7 +1745,7 @@ class IPNDMSchedulerTest(SchedulerCommonTest):
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def test_from_pretrained_save_pretrained(self):
def test_from_save_pretrained(self):
pass
def check_over_forward(self, time_step=0, **forward_kwargs):

View File

@ -126,7 +126,7 @@ class FlaxSchedulerCommonTest(unittest.TestCase):
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def test_from_pretrained_save_pretrained(self):
def test_from_save_pretrained(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
@ -408,7 +408,7 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def test_from_pretrained_save_pretrained(self):
def test_from_save_pretrained(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
@ -690,7 +690,7 @@ class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest):
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def test_from_pretrained_save_pretrained(self):
def test_from_save_pretrained(self):
pass
def test_scheduler_outputs_equivalence(self):