FlaxUNet2DConditionOutput @flax.struct.dataclass (#550)
This commit is contained in:
parent
d09bbae515
commit
0c0c222432
|
@ -1,6 +1,6 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Tuple, Union
|
||||
|
||||
import flax
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
@ -19,7 +19,7 @@ from .unet_blocks_flax import (
|
|||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@flax.struct.dataclass
|
||||
class FlaxUNet2DConditionOutput(BaseOutput):
|
||||
"""
|
||||
Args:
|
||||
|
|
Loading…
Reference in New Issue