diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index ca61120f..721f13a2 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -190,7 +190,6 @@ class ConfigMixin: def extract_init_dict(cls, config_dict, **kwargs): expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys()) expected_keys.remove("self") - import ipdb; ipdb.set_trace() init_dict = {} for key in expected_keys: if key in kwargs: diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index ba3a823e..1d9a2fd9 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -56,19 +56,23 @@ class DiffusionPipeline(ConfigMixin): class_name = module.__class__.__name__ register_dict = {name: (library, class_name)} - register_dict["_module"] = self.__module__ + # save model index config self.register(**register_dict) # set models setattr(self, name, module) + + register_dict = {"_module" : self.__module__.split(".")[-1] + ".py"} + self.register(**register_dict) def save_pretrained(self, save_directory: Union[str, os.PathLike]): self.save_config(save_directory) model_index_dict = self._dict_to_save model_index_dict.pop("_class_name") + model_index_dict.pop("_module") for name, (library_name, class_name) in self._dict_to_save.items(): importable_classes = LOADABLE_CLASSES[library_name] @@ -98,12 +102,17 @@ class DiffusionPipeline(ConfigMixin): cached_folder = pretrained_model_name_or_path config_dict = cls.get_config_dict(cached_folder) + module = config_dict["_module"] class_name_ = config_dict["_class_name"] - class_obj = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder) + + if class_name_ == cls.__name__: + pipeline_class = cls + else: + pipeline_class = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder) + - init_dict, unused = class_obj.extract_init_dict(config_dict, **kwargs) - import ipdb; ipdb.set_trace() + init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) init_kwargs = {} @@ -132,6 +141,5 @@ class DiffusionPipeline(ConfigMixin): init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) - - model = class_obj(**init_kwargs) + model = pipeline_class(**init_kwargs) return model diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 6dce91ae..04e7ddc5 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -22,6 +22,8 @@ from distutils.util import strtobool import torch from diffusers import GaussianDDPMScheduler, UNetModel +from diffusers.pipeline_utils import DiffusionPipeline +from models.vision.ddpm.modeling_ddpm import DDPM global_rng = random.Random() @@ -199,3 +201,46 @@ class SamplerTesterMixin(unittest.TestCase): assert image.shape == (1, 3, 256, 256) image_slice = image[0, -1, -3:, -3:].cpu() assert (image_slice - torch.tensor([[0.1746, 0.5125, -0.7920], [-0.5734, -0.2910, -0.1984], [0.4090, -0.7740, -0.3941]])).abs().sum() < 1e-3 + + +class PipelineTesterMixin(unittest.TestCase): + def test_from_pretrained_save_pretrained(self): + # 1. Load models + model = UNetModel(ch=32, ch_mult=(1, 2), num_res_blocks=2, attn_resolutions=(16,), resolution=32) + schedular = GaussianDDPMScheduler(timesteps=10) + + ddpm = DDPM(model, schedular) + + with tempfile.TemporaryDirectory() as tmpdirname: + ddpm.save_pretrained(tmpdirname) + new_ddpm = DDPM.from_pretrained(tmpdirname) + + generator = torch.Generator() + generator = generator.manual_seed(669472945848556) + + image = ddpm(generator) + generator = generator.manual_seed(669472945848556) + new_image = new_ddpm(generator) + + assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass" + + + @slow + def test_from_pretrained_hub(self): + model_path = "fusing/ddpm-cifar10" + + ddpm = DDPM.from_pretrained(model_path) + ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path) + + ddpm.noise_scheduler.num_timesteps = 10 + ddpm_from_hub.noise_scheduler.num_timesteps = 10 + + + generator = torch.Generator(device=torch_device) + generator = generator.manual_seed(669472945848556) + + image = ddpm(generator) + generator = generator.manual_seed(669472945848556) + new_image = ddpm_from_hub(generator) + + assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass"