diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index 1ac68e10..636a7ef9 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -1,6 +1,6 @@ -from dataclasses import dataclass from typing import Tuple, Union +import flax import flax.linen as nn import jax import jax.numpy as jnp @@ -19,7 +19,7 @@ from .unet_blocks_flax import ( ) -@dataclass +@flax.struct.dataclass class FlaxUNet2DConditionOutput(BaseOutput): """ Args: