FlaxUNet2DConditionOutput @flax.struct.dataclass (#550)

This commit is contained in:
Mishig Davaadorj 2022-09-18 19:35:37 +02:00 committed by GitHub
parent d09bbae515
commit 0c0c222432
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 2 deletions

View File

@ -1,6 +1,6 @@
from dataclasses import dataclass
from typing import Tuple, Union
import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
@ -19,7 +19,7 @@ from .unet_blocks_flax import (
)
@dataclass
@flax.struct.dataclass
class FlaxUNet2DConditionOutput(BaseOutput):
"""
Args: