rename register to register_to_config
This commit is contained in:
parent
0ffda1dfcc
commit
5e6f500038
|
@ -50,7 +50,7 @@ class ConfigMixin:
|
|||
"""
|
||||
config_name = None
|
||||
|
||||
def register(self, **kwargs):
|
||||
def register_to_config(self, **kwargs):
|
||||
if self.config_name is None:
|
||||
raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
|
||||
kwargs["_class_name"] = self.__class__.__name__
|
||||
|
|
|
@ -188,7 +188,7 @@ class UNetModel(ModelMixin, ConfigMixin):
|
|||
resolution=256,
|
||||
):
|
||||
super().__init__()
|
||||
self.register(
|
||||
self.register_to_config(
|
||||
ch=ch,
|
||||
out_ch=out_ch,
|
||||
ch_mult=ch_mult,
|
||||
|
|
|
@ -689,7 +689,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel):
|
|||
resblock_updown=resblock_updown,
|
||||
transformer_dim=transformer_dim,
|
||||
)
|
||||
self.register(
|
||||
self.register_to_config(
|
||||
in_channels=in_channels,
|
||||
resolution=resolution,
|
||||
model_channels=model_channels,
|
||||
|
@ -780,7 +780,7 @@ class GLIDESuperResUNetModel(GLIDEUNetModel):
|
|||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
resblock_updown=resblock_updown,
|
||||
)
|
||||
self.register(
|
||||
self.register_to_config(
|
||||
in_channels=in_channels,
|
||||
resolution=resolution,
|
||||
model_channels=model_channels,
|
||||
|
|
|
@ -126,7 +126,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
|
|||
def __init__(self, dim, dim_mults=(1, 2, 4), groups=8, n_spks=None, spk_emb_dim=64, n_feats=80, pe_scale=1000):
|
||||
super(UNetGradTTSModel, self).__init__()
|
||||
|
||||
self.register(
|
||||
self.register_to_config(
|
||||
dim=dim,
|
||||
dim_mults=dim_mults,
|
||||
groups=groups,
|
||||
|
|
|
@ -746,7 +746,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
|||
super().__init__()
|
||||
|
||||
# register all __init__ params with self.register
|
||||
self.register(
|
||||
self.register_to_config(
|
||||
image_size=image_size,
|
||||
in_channels=in_channels,
|
||||
model_channels=model_channels,
|
||||
|
|
|
@ -77,13 +77,13 @@ class DiffusionPipeline(ConfigMixin):
|
|||
register_dict = {name: (library, class_name)}
|
||||
|
||||
# save model index config
|
||||
self.register(**register_dict)
|
||||
self.register_to_config(**register_dict)
|
||||
|
||||
# set models
|
||||
setattr(self, name, module)
|
||||
|
||||
register_dict = {"_module": self.__module__.split(".")[-1]}
|
||||
self.register(**register_dict)
|
||||
self.register_to_config(**register_dict)
|
||||
|
||||
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
|
||||
self.save_config(save_directory)
|
||||
|
|
|
@ -655,7 +655,7 @@ class VQModel(ModelMixin, ConfigMixin):
|
|||
super().__init__()
|
||||
|
||||
# register all __init__ params with self.register
|
||||
self.register(
|
||||
self.register_to_config(
|
||||
ch=ch,
|
||||
out_ch=out_ch,
|
||||
num_res_blocks=num_res_blocks,
|
||||
|
@ -786,7 +786,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
|
|||
super().__init__()
|
||||
|
||||
# register all __init__ params with self.register
|
||||
self.register(
|
||||
self.register_to_config(
|
||||
ch=ch,
|
||||
out_ch=out_ch,
|
||||
num_res_blocks=num_res_blocks,
|
||||
|
|
|
@ -232,7 +232,7 @@ class DiffWave(ModelMixin, ConfigMixin):
|
|||
super().__init__()
|
||||
|
||||
# register all init arguments with self.register
|
||||
self.register(
|
||||
self.register_to_config(
|
||||
in_channels=in_channels,
|
||||
res_channels=res_channels,
|
||||
skip_channels=skip_channels,
|
||||
|
|
|
@ -355,7 +355,7 @@ class TextEncoder(ModelMixin, ConfigMixin):
|
|||
):
|
||||
super(TextEncoder, self).__init__()
|
||||
|
||||
self.register(
|
||||
self.register_to_config(
|
||||
n_vocab=n_vocab,
|
||||
n_feats=n_feats,
|
||||
n_channels=n_channels,
|
||||
|
|
|
@ -656,7 +656,7 @@ class VQModel(ModelMixin, ConfigMixin):
|
|||
super().__init__()
|
||||
|
||||
# register all __init__ params with self.register
|
||||
self.register(
|
||||
self.register_to_config(
|
||||
ch=ch,
|
||||
out_ch=out_ch,
|
||||
num_res_blocks=num_res_blocks,
|
||||
|
@ -787,7 +787,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
|
|||
super().__init__()
|
||||
|
||||
# register all __init__ params with self.register
|
||||
self.register(
|
||||
self.register_to_config(
|
||||
ch=ch,
|
||||
out_ch=out_ch,
|
||||
num_res_blocks=num_res_blocks,
|
||||
|
|
|
@ -57,7 +57,7 @@ class ClassifierFreeGuidanceScheduler(nn.Module, ConfigMixin):
|
|||
beta_schedule="squaredcos_cap_v2",
|
||||
):
|
||||
super().__init__()
|
||||
self.register(
|
||||
self.register_to_config(
|
||||
timesteps=timesteps,
|
||||
beta_schedule=beta_schedule,
|
||||
)
|
||||
|
|
|
@ -32,7 +32,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
|||
tensor_format="np",
|
||||
):
|
||||
super().__init__()
|
||||
self.register(
|
||||
self.register_to_config(
|
||||
timesteps=timesteps,
|
||||
beta_start=beta_start,
|
||||
beta_end=beta_end,
|
||||
|
|
|
@ -33,7 +33,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|||
tensor_format="np",
|
||||
):
|
||||
super().__init__()
|
||||
self.register(
|
||||
self.register_to_config(
|
||||
timesteps=timesteps,
|
||||
beta_start=beta_start,
|
||||
beta_end=beta_end,
|
||||
|
|
|
@ -25,7 +25,7 @@ class GradTTSScheduler(SchedulerMixin, ConfigMixin):
|
|||
tensor_format="np",
|
||||
):
|
||||
super().__init__()
|
||||
self.register(
|
||||
self.register_to_config(
|
||||
timesteps=timesteps,
|
||||
beta_start=beta_start,
|
||||
beta_end=beta_end,
|
||||
|
|
|
@ -29,7 +29,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
|||
tensor_format="np",
|
||||
):
|
||||
super().__init__()
|
||||
self.register(
|
||||
self.register_to_config(
|
||||
timesteps=timesteps,
|
||||
beta_start=beta_start,
|
||||
beta_end=beta_end,
|
||||
|
|
|
@ -57,7 +57,7 @@ class ConfigTester(unittest.TestCase):
|
|||
d="for diffusion",
|
||||
e=[1, 3],
|
||||
):
|
||||
self.register(a=a, b=b, c=c, d=d, e=e)
|
||||
self.register_to_config(a=a, b=b, c=c, d=d, e=e)
|
||||
|
||||
obj = SampleObject()
|
||||
config = obj.config
|
||||
|
|
Loading…
Reference in New Issue