# coding=utf-8 # Copyright 2022 HuggingFace Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import tempfile import unittest import unittest.mock as mock from typing import Dict, List, Tuple import numpy as np import torch from requests.exceptions import HTTPError from diffusers.models import ModelMixin, UNet2DConditionModel from diffusers.training_utils import EMAModel from diffusers.utils import torch_device class ModelUtilsTest(unittest.TestCase): def test_accelerate_loading_error_message(self): with self.assertRaises(ValueError) as error_context: UNet2DConditionModel.from_pretrained("hf-internal-testing/stable-diffusion-broken", subfolder="unet") # make sure that error message states what keys are missing assert "conv_out.bias" in str(error_context.exception) def test_cached_files_are_used_when_no_internet(self): # A mock response for an HTTP head request to emulate server down response_mock = mock.Mock() response_mock.status_code = 500 response_mock.headers = {} response_mock.raise_for_status.side_effect = HTTPError response_mock.json.return_value = {} # Download this model to make sure it's in the cache. orig_model = UNet2DConditionModel.from_pretrained( "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet" ) # Under the mock environment we get a 500 error when trying to reach the model. with mock.patch("requests.request", return_value=response_mock): # Download this model to make sure it's in the cache. model = UNet2DConditionModel.from_pretrained( "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet", local_files_only=True ) for p1, p2 in zip(orig_model.parameters(), model.parameters()): if p1.data.ne(p2.data).sum() > 0: assert False, "Parameters not the same!" class ModelTesterMixin: def test_from_save_pretrained(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict) model.to(torch_device) model.eval() with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) new_model = self.model_class.from_pretrained(tmpdirname) new_model.to(torch_device) with torch.no_grad(): # Warmup pass when using mps (see #372) if torch_device == "mps" and isinstance(model, ModelMixin): _ = model(**self.dummy_input) _ = new_model(**self.dummy_input) image = model(**inputs_dict) if isinstance(image, dict): image = image.sample new_image = new_model(**inputs_dict) if isinstance(new_image, dict): new_image = new_image.sample max_diff = (image - new_image).abs().sum().item() self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes") def test_from_save_pretrained_variant(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict) model.to(torch_device) model.eval() with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname, variant="fp16") new_model = self.model_class.from_pretrained(tmpdirname, variant="fp16") # non-variant cannot be loaded with self.assertRaises(OSError) as error_context: self.model_class.from_pretrained(tmpdirname) # make sure that error message states what keys are missing assert "Error no file named diffusion_pytorch_model.bin found in directory" in str(error_context.exception) new_model.to(torch_device) with torch.no_grad(): # Warmup pass when using mps (see #372) if torch_device == "mps" and isinstance(model, ModelMixin): _ = model(**self.dummy_input) _ = new_model(**self.dummy_input) image = model(**inputs_dict) if isinstance(image, dict): image = image.sample new_image = new_model(**inputs_dict) if isinstance(new_image, dict): new_image = new_image.sample max_diff = (image - new_image).abs().sum().item() self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes") def test_from_save_pretrained_dtype(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict) model.to(torch_device) model.eval() for dtype in [torch.float32, torch.float16, torch.bfloat16]: if torch_device == "mps" and dtype == torch.bfloat16: continue with tempfile.TemporaryDirectory() as tmpdirname: model.to(dtype) model.save_pretrained(tmpdirname) new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=True, torch_dtype=dtype) assert new_model.dtype == dtype new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=False, torch_dtype=dtype) assert new_model.dtype == dtype def test_determinism(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict) model.to(torch_device) model.eval() with torch.no_grad(): # Warmup pass when using mps (see #372) if torch_device == "mps" and isinstance(model, ModelMixin): model(**self.dummy_input) first = model(**inputs_dict) if isinstance(first, dict): first = first.sample second = model(**inputs_dict) if isinstance(second, dict): second = second.sample out_1 = first.cpu().numpy() out_2 = second.cpu().numpy() out_1 = out_1[~np.isnan(out_1)] out_2 = out_2[~np.isnan(out_2)] max_diff = np.amax(np.abs(out_1 - out_2)) self.assertLessEqual(max_diff, 1e-5) def test_output(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict) model.to(torch_device) model.eval() with torch.no_grad(): output = model(**inputs_dict) if isinstance(output, dict): output = output.sample self.assertIsNotNone(output) expected_shape = inputs_dict["sample"].shape self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") def test_forward_with_norm_groups(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict["norm_num_groups"] = 16 init_dict["block_out_channels"] = (16, 32) model = self.model_class(**init_dict) model.to(torch_device) model.eval() with torch.no_grad(): output = model(**inputs_dict) if isinstance(output, dict): output = output.sample self.assertIsNotNone(output) expected_shape = inputs_dict["sample"].shape self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") def test_forward_signature(self): init_dict, _ = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict) signature = inspect.signature(model.forward) # signature.parameters is an OrderedDict => so arg_names order is deterministic arg_names = [*signature.parameters.keys()] expected_arg_names = ["sample", "timestep"] self.assertListEqual(arg_names[:2], expected_arg_names) def test_model_from_pretrained(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict) model.to(torch_device) model.eval() # test if the model can be loaded from the config # and has all the expected shape with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) new_model = self.model_class.from_pretrained(tmpdirname) new_model.to(torch_device) new_model.eval() # check if all parameters shape are the same for param_name in model.state_dict().keys(): param_1 = model.state_dict()[param_name] param_2 = new_model.state_dict()[param_name] self.assertEqual(param_1.shape, param_2.shape) with torch.no_grad(): output_1 = model(**inputs_dict) if isinstance(output_1, dict): output_1 = output_1.sample output_2 = new_model(**inputs_dict) if isinstance(output_2, dict): output_2 = output_2.sample self.assertEqual(output_1.shape, output_2.shape) @unittest.skipIf(torch_device == "mps", "Training is not supported in mps") def test_training(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict) model.to(torch_device) model.train() output = model(**inputs_dict) if isinstance(output, dict): output = output.sample noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device) loss = torch.nn.functional.mse_loss(output, noise) loss.backward() @unittest.skipIf(torch_device == "mps", "Training is not supported in mps") def test_ema_training(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict) model.to(torch_device) model.train() ema_model = EMAModel(model.parameters()) output = model(**inputs_dict) if isinstance(output, dict): output = output.sample noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device) loss = torch.nn.functional.mse_loss(output, noise) loss.backward() ema_model.step(model.parameters()) def test_outputs_equivalence(self): def set_nan_tensor_to_zero(t): # Temporary fallback until `aten::_index_put_impl_` is implemented in mps # Track progress in https://github.com/pytorch/pytorch/issues/77764 device = t.device if device.type == "mps": t = t.to("cpu") t[t != t] = 0 return t.to(device) def recursive_check(tuple_object, dict_object): if isinstance(tuple_object, (List, Tuple)): for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()): recursive_check(tuple_iterable_value, dict_iterable_value) elif isinstance(tuple_object, Dict): for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()): recursive_check(tuple_iterable_value, dict_iterable_value) elif tuple_object is None: return else: self.assertTrue( torch.allclose( set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5 ), msg=( "Tuple and dict output are not equal. Difference:" f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:" f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has" f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}." ), ) init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict) model.to(torch_device) model.eval() with torch.no_grad(): # Warmup pass when using mps (see #372) if torch_device == "mps" and isinstance(model, ModelMixin): model(**self.dummy_input) outputs_dict = model(**inputs_dict) outputs_tuple = model(**inputs_dict, return_dict=False) recursive_check(outputs_tuple, outputs_dict) @unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS") def test_enable_disable_gradient_checkpointing(self): if not self.model_class._supports_gradient_checkpointing: return # Skip test if model does not support gradient checkpointing init_dict, _ = self.prepare_init_args_and_inputs_for_common() # at init model should have gradient checkpointing disabled model = self.model_class(**init_dict) self.assertFalse(model.is_gradient_checkpointing) # check enable works model.enable_gradient_checkpointing() self.assertTrue(model.is_gradient_checkpointing) # check disable works model.disable_gradient_checkpointing() self.assertFalse(model.is_gradient_checkpointing) def test_deprecated_kwargs(self): has_kwarg_in_model_class = "kwargs" in inspect.signature(self.model_class.__init__).parameters has_deprecated_kwarg = len(self.model_class._deprecated_kwargs) > 0 if has_kwarg_in_model_class and not has_deprecated_kwarg: raise ValueError( f"{self.model_class} has `**kwargs` in its __init__ method but has not defined any deprecated kwargs" " under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if there are" " no deprecated arguments or add the deprecated argument with `_deprecated_kwargs =" " []`" ) if not has_kwarg_in_model_class and has_deprecated_kwarg: raise ValueError( f"{self.model_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated kwargs" " under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs` argument to" f" {self.model_class}.__init__ if there are deprecated arguments or remove the deprecated argument" " from `_deprecated_kwargs = []`" )