Allow `UNet2DModel` to use arbitrary class embeddings (#2080)
* Allow `UNet2DModel` to use arbitrary class embeddings. We can currently use class conditioning in `UNet2DConditionModel`, but not in `UNet2DModel`. However, `UNet2DConditionModel` requires text conditioning too, which is unrelated to other types of conditioning. This commit makes it possible for `UNet2DModel` to be conditioned on entities other than timesteps. This is useful for training / research purposes. We can currently train models to perform unconditional image generation or text-to-image generation, but it's not straightforward to train a model to perform class-conditioned image generation, if text conditioning is not required. We could potentiall use `UNet2DConditionModel` for class-conditioning without text embeddings by using down/up blocks without cross-conditioning. However: - The mid block currently requires cross attention. - We are required to provide `encoder_hidden_states` to `forward`. * Style * Align class conditioning, add docstring for `num_class_embeds`. * Copy docstring to versatile_diffusion UNetFlatConditionModel
This commit is contained in:
parent
0856137337
commit
915a563611
|
@ -70,6 +70,11 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
|||
norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for the normalization.
|
||||
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
||||
for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
|
||||
class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately
|
||||
summed with the time embeddings. Choose from `None`, `"timestep"`, or `"identity"`.
|
||||
num_class_embeds (`int`, *optional*, defaults to None):
|
||||
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
||||
class conditioning with `class_embed_type` equal to `None`.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
|
@ -94,6 +99,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
|||
norm_eps: float = 1e-5,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
add_attention: bool = True,
|
||||
class_embed_type: Optional[str] = None,
|
||||
num_class_embeds: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -113,6 +120,16 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
|||
|
||||
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||
|
||||
# class embedding
|
||||
if class_embed_type is None and num_class_embeds is not None:
|
||||
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
||||
elif class_embed_type == "timestep":
|
||||
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||
elif class_embed_type == "identity":
|
||||
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
||||
else:
|
||||
self.class_embedding = None
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.mid_block = None
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
@ -190,12 +207,15 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
|||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[UNet2DOutput, Tuple]:
|
||||
r"""
|
||||
Args:
|
||||
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
||||
timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps
|
||||
class_labels (`torch.FloatTensor`, *optional*, defaults to `None`):
|
||||
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
|
||||
|
||||
|
@ -225,6 +245,16 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
|||
t_emb = t_emb.to(dtype=self.dtype)
|
||||
emb = self.time_embedding(t_emb)
|
||||
|
||||
if self.class_embedding is not None:
|
||||
if class_labels is None:
|
||||
raise ValueError("class_labels should be provided when doing class conditioning")
|
||||
|
||||
if self.config.class_embed_type == "timestep":
|
||||
class_labels = self.time_proj(class_labels)
|
||||
|
||||
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
||||
emb = emb + class_emb
|
||||
|
||||
# 2. pre-process
|
||||
skip_sample = sample
|
||||
sample = self.conv_in(sample)
|
||||
|
|
|
@ -87,6 +87,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|||
for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
|
||||
class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately
|
||||
summed with the time embeddings. Choose from `None`, `"timestep"`, or `"identity"`.
|
||||
num_class_embeds (`int`, *optional*, defaults to None):
|
||||
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
||||
class conditioning with `class_embed_type` equal to `None`.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
|
|
@ -168,6 +168,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|||
for resnet blocks, see [`~models.resnet.ResnetBlockFlat`]. Choose from `default` or `scale_shift`.
|
||||
class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately
|
||||
summed with the time embeddings. Choose from `None`, `"timestep"`, or `"identity"`.
|
||||
num_class_embeds (`int`, *optional*, defaults to None):
|
||||
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
||||
class conditioning with `class_embed_type` equal to `None`.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
|
Loading…
Reference in New Issue