From aa5c4c26092773e87cdd1c7563025f6764940845 Mon Sep 17 00:00:00 2001 From: Kamal Raj Date: Wed, 16 Nov 2022 22:33:44 +0530 Subject: [PATCH] doc string args shape fix (#1243) * doc string args shape fix * fix styling --- src/diffusers/models/unet_2d_condition.py | 3 ++- src/diffusers/models/unet_2d_condition_flax.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index becae756..c3f2fb87 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -251,7 +251,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): Args: sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps - encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states + encoder_hidden_states (`torch.FloatTensor`): + (batch_size, sequence_length, hidden_size) encoder hidden states return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index f0e72182..7ca9c191 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -230,9 +230,9 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ) -> Union[FlaxUNet2DConditionOutput, Tuple]: r""" Args: - sample (`jnp.ndarray`): (channel, height, width) noisy inputs tensor + sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor timestep (`jnp.ndarray` or `float` or `int`): timesteps - encoder_hidden_states (`jnp.ndarray`): (channel, height, width) encoder hidden states + encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a plain tuple.