Allow `UNet2DModel` to use arbitrary class embeddings (#2080)

* Allow `UNet2DModel` to use arbitrary class embeddings.

We can currently use class conditioning in `UNet2DConditionModel`, but
not in `UNet2DModel`. However, `UNet2DConditionModel` requires text
conditioning too, which is unrelated to other types of conditioning.
This commit makes it possible for `UNet2DModel` to be conditioned on
entities other than timesteps. This is useful for training /
research purposes. We can currently train models to perform
unconditional image generation or text-to-image generation, but it's not
straightforward to train a model to perform class-conditioned image
generation, if text conditioning is not required.

We could potentiall use `UNet2DConditionModel` for class-conditioning
without text embeddings by using down/up blocks without
cross-conditioning. However:
- The mid block currently requires cross attention.
- We are required to provide `encoder_hidden_states` to `forward`.

* Style

* Align class conditioning, add docstring for `num_class_embeds`.

* Copy docstring to versatile_diffusion UNetFlatConditionModel
This commit is contained in:
Pedro Cuenca 2023-01-26 13:46:32 +01:00 committed by GitHub
parent 0856137337
commit 915a563611
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 36 additions and 0 deletions

View File

@ -70,6 +70,11 @@ class UNet2DModel(ModelMixin, ConfigMixin):
norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for the normalization.
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately
summed with the time embeddings. Choose from `None`, `"timestep"`, or `"identity"`.
num_class_embeds (`int`, *optional*, defaults to None):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
class conditioning with `class_embed_type` equal to `None`.
"""
@register_to_config
@ -94,6 +99,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
norm_eps: float = 1e-5,
resnet_time_scale_shift: str = "default",
add_attention: bool = True,
class_embed_type: Optional[str] = None,
num_class_embeds: Optional[int] = None,
):
super().__init__()
@ -113,6 +120,16 @@ class UNet2DModel(ModelMixin, ConfigMixin):
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
# class embedding
if class_embed_type is None and num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
elif class_embed_type == "timestep":
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
elif class_embed_type == "identity":
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
else:
self.class_embedding = None
self.down_blocks = nn.ModuleList([])
self.mid_block = None
self.up_blocks = nn.ModuleList([])
@ -190,12 +207,15 @@ class UNet2DModel(ModelMixin, ConfigMixin):
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
class_labels: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[UNet2DOutput, Tuple]:
r"""
Args:
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps
class_labels (`torch.FloatTensor`, *optional*, defaults to `None`):
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
@ -225,6 +245,16 @@ class UNet2DModel(ModelMixin, ConfigMixin):
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb)
if self.class_embedding is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when doing class conditioning")
if self.config.class_embed_type == "timestep":
class_labels = self.time_proj(class_labels)
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb
# 2. pre-process
skip_sample = sample
sample = self.conv_in(sample)

View File

@ -87,6 +87,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately
summed with the time embeddings. Choose from `None`, `"timestep"`, or `"identity"`.
num_class_embeds (`int`, *optional*, defaults to None):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
class conditioning with `class_embed_type` equal to `None`.
"""
_supports_gradient_checkpointing = True

View File

@ -168,6 +168,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
for resnet blocks, see [`~models.resnet.ResnetBlockFlat`]. Choose from `default` or `scale_shift`.
class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately
summed with the time embeddings. Choose from `None`, `"timestep"`, or `"identity"`.
num_class_embeds (`int`, *optional*, defaults to None):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
class conditioning with `class_embed_type` equal to `None`.
"""
_supports_gradient_checkpointing = True