fix issues with loading, add test for pipeline
This commit is contained in:
parent
fe99460b5f
commit
d8287fcd1d
|
@ -190,7 +190,6 @@ class ConfigMixin:
|
|||
def extract_init_dict(cls, config_dict, **kwargs):
|
||||
expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys())
|
||||
expected_keys.remove("self")
|
||||
import ipdb; ipdb.set_trace()
|
||||
init_dict = {}
|
||||
for key in expected_keys:
|
||||
if key in kwargs:
|
||||
|
|
|
@ -56,19 +56,23 @@ class DiffusionPipeline(ConfigMixin):
|
|||
class_name = module.__class__.__name__
|
||||
|
||||
register_dict = {name: (library, class_name)}
|
||||
register_dict["_module"] = self.__module__
|
||||
|
||||
|
||||
# save model index config
|
||||
self.register(**register_dict)
|
||||
|
||||
# set models
|
||||
setattr(self, name, module)
|
||||
|
||||
register_dict = {"_module" : self.__module__.split(".")[-1] + ".py"}
|
||||
self.register(**register_dict)
|
||||
|
||||
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
|
||||
self.save_config(save_directory)
|
||||
|
||||
model_index_dict = self._dict_to_save
|
||||
model_index_dict.pop("_class_name")
|
||||
model_index_dict.pop("_module")
|
||||
|
||||
for name, (library_name, class_name) in self._dict_to_save.items():
|
||||
importable_classes = LOADABLE_CLASSES[library_name]
|
||||
|
@ -98,12 +102,17 @@ class DiffusionPipeline(ConfigMixin):
|
|||
cached_folder = pretrained_model_name_or_path
|
||||
|
||||
config_dict = cls.get_config_dict(cached_folder)
|
||||
|
||||
module = config_dict["_module"]
|
||||
class_name_ = config_dict["_class_name"]
|
||||
class_obj = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder)
|
||||
|
||||
if class_name_ == cls.__name__:
|
||||
pipeline_class = cls
|
||||
else:
|
||||
pipeline_class = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder)
|
||||
|
||||
|
||||
init_dict, unused = class_obj.extract_init_dict(config_dict, **kwargs)
|
||||
import ipdb; ipdb.set_trace()
|
||||
init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
||||
|
||||
init_kwargs = {}
|
||||
|
||||
|
@ -132,6 +141,5 @@ class DiffusionPipeline(ConfigMixin):
|
|||
|
||||
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
|
||||
|
||||
|
||||
model = class_obj(**init_kwargs)
|
||||
model = pipeline_class(**init_kwargs)
|
||||
return model
|
||||
|
|
|
@ -22,6 +22,8 @@ from distutils.util import strtobool
|
|||
import torch
|
||||
|
||||
from diffusers import GaussianDDPMScheduler, UNetModel
|
||||
from diffusers.pipeline_utils import DiffusionPipeline
|
||||
from models.vision.ddpm.modeling_ddpm import DDPM
|
||||
|
||||
|
||||
global_rng = random.Random()
|
||||
|
@ -199,3 +201,46 @@ class SamplerTesterMixin(unittest.TestCase):
|
|||
assert image.shape == (1, 3, 256, 256)
|
||||
image_slice = image[0, -1, -3:, -3:].cpu()
|
||||
assert (image_slice - torch.tensor([[0.1746, 0.5125, -0.7920], [-0.5734, -0.2910, -0.1984], [0.4090, -0.7740, -0.3941]])).abs().sum() < 1e-3
|
||||
|
||||
|
||||
class PipelineTesterMixin(unittest.TestCase):
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
# 1. Load models
|
||||
model = UNetModel(ch=32, ch_mult=(1, 2), num_res_blocks=2, attn_resolutions=(16,), resolution=32)
|
||||
schedular = GaussianDDPMScheduler(timesteps=10)
|
||||
|
||||
ddpm = DDPM(model, schedular)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
ddpm.save_pretrained(tmpdirname)
|
||||
new_ddpm = DDPM.from_pretrained(tmpdirname)
|
||||
|
||||
generator = torch.Generator()
|
||||
generator = generator.manual_seed(669472945848556)
|
||||
|
||||
image = ddpm(generator)
|
||||
generator = generator.manual_seed(669472945848556)
|
||||
new_image = new_ddpm(generator)
|
||||
|
||||
assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass"
|
||||
|
||||
|
||||
@slow
|
||||
def test_from_pretrained_hub(self):
|
||||
model_path = "fusing/ddpm-cifar10"
|
||||
|
||||
ddpm = DDPM.from_pretrained(model_path)
|
||||
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path)
|
||||
|
||||
ddpm.noise_scheduler.num_timesteps = 10
|
||||
ddpm_from_hub.noise_scheduler.num_timesteps = 10
|
||||
|
||||
|
||||
generator = torch.Generator(device=torch_device)
|
||||
generator = generator.manual_seed(669472945848556)
|
||||
|
||||
image = ddpm(generator)
|
||||
generator = generator.manual_seed(669472945848556)
|
||||
new_image = ddpm_from_hub(generator)
|
||||
|
||||
assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass"
|
||||
|
|
Loading…
Reference in New Issue