parent
4ff4d4db12
commit
f1b9ee7ed9
|
@ -170,7 +170,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
|||
timestep: Union[torch.Tensor, float, int],
|
||||
return_dict: bool = True,
|
||||
) -> Union[UNet2DOutput, Tuple]:
|
||||
"""r
|
||||
r"""
|
||||
Args:
|
||||
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
||||
timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps
|
||||
|
|
|
@ -215,7 +215,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
|||
return_dict: bool = True,
|
||||
train: bool = False,
|
||||
) -> Union[FlaxUNet2DConditionOutput, Tuple]:
|
||||
"""r
|
||||
r"""
|
||||
Args:
|
||||
sample (`jnp.ndarray`): (channel, height, width) noisy inputs tensor
|
||||
timestep (`jnp.ndarray` or `float` or `int`): timesteps
|
||||
|
|
Loading…
Reference in New Issue