[Type Hint] Unet Models (#330)

* add void check

* remove void, add types for params
This commit is contained in:
Sid Sahai 2022-09-03 03:31:38 -07:00 committed by GitHub
parent 9b704f7688
commit b1fe170642
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 41 additions and 36 deletions

View File

@ -1,4 +1,4 @@
from typing import Dict, Union
from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
@ -13,23 +13,23 @@ class UNet2DModel(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
sample_size=None,
in_channels=3,
out_channels=3,
center_input_sample=False,
time_embedding_type="positional",
freq_shift=0,
flip_sin_to_cos=True,
down_block_types=("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
up_block_types=("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
block_out_channels=(224, 448, 672, 896),
layers_per_block=2,
mid_block_scale_factor=1,
downsample_padding=1,
act_fn="silu",
attention_head_dim=8,
norm_num_groups=32,
norm_eps=1e-5,
sample_size: Optional[int] = None,
in_channels: int = 3,
out_channels: int = 3,
center_input_sample: bool = False,
time_embedding_type: str = "positional",
freq_shift: int = 0,
flip_sin_to_cos: bool = True,
down_block_types: Tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
block_out_channels: Tuple[int] = (224, 448, 672, 896),
layers_per_block: int = 2,
mid_block_scale_factor: float = 1,
downsample_padding: int = 1,
act_fn: str = "silu",
attention_head_dim: int = 8,
norm_num_groups: int = 32,
norm_eps: float = 1e-5,
):
super().__init__()

View File

@ -1,4 +1,4 @@
from typing import Dict, Union
from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
@ -13,23 +13,28 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
sample_size=None,
in_channels=4,
out_channels=4,
center_input_sample=False,
flip_sin_to_cos=True,
freq_shift=0,
down_block_types=("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"),
up_block_types=("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
block_out_channels=(320, 640, 1280, 1280),
layers_per_block=2,
downsample_padding=1,
mid_block_scale_factor=1,
act_fn="silu",
norm_num_groups=32,
norm_eps=1e-5,
cross_attention_dim=1280,
attention_head_dim=8,
sample_size: Optional[int] = None,
in_channels: int = 4,
out_channels: int = 4,
center_input_sample: bool = False,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
layers_per_block: int = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
act_fn: str = "silu",
norm_num_groups: int = 32,
norm_eps: float = 1e-5,
cross_attention_dim: int = 1280,
attention_head_dim: int = 8,
):
super().__init__()