From 0c0c222432586d226659f283cb2203cc2f365758 Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Sun, 18 Sep 2022 19:35:37 +0200 Subject: [PATCH] FlaxUNet2DConditionOutput @flax.struct.dataclass (#550) --- src/diffusers/models/unet_2d_condition_flax.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index 1ac68e10..636a7ef9 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -1,6 +1,6 @@ -from dataclasses import dataclass from typing import Tuple, Union +import flax import flax.linen as nn import jax import jax.numpy as jnp @@ -19,7 +19,7 @@ from .unet_blocks_flax import ( ) -@dataclass +@flax.struct.dataclass class FlaxUNet2DConditionOutput(BaseOutput): """ Args: