diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index a74b5ea0..1b7b2356 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -50,7 +50,7 @@ class ConfigMixin: """ config_name = None - def register(self, **kwargs): + def register_to_config(self, **kwargs): if self.config_name is None: raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`") kwargs["_class_name"] = self.__class__.__name__ diff --git a/src/diffusers/models/unet.py b/src/diffusers/models/unet.py index 0f7559ec..16fb509c 100644 --- a/src/diffusers/models/unet.py +++ b/src/diffusers/models/unet.py @@ -188,7 +188,7 @@ class UNetModel(ModelMixin, ConfigMixin): resolution=256, ): super().__init__() - self.register( + self.register_to_config( ch=ch, out_ch=out_ch, ch_mult=ch_mult, diff --git a/src/diffusers/models/unet_glide.py b/src/diffusers/models/unet_glide.py index 1c5cf800..abbd7dae 100644 --- a/src/diffusers/models/unet_glide.py +++ b/src/diffusers/models/unet_glide.py @@ -689,7 +689,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel): resblock_updown=resblock_updown, transformer_dim=transformer_dim, ) - self.register( + self.register_to_config( in_channels=in_channels, resolution=resolution, model_channels=model_channels, @@ -780,7 +780,7 @@ class GLIDESuperResUNetModel(GLIDEUNetModel): use_scale_shift_norm=use_scale_shift_norm, resblock_updown=resblock_updown, ) - self.register( + self.register_to_config( in_channels=in_channels, resolution=resolution, model_channels=model_channels, diff --git a/src/diffusers/models/unet_grad_tts.py b/src/diffusers/models/unet_grad_tts.py index 67921771..84bd64d9 100644 --- a/src/diffusers/models/unet_grad_tts.py +++ b/src/diffusers/models/unet_grad_tts.py @@ -126,7 +126,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): def __init__(self, dim, dim_mults=(1, 2, 4), groups=8, n_spks=None, spk_emb_dim=64, n_feats=80, pe_scale=1000): super(UNetGradTTSModel, self).__init__() - self.register( + self.register_to_config( dim=dim, dim_mults=dim_mults, groups=groups, diff --git a/src/diffusers/models/unet_ldm.py b/src/diffusers/models/unet_ldm.py index d6e5a5c0..cca32313 100644 --- a/src/diffusers/models/unet_ldm.py +++ b/src/diffusers/models/unet_ldm.py @@ -746,7 +746,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin): super().__init__() # register all __init__ params with self.register - self.register( + self.register_to_config( image_size=image_size, in_channels=in_channels, model_channels=model_channels, diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index ceae102d..6b49f22c 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -77,13 +77,13 @@ class DiffusionPipeline(ConfigMixin): register_dict = {name: (library, class_name)} # save model index config - self.register(**register_dict) + self.register_to_config(**register_dict) # set models setattr(self, name, module) register_dict = {"_module": self.__module__.split(".")[-1]} - self.register(**register_dict) + self.register_to_config(**register_dict) def save_pretrained(self, save_directory: Union[str, os.PathLike]): self.save_config(save_directory) diff --git a/src/diffusers/pipelines/old/latent_diffusion/modeling_vae.py b/src/diffusers/pipelines/old/latent_diffusion/modeling_vae.py index 7b299eee..70247897 100644 --- a/src/diffusers/pipelines/old/latent_diffusion/modeling_vae.py +++ b/src/diffusers/pipelines/old/latent_diffusion/modeling_vae.py @@ -655,7 +655,7 @@ class VQModel(ModelMixin, ConfigMixin): super().__init__() # register all __init__ params with self.register - self.register( + self.register_to_config( ch=ch, out_ch=out_ch, num_res_blocks=num_res_blocks, @@ -786,7 +786,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin): super().__init__() # register all __init__ params with self.register - self.register( + self.register_to_config( ch=ch, out_ch=out_ch, num_res_blocks=num_res_blocks, diff --git a/src/diffusers/pipelines/pipeline_bddm.py b/src/diffusers/pipelines/pipeline_bddm.py index de0689ce..82e400ce 100644 --- a/src/diffusers/pipelines/pipeline_bddm.py +++ b/src/diffusers/pipelines/pipeline_bddm.py @@ -232,7 +232,7 @@ class DiffWave(ModelMixin, ConfigMixin): super().__init__() # register all init arguments with self.register - self.register( + self.register_to_config( in_channels=in_channels, res_channels=res_channels, skip_channels=skip_channels, diff --git a/src/diffusers/pipelines/pipeline_grad_tts.py b/src/diffusers/pipelines/pipeline_grad_tts.py index 246661be..c9fad219 100644 --- a/src/diffusers/pipelines/pipeline_grad_tts.py +++ b/src/diffusers/pipelines/pipeline_grad_tts.py @@ -355,7 +355,7 @@ class TextEncoder(ModelMixin, ConfigMixin): ): super(TextEncoder, self).__init__() - self.register( + self.register_to_config( n_vocab=n_vocab, n_feats=n_feats, n_channels=n_channels, diff --git a/src/diffusers/pipelines/pipeline_latent_diffusion.py b/src/diffusers/pipelines/pipeline_latent_diffusion.py index 10ee253f..9b29faa9 100644 --- a/src/diffusers/pipelines/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/pipeline_latent_diffusion.py @@ -656,7 +656,7 @@ class VQModel(ModelMixin, ConfigMixin): super().__init__() # register all __init__ params with self.register - self.register( + self.register_to_config( ch=ch, out_ch=out_ch, num_res_blocks=num_res_blocks, @@ -787,7 +787,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin): super().__init__() # register all __init__ params with self.register - self.register( + self.register_to_config( ch=ch, out_ch=out_ch, num_res_blocks=num_res_blocks, diff --git a/src/diffusers/schedulers/classifier_free_guidance.py b/src/diffusers/schedulers/classifier_free_guidance.py index ec435655..bd7fbe7d 100644 --- a/src/diffusers/schedulers/classifier_free_guidance.py +++ b/src/diffusers/schedulers/classifier_free_guidance.py @@ -57,7 +57,7 @@ class ClassifierFreeGuidanceScheduler(nn.Module, ConfigMixin): beta_schedule="squaredcos_cap_v2", ): super().__init__() - self.register( + self.register_to_config( timesteps=timesteps, beta_schedule=beta_schedule, ) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 88e4725e..8b318089 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -32,7 +32,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): tensor_format="np", ): super().__init__() - self.register( + self.register_to_config( timesteps=timesteps, beta_start=beta_start, beta_end=beta_end, diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index e2fc2890..356eedd6 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -33,7 +33,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): tensor_format="np", ): super().__init__() - self.register( + self.register_to_config( timesteps=timesteps, beta_start=beta_start, beta_end=beta_end, diff --git a/src/diffusers/schedulers/scheduling_grad_tts.py b/src/diffusers/schedulers/scheduling_grad_tts.py index ca42921d..a97ab67b 100644 --- a/src/diffusers/schedulers/scheduling_grad_tts.py +++ b/src/diffusers/schedulers/scheduling_grad_tts.py @@ -25,7 +25,7 @@ class GradTTSScheduler(SchedulerMixin, ConfigMixin): tensor_format="np", ): super().__init__() - self.register( + self.register_to_config( timesteps=timesteps, beta_start=beta_start, beta_end=beta_end, diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index ee6f3bcf..9ef40a90 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -29,7 +29,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): tensor_format="np", ): super().__init__() - self.register( + self.register_to_config( timesteps=timesteps, beta_start=beta_start, beta_end=beta_end, diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index e78f01a4..1fa40cbd 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -57,7 +57,7 @@ class ConfigTester(unittest.TestCase): d="for diffusion", e=[1, 3], ): - self.register(a=a, b=b, c=c, d=d, e=e) + self.register_to_config(a=a, b=b, c=c, d=d, e=e) obj = SampleObject() config = obj.config