Fix dtype model loading (#1449)
* Add test * up * no bfloat16 for mps * fix * rename test
This commit is contained in:
parent
110ffe2589
commit
20ce68f945
|
@ -472,6 +472,21 @@ class ModelMixin(torch.nn.Module):
|
|||
model = cls.from_config(config, **unused_kwargs)
|
||||
|
||||
state_dict = load_state_dict(model_file)
|
||||
dtype = set(v.dtype for v in state_dict.values())
|
||||
|
||||
if len(dtype) > 1 and torch.float32 not in dtype:
|
||||
raise ValueError(
|
||||
f"The weights of the model file {model_file} have a mixture of incompatible dtypes {dtype}. Please"
|
||||
f" make sure that {model_file} weights have only one dtype."
|
||||
)
|
||||
elif len(dtype) > 1 and torch.float32 in dtype:
|
||||
dtype = torch.float32
|
||||
else:
|
||||
dtype = dtype.pop()
|
||||
|
||||
# move model to correct dtype
|
||||
model = model.to(dtype)
|
||||
|
||||
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
||||
model,
|
||||
state_dict,
|
||||
|
|
|
@ -63,8 +63,8 @@ class UNet1DModelTests(ModelTesterMixin, unittest.TestCase):
|
|||
super().test_outputs_equivalence()
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
super().test_from_pretrained_save_pretrained()
|
||||
def test_from_save_pretrained(self):
|
||||
super().test_from_save_pretrained()
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_model_from_pretrained(self):
|
||||
|
@ -183,8 +183,8 @@ class UNetRLModelTests(ModelTesterMixin, unittest.TestCase):
|
|||
super().test_outputs_equivalence()
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
super().test_from_pretrained_save_pretrained()
|
||||
def test_from_save_pretrained(self):
|
||||
super().test_from_save_pretrained()
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_model_from_pretrained(self):
|
||||
|
|
|
@ -42,7 +42,7 @@ class VersatileDiffusionMegaPipelineIntegrationTests(unittest.TestCase):
|
|||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
def test_from_save_pretrained(self):
|
||||
pipe = VersatileDiffusionPipeline.from_pretrained("shi-labs/versatile-diffusion", torch_dtype=torch.float16)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
|
|
@ -27,7 +27,7 @@ from diffusers.utils import torch_device
|
|||
|
||||
|
||||
class ModelTesterMixin:
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
def test_from_save_pretrained(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
|
@ -57,6 +57,24 @@ class ModelTesterMixin:
|
|||
max_diff = (image - new_image).abs().sum().item()
|
||||
self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
|
||||
|
||||
def test_from_save_pretrained_dtype(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
||||
if torch_device == "mps" and dtype == torch.bfloat16:
|
||||
continue
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.to(dtype)
|
||||
model.save_pretrained(tmpdirname)
|
||||
new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=True)
|
||||
assert new_model.dtype == dtype
|
||||
new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=False)
|
||||
assert new_model.dtype == dtype
|
||||
|
||||
def test_determinism(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
|
|
|
@ -659,7 +659,7 @@ class PipelineSlowTests(unittest.TestCase):
|
|||
== "Keyword arguments {'not_used': True} are not expected by DDPMPipeline and will be ignored.\n"
|
||||
)
|
||||
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
def test_from_save_pretrained(self):
|
||||
# 1. Load models
|
||||
model = UNet2DModel(
|
||||
block_out_channels=(32, 64),
|
||||
|
|
|
@ -334,7 +334,7 @@ class SchedulerCommonTest(unittest.TestCase):
|
|||
|
||||
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
def test_from_save_pretrained(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
@ -875,7 +875,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
|
|||
|
||||
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
def test_from_save_pretrained(self):
|
||||
pass
|
||||
|
||||
def check_over_forward(self, time_step=0, **forward_kwargs):
|
||||
|
@ -1068,7 +1068,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
|
|||
|
||||
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
def test_from_save_pretrained(self):
|
||||
pass
|
||||
|
||||
def check_over_forward(self, time_step=0, **forward_kwargs):
|
||||
|
@ -1745,7 +1745,7 @@ class IPNDMSchedulerTest(SchedulerCommonTest):
|
|||
|
||||
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
def test_from_save_pretrained(self):
|
||||
pass
|
||||
|
||||
def check_over_forward(self, time_step=0, **forward_kwargs):
|
||||
|
|
|
@ -126,7 +126,7 @@ class FlaxSchedulerCommonTest(unittest.TestCase):
|
|||
|
||||
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
def test_from_save_pretrained(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
@ -408,7 +408,7 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
|
|||
|
||||
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
def test_from_save_pretrained(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
@ -690,7 +690,7 @@ class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest):
|
|||
|
||||
assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
def test_from_save_pretrained(self):
|
||||
pass
|
||||
|
||||
def test_scheduler_outputs_equivalence(self):
|
||||
|
|
Loading…
Reference in New Issue