fix issues with loading, add test for pipeline
This commit is contained in:
parent
fe99460b5f
commit
d8287fcd1d
|
@ -190,7 +190,6 @@ class ConfigMixin:
|
||||||
def extract_init_dict(cls, config_dict, **kwargs):
|
def extract_init_dict(cls, config_dict, **kwargs):
|
||||||
expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys())
|
expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys())
|
||||||
expected_keys.remove("self")
|
expected_keys.remove("self")
|
||||||
import ipdb; ipdb.set_trace()
|
|
||||||
init_dict = {}
|
init_dict = {}
|
||||||
for key in expected_keys:
|
for key in expected_keys:
|
||||||
if key in kwargs:
|
if key in kwargs:
|
||||||
|
|
|
@ -56,19 +56,23 @@ class DiffusionPipeline(ConfigMixin):
|
||||||
class_name = module.__class__.__name__
|
class_name = module.__class__.__name__
|
||||||
|
|
||||||
register_dict = {name: (library, class_name)}
|
register_dict = {name: (library, class_name)}
|
||||||
register_dict["_module"] = self.__module__
|
|
||||||
|
|
||||||
# save model index config
|
# save model index config
|
||||||
self.register(**register_dict)
|
self.register(**register_dict)
|
||||||
|
|
||||||
# set models
|
# set models
|
||||||
setattr(self, name, module)
|
setattr(self, name, module)
|
||||||
|
|
||||||
|
register_dict = {"_module" : self.__module__.split(".")[-1] + ".py"}
|
||||||
|
self.register(**register_dict)
|
||||||
|
|
||||||
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
|
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
|
||||||
self.save_config(save_directory)
|
self.save_config(save_directory)
|
||||||
|
|
||||||
model_index_dict = self._dict_to_save
|
model_index_dict = self._dict_to_save
|
||||||
model_index_dict.pop("_class_name")
|
model_index_dict.pop("_class_name")
|
||||||
|
model_index_dict.pop("_module")
|
||||||
|
|
||||||
for name, (library_name, class_name) in self._dict_to_save.items():
|
for name, (library_name, class_name) in self._dict_to_save.items():
|
||||||
importable_classes = LOADABLE_CLASSES[library_name]
|
importable_classes = LOADABLE_CLASSES[library_name]
|
||||||
|
@ -98,12 +102,17 @@ class DiffusionPipeline(ConfigMixin):
|
||||||
cached_folder = pretrained_model_name_or_path
|
cached_folder = pretrained_model_name_or_path
|
||||||
|
|
||||||
config_dict = cls.get_config_dict(cached_folder)
|
config_dict = cls.get_config_dict(cached_folder)
|
||||||
|
|
||||||
module = config_dict["_module"]
|
module = config_dict["_module"]
|
||||||
class_name_ = config_dict["_class_name"]
|
class_name_ = config_dict["_class_name"]
|
||||||
class_obj = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder)
|
|
||||||
|
if class_name_ == cls.__name__:
|
||||||
|
pipeline_class = cls
|
||||||
|
else:
|
||||||
|
pipeline_class = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder)
|
||||||
|
|
||||||
|
|
||||||
init_dict, unused = class_obj.extract_init_dict(config_dict, **kwargs)
|
init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
||||||
import ipdb; ipdb.set_trace()
|
|
||||||
|
|
||||||
init_kwargs = {}
|
init_kwargs = {}
|
||||||
|
|
||||||
|
@ -132,6 +141,5 @@ class DiffusionPipeline(ConfigMixin):
|
||||||
|
|
||||||
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
|
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
|
||||||
|
|
||||||
|
model = pipeline_class(**init_kwargs)
|
||||||
model = class_obj(**init_kwargs)
|
|
||||||
return model
|
return model
|
||||||
|
|
|
@ -22,6 +22,8 @@ from distutils.util import strtobool
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from diffusers import GaussianDDPMScheduler, UNetModel
|
from diffusers import GaussianDDPMScheduler, UNetModel
|
||||||
|
from diffusers.pipeline_utils import DiffusionPipeline
|
||||||
|
from models.vision.ddpm.modeling_ddpm import DDPM
|
||||||
|
|
||||||
|
|
||||||
global_rng = random.Random()
|
global_rng = random.Random()
|
||||||
|
@ -199,3 +201,46 @@ class SamplerTesterMixin(unittest.TestCase):
|
||||||
assert image.shape == (1, 3, 256, 256)
|
assert image.shape == (1, 3, 256, 256)
|
||||||
image_slice = image[0, -1, -3:, -3:].cpu()
|
image_slice = image[0, -1, -3:, -3:].cpu()
|
||||||
assert (image_slice - torch.tensor([[0.1746, 0.5125, -0.7920], [-0.5734, -0.2910, -0.1984], [0.4090, -0.7740, -0.3941]])).abs().sum() < 1e-3
|
assert (image_slice - torch.tensor([[0.1746, 0.5125, -0.7920], [-0.5734, -0.2910, -0.1984], [0.4090, -0.7740, -0.3941]])).abs().sum() < 1e-3
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineTesterMixin(unittest.TestCase):
|
||||||
|
def test_from_pretrained_save_pretrained(self):
|
||||||
|
# 1. Load models
|
||||||
|
model = UNetModel(ch=32, ch_mult=(1, 2), num_res_blocks=2, attn_resolutions=(16,), resolution=32)
|
||||||
|
schedular = GaussianDDPMScheduler(timesteps=10)
|
||||||
|
|
||||||
|
ddpm = DDPM(model, schedular)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
ddpm.save_pretrained(tmpdirname)
|
||||||
|
new_ddpm = DDPM.from_pretrained(tmpdirname)
|
||||||
|
|
||||||
|
generator = torch.Generator()
|
||||||
|
generator = generator.manual_seed(669472945848556)
|
||||||
|
|
||||||
|
image = ddpm(generator)
|
||||||
|
generator = generator.manual_seed(669472945848556)
|
||||||
|
new_image = new_ddpm(generator)
|
||||||
|
|
||||||
|
assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass"
|
||||||
|
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_from_pretrained_hub(self):
|
||||||
|
model_path = "fusing/ddpm-cifar10"
|
||||||
|
|
||||||
|
ddpm = DDPM.from_pretrained(model_path)
|
||||||
|
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path)
|
||||||
|
|
||||||
|
ddpm.noise_scheduler.num_timesteps = 10
|
||||||
|
ddpm_from_hub.noise_scheduler.num_timesteps = 10
|
||||||
|
|
||||||
|
|
||||||
|
generator = torch.Generator(device=torch_device)
|
||||||
|
generator = generator.manual_seed(669472945848556)
|
||||||
|
|
||||||
|
image = ddpm(generator)
|
||||||
|
generator = generator.manual_seed(669472945848556)
|
||||||
|
new_image = ddpm_from_hub(generator)
|
||||||
|
|
||||||
|
assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass"
|
||||||
|
|
Loading…
Reference in New Issue