From 915a56361157b02f1429f5cbd60094a83b1b0ff4 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Thu, 26 Jan 2023 13:46:32 +0100 Subject: [PATCH] Allow `UNet2DModel` to use arbitrary class embeddings (#2080) * Allow `UNet2DModel` to use arbitrary class embeddings. We can currently use class conditioning in `UNet2DConditionModel`, but not in `UNet2DModel`. However, `UNet2DConditionModel` requires text conditioning too, which is unrelated to other types of conditioning. This commit makes it possible for `UNet2DModel` to be conditioned on entities other than timesteps. This is useful for training / research purposes. We can currently train models to perform unconditional image generation or text-to-image generation, but it's not straightforward to train a model to perform class-conditioned image generation, if text conditioning is not required. We could potentiall use `UNet2DConditionModel` for class-conditioning without text embeddings by using down/up blocks without cross-conditioning. However: - The mid block currently requires cross attention. - We are required to provide `encoder_hidden_states` to `forward`. * Style * Align class conditioning, add docstring for `num_class_embeds`. * Copy docstring to versatile_diffusion UNetFlatConditionModel --- src/diffusers/models/unet_2d.py | 30 +++++++++++++++++++ src/diffusers/models/unet_2d_condition.py | 3 ++ .../versatile_diffusion/modeling_text_unet.py | 3 ++ 3 files changed, 36 insertions(+) diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py index ccbb218b..35f5dc34 100644 --- a/src/diffusers/models/unet_2d.py +++ b/src/diffusers/models/unet_2d.py @@ -70,6 +70,11 @@ class UNet2DModel(ModelMixin, ConfigMixin): norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for the normalization. resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately + summed with the time embeddings. Choose from `None`, `"timestep"`, or `"identity"`. + num_class_embeds (`int`, *optional*, defaults to None): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. """ @register_to_config @@ -94,6 +99,8 @@ class UNet2DModel(ModelMixin, ConfigMixin): norm_eps: float = 1e-5, resnet_time_scale_shift: str = "default", add_attention: bool = True, + class_embed_type: Optional[str] = None, + num_class_embeds: Optional[int] = None, ): super().__init__() @@ -113,6 +120,16 @@ class UNet2DModel(ModelMixin, ConfigMixin): self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + else: + self.class_embedding = None + self.down_blocks = nn.ModuleList([]) self.mid_block = None self.up_blocks = nn.ModuleList([]) @@ -190,12 +207,15 @@ class UNet2DModel(ModelMixin, ConfigMixin): self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], + class_labels: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[UNet2DOutput, Tuple]: r""" Args: sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps + class_labels (`torch.FloatTensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple. @@ -225,6 +245,16 @@ class UNet2DModel(ModelMixin, ConfigMixin): t_emb = t_emb.to(dtype=self.dtype) emb = self.time_embedding(t_emb) + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when doing class conditioning") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + # 2. pre-process skip_sample = sample sample = self.conv_in(sample) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index ee7bb812..bfdb50b7 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -87,6 +87,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`. class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, `"timestep"`, or `"identity"`. + num_class_embeds (`int`, *optional*, defaults to None): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. """ _supports_gradient_checkpointing = True diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index e52ebcba..dd9988be 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -168,6 +168,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): for resnet blocks, see [`~models.resnet.ResnetBlockFlat`]. Choose from `default` or `scale_shift`. class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, `"timestep"`, or `"identity"`. + num_class_embeds (`int`, *optional*, defaults to None): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. """ _supports_gradient_checkpointing = True