parent
4ff4d4db12
commit
f1b9ee7ed9
|
@ -170,7 +170,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||||
timestep: Union[torch.Tensor, float, int],
|
timestep: Union[torch.Tensor, float, int],
|
||||||
return_dict: bool = True,
|
return_dict: bool = True,
|
||||||
) -> Union[UNet2DOutput, Tuple]:
|
) -> Union[UNet2DOutput, Tuple]:
|
||||||
"""r
|
r"""
|
||||||
Args:
|
Args:
|
||||||
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
||||||
timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps
|
timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps
|
||||||
|
|
|
@ -215,7 +215,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||||
return_dict: bool = True,
|
return_dict: bool = True,
|
||||||
train: bool = False,
|
train: bool = False,
|
||||||
) -> Union[FlaxUNet2DConditionOutput, Tuple]:
|
) -> Union[FlaxUNet2DConditionOutput, Tuple]:
|
||||||
"""r
|
r"""
|
||||||
Args:
|
Args:
|
||||||
sample (`jnp.ndarray`): (channel, height, width) noisy inputs tensor
|
sample (`jnp.ndarray`): (channel, height, width) noisy inputs tensor
|
||||||
timestep (`jnp.ndarray` or `float` or `int`): timesteps
|
timestep (`jnp.ndarray` or `float` or `int`): timesteps
|
||||||
|
|
Loading…
Reference in New Issue