Add resnet_time_scale_shift to VD layers (#1757)
This commit is contained in:
parent
8890758823
commit
dc7cd893fd
|
@ -33,6 +33,7 @@ def get_down_block(
|
||||||
use_linear_projection=False,
|
use_linear_projection=False,
|
||||||
only_cross_attention=False,
|
only_cross_attention=False,
|
||||||
upcast_attention=False,
|
upcast_attention=False,
|
||||||
|
resnet_time_scale_shift="default",
|
||||||
):
|
):
|
||||||
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
|
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
|
||||||
if down_block_type == "DownBlockFlat":
|
if down_block_type == "DownBlockFlat":
|
||||||
|
@ -46,6 +47,7 @@ def get_down_block(
|
||||||
resnet_act_fn=resnet_act_fn,
|
resnet_act_fn=resnet_act_fn,
|
||||||
resnet_groups=resnet_groups,
|
resnet_groups=resnet_groups,
|
||||||
downsample_padding=downsample_padding,
|
downsample_padding=downsample_padding,
|
||||||
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||||
)
|
)
|
||||||
elif down_block_type == "CrossAttnDownBlockFlat":
|
elif down_block_type == "CrossAttnDownBlockFlat":
|
||||||
if cross_attention_dim is None:
|
if cross_attention_dim is None:
|
||||||
|
@ -65,6 +67,7 @@ def get_down_block(
|
||||||
dual_cross_attention=dual_cross_attention,
|
dual_cross_attention=dual_cross_attention,
|
||||||
use_linear_projection=use_linear_projection,
|
use_linear_projection=use_linear_projection,
|
||||||
only_cross_attention=only_cross_attention,
|
only_cross_attention=only_cross_attention,
|
||||||
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||||
)
|
)
|
||||||
raise ValueError(f"{down_block_type} is not supported.")
|
raise ValueError(f"{down_block_type} is not supported.")
|
||||||
|
|
||||||
|
@ -86,6 +89,7 @@ def get_up_block(
|
||||||
use_linear_projection=False,
|
use_linear_projection=False,
|
||||||
only_cross_attention=False,
|
only_cross_attention=False,
|
||||||
upcast_attention=False,
|
upcast_attention=False,
|
||||||
|
resnet_time_scale_shift="default",
|
||||||
):
|
):
|
||||||
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
|
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
|
||||||
if up_block_type == "UpBlockFlat":
|
if up_block_type == "UpBlockFlat":
|
||||||
|
@ -99,6 +103,7 @@ def get_up_block(
|
||||||
resnet_eps=resnet_eps,
|
resnet_eps=resnet_eps,
|
||||||
resnet_act_fn=resnet_act_fn,
|
resnet_act_fn=resnet_act_fn,
|
||||||
resnet_groups=resnet_groups,
|
resnet_groups=resnet_groups,
|
||||||
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||||
)
|
)
|
||||||
elif up_block_type == "CrossAttnUpBlockFlat":
|
elif up_block_type == "CrossAttnUpBlockFlat":
|
||||||
if cross_attention_dim is None:
|
if cross_attention_dim is None:
|
||||||
|
@ -118,6 +123,7 @@ def get_up_block(
|
||||||
dual_cross_attention=dual_cross_attention,
|
dual_cross_attention=dual_cross_attention,
|
||||||
use_linear_projection=use_linear_projection,
|
use_linear_projection=use_linear_projection,
|
||||||
only_cross_attention=only_cross_attention,
|
only_cross_attention=only_cross_attention,
|
||||||
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||||
)
|
)
|
||||||
raise ValueError(f"{up_block_type} is not supported.")
|
raise ValueError(f"{up_block_type} is not supported.")
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue