fix issues with loading, add test for pipeline

This commit is contained in:
patil-suraj 2022-06-07 15:39:47 +02:00
parent fe99460b5f
commit d8287fcd1d
4 changed files with 59 additions and 7 deletions

View File

@ -190,7 +190,6 @@ class ConfigMixin:
def extract_init_dict(cls, config_dict, **kwargs):
expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys())
expected_keys.remove("self")
import ipdb; ipdb.set_trace()
init_dict = {}
for key in expected_keys:
if key in kwargs:

View File

@ -56,7 +56,7 @@ class DiffusionPipeline(ConfigMixin):
class_name = module.__class__.__name__
register_dict = {name: (library, class_name)}
register_dict["_module"] = self.__module__
# save model index config
self.register(**register_dict)
@ -64,11 +64,15 @@ class DiffusionPipeline(ConfigMixin):
# set models
setattr(self, name, module)
register_dict = {"_module" : self.__module__.split(".")[-1] + ".py"}
self.register(**register_dict)
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
self.save_config(save_directory)
model_index_dict = self._dict_to_save
model_index_dict.pop("_class_name")
model_index_dict.pop("_module")
for name, (library_name, class_name) in self._dict_to_save.items():
importable_classes = LOADABLE_CLASSES[library_name]
@ -98,12 +102,17 @@ class DiffusionPipeline(ConfigMixin):
cached_folder = pretrained_model_name_or_path
config_dict = cls.get_config_dict(cached_folder)
module = config_dict["_module"]
class_name_ = config_dict["_class_name"]
class_obj = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder)
init_dict, unused = class_obj.extract_init_dict(config_dict, **kwargs)
import ipdb; ipdb.set_trace()
if class_name_ == cls.__name__:
pipeline_class = cls
else:
pipeline_class = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder)
init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
init_kwargs = {}
@ -132,6 +141,5 @@ class DiffusionPipeline(ConfigMixin):
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
model = class_obj(**init_kwargs)
model = pipeline_class(**init_kwargs)
return model

0
tests/__init__.py Normal file
View File

View File

@ -22,6 +22,8 @@ from distutils.util import strtobool
import torch
from diffusers import GaussianDDPMScheduler, UNetModel
from diffusers.pipeline_utils import DiffusionPipeline
from models.vision.ddpm.modeling_ddpm import DDPM
global_rng = random.Random()
@ -199,3 +201,46 @@ class SamplerTesterMixin(unittest.TestCase):
assert image.shape == (1, 3, 256, 256)
image_slice = image[0, -1, -3:, -3:].cpu()
assert (image_slice - torch.tensor([[0.1746, 0.5125, -0.7920], [-0.5734, -0.2910, -0.1984], [0.4090, -0.7740, -0.3941]])).abs().sum() < 1e-3
class PipelineTesterMixin(unittest.TestCase):
def test_from_pretrained_save_pretrained(self):
# 1. Load models
model = UNetModel(ch=32, ch_mult=(1, 2), num_res_blocks=2, attn_resolutions=(16,), resolution=32)
schedular = GaussianDDPMScheduler(timesteps=10)
ddpm = DDPM(model, schedular)
with tempfile.TemporaryDirectory() as tmpdirname:
ddpm.save_pretrained(tmpdirname)
new_ddpm = DDPM.from_pretrained(tmpdirname)
generator = torch.Generator()
generator = generator.manual_seed(669472945848556)
image = ddpm(generator)
generator = generator.manual_seed(669472945848556)
new_image = new_ddpm(generator)
assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass"
@slow
def test_from_pretrained_hub(self):
model_path = "fusing/ddpm-cifar10"
ddpm = DDPM.from_pretrained(model_path)
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path)
ddpm.noise_scheduler.num_timesteps = 10
ddpm_from_hub.noise_scheduler.num_timesteps = 10
generator = torch.Generator(device=torch_device)
generator = generator.manual_seed(669472945848556)
image = ddpm(generator)
generator = generator.manual_seed(669472945848556)
new_image = ddpm_from_hub(generator)
assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass"