From 20ce68f945de7860f9854cd7ee680debf4a07fe5 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 30 Nov 2022 11:31:50 +0100 Subject: [PATCH] Fix dtype model loading (#1449) * Add test * up * no bfloat16 for mps * fix * rename test --- src/diffusers/modeling_utils.py | 15 ++++++++++++++ tests/models/test_models_unet_1d.py | 8 ++++---- .../test_versatile_diffusion_mega.py | 2 +- tests/test_modeling_common.py | 20 ++++++++++++++++++- tests/test_pipelines.py | 2 +- tests/test_scheduler.py | 8 ++++---- tests/test_scheduler_flax.py | 6 +++--- 7 files changed, 47 insertions(+), 14 deletions(-) diff --git a/src/diffusers/modeling_utils.py b/src/diffusers/modeling_utils.py index 5f79e7fe..bfcba291 100644 --- a/src/diffusers/modeling_utils.py +++ b/src/diffusers/modeling_utils.py @@ -472,6 +472,21 @@ class ModelMixin(torch.nn.Module): model = cls.from_config(config, **unused_kwargs) state_dict = load_state_dict(model_file) + dtype = set(v.dtype for v in state_dict.values()) + + if len(dtype) > 1 and torch.float32 not in dtype: + raise ValueError( + f"The weights of the model file {model_file} have a mixture of incompatible dtypes {dtype}. Please" + f" make sure that {model_file} weights have only one dtype." + ) + elif len(dtype) > 1 and torch.float32 in dtype: + dtype = torch.float32 + else: + dtype = dtype.pop() + + # move model to correct dtype + model = model.to(dtype) + model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( model, state_dict, diff --git a/tests/models/test_models_unet_1d.py b/tests/models/test_models_unet_1d.py index 089d9356..b494c231 100644 --- a/tests/models/test_models_unet_1d.py +++ b/tests/models/test_models_unet_1d.py @@ -63,8 +63,8 @@ class UNet1DModelTests(ModelTesterMixin, unittest.TestCase): super().test_outputs_equivalence() @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") - def test_from_pretrained_save_pretrained(self): - super().test_from_pretrained_save_pretrained() + def test_from_save_pretrained(self): + super().test_from_save_pretrained() @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") def test_model_from_pretrained(self): @@ -183,8 +183,8 @@ class UNetRLModelTests(ModelTesterMixin, unittest.TestCase): super().test_outputs_equivalence() @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") - def test_from_pretrained_save_pretrained(self): - super().test_from_pretrained_save_pretrained() + def test_from_save_pretrained(self): + super().test_from_save_pretrained() @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") def test_model_from_pretrained(self): diff --git a/tests/pipelines/versatile_diffusion/test_versatile_diffusion_mega.py b/tests/pipelines/versatile_diffusion/test_versatile_diffusion_mega.py index 9387d141..31085aeb 100644 --- a/tests/pipelines/versatile_diffusion/test_versatile_diffusion_mega.py +++ b/tests/pipelines/versatile_diffusion/test_versatile_diffusion_mega.py @@ -42,7 +42,7 @@ class VersatileDiffusionMegaPipelineIntegrationTests(unittest.TestCase): gc.collect() torch.cuda.empty_cache() - def test_from_pretrained_save_pretrained(self): + def test_from_save_pretrained(self): pipe = VersatileDiffusionPipeline.from_pretrained("shi-labs/versatile-diffusion", torch_dtype=torch.float16) pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index cad1887f..68ab914b 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -27,7 +27,7 @@ from diffusers.utils import torch_device class ModelTesterMixin: - def test_from_pretrained_save_pretrained(self): + def test_from_save_pretrained(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict) @@ -57,6 +57,24 @@ class ModelTesterMixin: max_diff = (image - new_image).abs().sum().item() self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes") + def test_from_save_pretrained_dtype(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + for dtype in [torch.float32, torch.float16, torch.bfloat16]: + if torch_device == "mps" and dtype == torch.bfloat16: + continue + with tempfile.TemporaryDirectory() as tmpdirname: + model.to(dtype) + model.save_pretrained(tmpdirname) + new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=True) + assert new_model.dtype == dtype + new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=False) + assert new_model.dtype == dtype + def test_determinism(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict) diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 6ae11e12..617bde43 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -659,7 +659,7 @@ class PipelineSlowTests(unittest.TestCase): == "Keyword arguments {'not_used': True} are not expected by DDPMPipeline and will be ignored.\n" ) - def test_from_pretrained_save_pretrained(self): + def test_from_save_pretrained(self): # 1. Load models model = UNet2DModel( block_out_channels=(32, 64), diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index b6008066..3962de9a 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -334,7 +334,7 @@ class SchedulerCommonTest(unittest.TestCase): assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - def test_from_pretrained_save_pretrained(self): + def test_from_save_pretrained(self): kwargs = dict(self.forward_default_kwargs) num_inference_steps = kwargs.pop("num_inference_steps", None) @@ -875,7 +875,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest): assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - def test_from_pretrained_save_pretrained(self): + def test_from_save_pretrained(self): pass def check_over_forward(self, time_step=0, **forward_kwargs): @@ -1068,7 +1068,7 @@ class PNDMSchedulerTest(SchedulerCommonTest): assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - def test_from_pretrained_save_pretrained(self): + def test_from_save_pretrained(self): pass def check_over_forward(self, time_step=0, **forward_kwargs): @@ -1745,7 +1745,7 @@ class IPNDMSchedulerTest(SchedulerCommonTest): assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - def test_from_pretrained_save_pretrained(self): + def test_from_save_pretrained(self): pass def check_over_forward(self, time_step=0, **forward_kwargs): diff --git a/tests/test_scheduler_flax.py b/tests/test_scheduler_flax.py index 5ada689b..da1042f3 100644 --- a/tests/test_scheduler_flax.py +++ b/tests/test_scheduler_flax.py @@ -126,7 +126,7 @@ class FlaxSchedulerCommonTest(unittest.TestCase): assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - def test_from_pretrained_save_pretrained(self): + def test_from_save_pretrained(self): kwargs = dict(self.forward_default_kwargs) num_inference_steps = kwargs.pop("num_inference_steps", None) @@ -408,7 +408,7 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest): assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - def test_from_pretrained_save_pretrained(self): + def test_from_save_pretrained(self): kwargs = dict(self.forward_default_kwargs) num_inference_steps = kwargs.pop("num_inference_steps", None) @@ -690,7 +690,7 @@ class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest): assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - def test_from_pretrained_save_pretrained(self): + def test_from_save_pretrained(self): pass def test_scheduler_outputs_equivalence(self):