diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py index ccbb218b..35f5dc34 100644 --- a/src/diffusers/models/unet_2d.py +++ b/src/diffusers/models/unet_2d.py @@ -70,6 +70,11 @@ class UNet2DModel(ModelMixin, ConfigMixin): norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for the normalization. resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately + summed with the time embeddings. Choose from `None`, `"timestep"`, or `"identity"`. + num_class_embeds (`int`, *optional*, defaults to None): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. """ @register_to_config @@ -94,6 +99,8 @@ class UNet2DModel(ModelMixin, ConfigMixin): norm_eps: float = 1e-5, resnet_time_scale_shift: str = "default", add_attention: bool = True, + class_embed_type: Optional[str] = None, + num_class_embeds: Optional[int] = None, ): super().__init__() @@ -113,6 +120,16 @@ class UNet2DModel(ModelMixin, ConfigMixin): self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + else: + self.class_embedding = None + self.down_blocks = nn.ModuleList([]) self.mid_block = None self.up_blocks = nn.ModuleList([]) @@ -190,12 +207,15 @@ class UNet2DModel(ModelMixin, ConfigMixin): self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], + class_labels: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[UNet2DOutput, Tuple]: r""" Args: sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps + class_labels (`torch.FloatTensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple. @@ -225,6 +245,16 @@ class UNet2DModel(ModelMixin, ConfigMixin): t_emb = t_emb.to(dtype=self.dtype) emb = self.time_embedding(t_emb) + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when doing class conditioning") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + # 2. pre-process skip_sample = sample sample = self.conv_in(sample) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index ee7bb812..bfdb50b7 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -87,6 +87,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`. class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, `"timestep"`, or `"identity"`. + num_class_embeds (`int`, *optional*, defaults to None): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. """ _supports_gradient_checkpointing = True diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index e52ebcba..dd9988be 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -168,6 +168,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): for resnet blocks, see [`~models.resnet.ResnetBlockFlat`]. Choose from `default` or `scale_shift`. class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, `"timestep"`, or `"identity"`. + num_class_embeds (`int`, *optional*, defaults to None): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. """ _supports_gradient_checkpointing = True