[Type Hint] Unet Models (#330)
* add void check * remove void, add types for params
This commit is contained in:
parent
9b704f7688
commit
b1fe170642
|
@ -1,4 +1,4 @@
|
|||
from typing import Dict, Union
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -13,23 +13,23 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
|||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
sample_size=None,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
center_input_sample=False,
|
||||
time_embedding_type="positional",
|
||||
freq_shift=0,
|
||||
flip_sin_to_cos=True,
|
||||
down_block_types=("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
|
||||
up_block_types=("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
|
||||
block_out_channels=(224, 448, 672, 896),
|
||||
layers_per_block=2,
|
||||
mid_block_scale_factor=1,
|
||||
downsample_padding=1,
|
||||
act_fn="silu",
|
||||
attention_head_dim=8,
|
||||
norm_num_groups=32,
|
||||
norm_eps=1e-5,
|
||||
sample_size: Optional[int] = None,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
center_input_sample: bool = False,
|
||||
time_embedding_type: str = "positional",
|
||||
freq_shift: int = 0,
|
||||
flip_sin_to_cos: bool = True,
|
||||
down_block_types: Tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
|
||||
up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
|
||||
block_out_channels: Tuple[int] = (224, 448, 672, 896),
|
||||
layers_per_block: int = 2,
|
||||
mid_block_scale_factor: float = 1,
|
||||
downsample_padding: int = 1,
|
||||
act_fn: str = "silu",
|
||||
attention_head_dim: int = 8,
|
||||
norm_num_groups: int = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Dict, Union
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -13,23 +13,28 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
|||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
sample_size=None,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
center_input_sample=False,
|
||||
flip_sin_to_cos=True,
|
||||
freq_shift=0,
|
||||
down_block_types=("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"),
|
||||
up_block_types=("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
||||
block_out_channels=(320, 640, 1280, 1280),
|
||||
layers_per_block=2,
|
||||
downsample_padding=1,
|
||||
mid_block_scale_factor=1,
|
||||
act_fn="silu",
|
||||
norm_num_groups=32,
|
||||
norm_eps=1e-5,
|
||||
cross_attention_dim=1280,
|
||||
attention_head_dim=8,
|
||||
sample_size: Optional[int] = None,
|
||||
in_channels: int = 4,
|
||||
out_channels: int = 4,
|
||||
center_input_sample: bool = False,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
||||
layers_per_block: int = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
act_fn: str = "silu",
|
||||
norm_num_groups: int = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: int = 1280,
|
||||
attention_head_dim: int = 8,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
|
Loading…
Reference in New Issue