rename register to register_to_config

This commit is contained in:
Patrick von Platen 2022-06-17 10:58:43 +02:00
parent 0ffda1dfcc
commit 5e6f500038
16 changed files with 20 additions and 20 deletions

View File

@ -50,7 +50,7 @@ class ConfigMixin:
""" """
config_name = None config_name = None
def register(self, **kwargs): def register_to_config(self, **kwargs):
if self.config_name is None: if self.config_name is None:
raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`") raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
kwargs["_class_name"] = self.__class__.__name__ kwargs["_class_name"] = self.__class__.__name__

View File

@ -188,7 +188,7 @@ class UNetModel(ModelMixin, ConfigMixin):
resolution=256, resolution=256,
): ):
super().__init__() super().__init__()
self.register( self.register_to_config(
ch=ch, ch=ch,
out_ch=out_ch, out_ch=out_ch,
ch_mult=ch_mult, ch_mult=ch_mult,

View File

@ -689,7 +689,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel):
resblock_updown=resblock_updown, resblock_updown=resblock_updown,
transformer_dim=transformer_dim, transformer_dim=transformer_dim,
) )
self.register( self.register_to_config(
in_channels=in_channels, in_channels=in_channels,
resolution=resolution, resolution=resolution,
model_channels=model_channels, model_channels=model_channels,
@ -780,7 +780,7 @@ class GLIDESuperResUNetModel(GLIDEUNetModel):
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown, resblock_updown=resblock_updown,
) )
self.register( self.register_to_config(
in_channels=in_channels, in_channels=in_channels,
resolution=resolution, resolution=resolution,
model_channels=model_channels, model_channels=model_channels,

View File

@ -126,7 +126,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
def __init__(self, dim, dim_mults=(1, 2, 4), groups=8, n_spks=None, spk_emb_dim=64, n_feats=80, pe_scale=1000): def __init__(self, dim, dim_mults=(1, 2, 4), groups=8, n_spks=None, spk_emb_dim=64, n_feats=80, pe_scale=1000):
super(UNetGradTTSModel, self).__init__() super(UNetGradTTSModel, self).__init__()
self.register( self.register_to_config(
dim=dim, dim=dim,
dim_mults=dim_mults, dim_mults=dim_mults,
groups=groups, groups=groups,

View File

@ -746,7 +746,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
super().__init__() super().__init__()
# register all __init__ params with self.register # register all __init__ params with self.register
self.register( self.register_to_config(
image_size=image_size, image_size=image_size,
in_channels=in_channels, in_channels=in_channels,
model_channels=model_channels, model_channels=model_channels,

View File

@ -77,13 +77,13 @@ class DiffusionPipeline(ConfigMixin):
register_dict = {name: (library, class_name)} register_dict = {name: (library, class_name)}
# save model index config # save model index config
self.register(**register_dict) self.register_to_config(**register_dict)
# set models # set models
setattr(self, name, module) setattr(self, name, module)
register_dict = {"_module": self.__module__.split(".")[-1]} register_dict = {"_module": self.__module__.split(".")[-1]}
self.register(**register_dict) self.register_to_config(**register_dict)
def save_pretrained(self, save_directory: Union[str, os.PathLike]): def save_pretrained(self, save_directory: Union[str, os.PathLike]):
self.save_config(save_directory) self.save_config(save_directory)

View File

@ -655,7 +655,7 @@ class VQModel(ModelMixin, ConfigMixin):
super().__init__() super().__init__()
# register all __init__ params with self.register # register all __init__ params with self.register
self.register( self.register_to_config(
ch=ch, ch=ch,
out_ch=out_ch, out_ch=out_ch,
num_res_blocks=num_res_blocks, num_res_blocks=num_res_blocks,
@ -786,7 +786,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
super().__init__() super().__init__()
# register all __init__ params with self.register # register all __init__ params with self.register
self.register( self.register_to_config(
ch=ch, ch=ch,
out_ch=out_ch, out_ch=out_ch,
num_res_blocks=num_res_blocks, num_res_blocks=num_res_blocks,

View File

@ -232,7 +232,7 @@ class DiffWave(ModelMixin, ConfigMixin):
super().__init__() super().__init__()
# register all init arguments with self.register # register all init arguments with self.register
self.register( self.register_to_config(
in_channels=in_channels, in_channels=in_channels,
res_channels=res_channels, res_channels=res_channels,
skip_channels=skip_channels, skip_channels=skip_channels,

View File

@ -355,7 +355,7 @@ class TextEncoder(ModelMixin, ConfigMixin):
): ):
super(TextEncoder, self).__init__() super(TextEncoder, self).__init__()
self.register( self.register_to_config(
n_vocab=n_vocab, n_vocab=n_vocab,
n_feats=n_feats, n_feats=n_feats,
n_channels=n_channels, n_channels=n_channels,

View File

@ -656,7 +656,7 @@ class VQModel(ModelMixin, ConfigMixin):
super().__init__() super().__init__()
# register all __init__ params with self.register # register all __init__ params with self.register
self.register( self.register_to_config(
ch=ch, ch=ch,
out_ch=out_ch, out_ch=out_ch,
num_res_blocks=num_res_blocks, num_res_blocks=num_res_blocks,
@ -787,7 +787,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
super().__init__() super().__init__()
# register all __init__ params with self.register # register all __init__ params with self.register
self.register( self.register_to_config(
ch=ch, ch=ch,
out_ch=out_ch, out_ch=out_ch,
num_res_blocks=num_res_blocks, num_res_blocks=num_res_blocks,

View File

@ -57,7 +57,7 @@ class ClassifierFreeGuidanceScheduler(nn.Module, ConfigMixin):
beta_schedule="squaredcos_cap_v2", beta_schedule="squaredcos_cap_v2",
): ):
super().__init__() super().__init__()
self.register( self.register_to_config(
timesteps=timesteps, timesteps=timesteps,
beta_schedule=beta_schedule, beta_schedule=beta_schedule,
) )

View File

@ -32,7 +32,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
tensor_format="np", tensor_format="np",
): ):
super().__init__() super().__init__()
self.register( self.register_to_config(
timesteps=timesteps, timesteps=timesteps,
beta_start=beta_start, beta_start=beta_start,
beta_end=beta_end, beta_end=beta_end,

View File

@ -33,7 +33,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
tensor_format="np", tensor_format="np",
): ):
super().__init__() super().__init__()
self.register( self.register_to_config(
timesteps=timesteps, timesteps=timesteps,
beta_start=beta_start, beta_start=beta_start,
beta_end=beta_end, beta_end=beta_end,

View File

@ -25,7 +25,7 @@ class GradTTSScheduler(SchedulerMixin, ConfigMixin):
tensor_format="np", tensor_format="np",
): ):
super().__init__() super().__init__()
self.register( self.register_to_config(
timesteps=timesteps, timesteps=timesteps,
beta_start=beta_start, beta_start=beta_start,
beta_end=beta_end, beta_end=beta_end,

View File

@ -29,7 +29,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
tensor_format="np", tensor_format="np",
): ):
super().__init__() super().__init__()
self.register( self.register_to_config(
timesteps=timesteps, timesteps=timesteps,
beta_start=beta_start, beta_start=beta_start,
beta_end=beta_end, beta_end=beta_end,

View File

@ -57,7 +57,7 @@ class ConfigTester(unittest.TestCase):
d="for diffusion", d="for diffusion",
e=[1, 3], e=[1, 3],
): ):
self.register(a=a, b=b, c=c, d=d, e=e) self.register_to_config(a=a, b=b, c=c, d=d, e=e)
obj = SampleObject() obj = SampleObject()
config = obj.config config = obj.config