diffusers/tests/test_modeling_common.py

379 lines
15 KiB
Python

# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import tempfile
import unittest
import unittest.mock as mock
from typing import Dict, List, Tuple
import numpy as np
import torch
from requests.exceptions import HTTPError
from diffusers.models import ModelMixin, UNet2DConditionModel
from diffusers.training_utils import EMAModel
from diffusers.utils import torch_device
class ModelUtilsTest(unittest.TestCase):
def test_accelerate_loading_error_message(self):
with self.assertRaises(ValueError) as error_context:
UNet2DConditionModel.from_pretrained("hf-internal-testing/stable-diffusion-broken", subfolder="unet")
# make sure that error message states what keys are missing
assert "conv_out.bias" in str(error_context.exception)
def test_cached_files_are_used_when_no_internet(self):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
response_mock.status_code = 500
response_mock.headers = {}
response_mock.raise_for_status.side_effect = HTTPError
response_mock.json.return_value = {}
# Download this model to make sure it's in the cache.
orig_model = UNet2DConditionModel.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet"
)
# Under the mock environment we get a 500 error when trying to reach the model.
with mock.patch("requests.request", return_value=response_mock):
# Download this model to make sure it's in the cache.
model = UNet2DConditionModel.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet", local_files_only=True
)
for p1, p2 in zip(orig_model.parameters(), model.parameters()):
if p1.data.ne(p2.data).sum() > 0:
assert False, "Parameters not the same!"
class ModelTesterMixin:
def test_from_save_pretrained(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
new_model = self.model_class.from_pretrained(tmpdirname)
new_model.to(torch_device)
with torch.no_grad():
# Warmup pass when using mps (see #372)
if torch_device == "mps" and isinstance(model, ModelMixin):
_ = model(**self.dummy_input)
_ = new_model(**self.dummy_input)
image = model(**inputs_dict)
if isinstance(image, dict):
image = image.sample
new_image = new_model(**inputs_dict)
if isinstance(new_image, dict):
new_image = new_image.sample
max_diff = (image - new_image).abs().sum().item()
self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
def test_from_save_pretrained_variant(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, variant="fp16")
new_model = self.model_class.from_pretrained(tmpdirname, variant="fp16")
# non-variant cannot be loaded
with self.assertRaises(OSError) as error_context:
self.model_class.from_pretrained(tmpdirname)
# make sure that error message states what keys are missing
assert "Error no file named diffusion_pytorch_model.bin found in directory" in str(error_context.exception)
new_model.to(torch_device)
with torch.no_grad():
# Warmup pass when using mps (see #372)
if torch_device == "mps" and isinstance(model, ModelMixin):
_ = model(**self.dummy_input)
_ = new_model(**self.dummy_input)
image = model(**inputs_dict)
if isinstance(image, dict):
image = image.sample
new_image = new_model(**inputs_dict)
if isinstance(new_image, dict):
new_image = new_image.sample
max_diff = (image - new_image).abs().sum().item()
self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
def test_from_save_pretrained_dtype(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
if torch_device == "mps" and dtype == torch.bfloat16:
continue
with tempfile.TemporaryDirectory() as tmpdirname:
model.to(dtype)
model.save_pretrained(tmpdirname)
new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=True, torch_dtype=dtype)
assert new_model.dtype == dtype
new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=False, torch_dtype=dtype)
assert new_model.dtype == dtype
def test_determinism(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
# Warmup pass when using mps (see #372)
if torch_device == "mps" and isinstance(model, ModelMixin):
model(**self.dummy_input)
first = model(**inputs_dict)
if isinstance(first, dict):
first = first.sample
second = model(**inputs_dict)
if isinstance(second, dict):
second = second.sample
out_1 = first.cpu().numpy()
out_2 = second.cpu().numpy()
out_1 = out_1[~np.isnan(out_1)]
out_2 = out_2[~np.isnan(out_2)]
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5)
def test_output(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_forward_with_norm_groups(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["norm_num_groups"] = 16
init_dict["block_out_channels"] = (16, 32)
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_forward_signature(self):
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["sample", "timestep"]
self.assertListEqual(arg_names[:2], expected_arg_names)
def test_model_from_pretrained(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
# test if the model can be loaded from the config
# and has all the expected shape
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
new_model = self.model_class.from_pretrained(tmpdirname)
new_model.to(torch_device)
new_model.eval()
# check if all parameters shape are the same
for param_name in model.state_dict().keys():
param_1 = model.state_dict()[param_name]
param_2 = new_model.state_dict()[param_name]
self.assertEqual(param_1.shape, param_2.shape)
with torch.no_grad():
output_1 = model(**inputs_dict)
if isinstance(output_1, dict):
output_1 = output_1.sample
output_2 = new_model(**inputs_dict)
if isinstance(output_2, dict):
output_2 = output_2.sample
self.assertEqual(output_1.shape, output_2.shape)
@unittest.skipIf(torch_device == "mps", "Training is not supported in mps")
def test_training(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.train()
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.sample
noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device)
loss = torch.nn.functional.mse_loss(output, noise)
loss.backward()
@unittest.skipIf(torch_device == "mps", "Training is not supported in mps")
def test_ema_training(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.train()
ema_model = EMAModel(model.parameters())
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.sample
noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device)
loss = torch.nn.functional.mse_loss(output, noise)
loss.backward()
ema_model.step(model.parameters())
def test_outputs_equivalence(self):
def set_nan_tensor_to_zero(t):
# Temporary fallback until `aten::_index_put_impl_` is implemented in mps
# Track progress in https://github.com/pytorch/pytorch/issues/77764
device = t.device
if device.type == "mps":
t = t.to("cpu")
t[t != t] = 0
return t.to(device)
def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, (List, Tuple)):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif isinstance(tuple_object, Dict):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
return
else:
self.assertTrue(
torch.allclose(
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
),
msg=(
"Tuple and dict output are not equal. Difference:"
f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
),
)
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
# Warmup pass when using mps (see #372)
if torch_device == "mps" and isinstance(model, ModelMixin):
model(**self.dummy_input)
outputs_dict = model(**inputs_dict)
outputs_tuple = model(**inputs_dict, return_dict=False)
recursive_check(outputs_tuple, outputs_dict)
@unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
def test_enable_disable_gradient_checkpointing(self):
if not self.model_class._supports_gradient_checkpointing:
return # Skip test if model does not support gradient checkpointing
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
# at init model should have gradient checkpointing disabled
model = self.model_class(**init_dict)
self.assertFalse(model.is_gradient_checkpointing)
# check enable works
model.enable_gradient_checkpointing()
self.assertTrue(model.is_gradient_checkpointing)
# check disable works
model.disable_gradient_checkpointing()
self.assertFalse(model.is_gradient_checkpointing)
def test_deprecated_kwargs(self):
has_kwarg_in_model_class = "kwargs" in inspect.signature(self.model_class.__init__).parameters
has_deprecated_kwarg = len(self.model_class._deprecated_kwargs) > 0
if has_kwarg_in_model_class and not has_deprecated_kwarg:
raise ValueError(
f"{self.model_class} has `**kwargs` in its __init__ method but has not defined any deprecated kwargs"
" under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if there are"
" no deprecated arguments or add the deprecated argument with `_deprecated_kwargs ="
" [<deprecated_argument>]`"
)
if not has_kwarg_in_model_class and has_deprecated_kwarg:
raise ValueError(
f"{self.model_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated kwargs"
" under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs` argument to"
f" {self.model_class}.__init__ if there are deprecated arguments or remove the deprecated argument"
" from `_deprecated_kwargs = [<deprecated_argument>]`"
)