fix the parameter naming in `self.downsamplers` (#1108)

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Chenguo Lin 2022-11-05 01:05:19 +08:00 committed by GitHub
parent 2c108693cc
commit 5b20d3b3d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 7 additions and 7 deletions

View File

@ -462,7 +462,7 @@ class AttnDownBlock2D(nn.Module):
self.downsamplers = nn.ModuleList(
[
Downsample2D(
in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
)
]
)
@ -546,7 +546,7 @@ class CrossAttnDownBlock2D(nn.Module):
self.downsamplers = nn.ModuleList(
[
Downsample2D(
in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
)
]
)
@ -651,7 +651,7 @@ class DownBlock2D(nn.Module):
self.downsamplers = nn.ModuleList(
[
Downsample2D(
in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
)
]
)
@ -729,7 +729,7 @@ class DownEncoderBlock2D(nn.Module):
self.downsamplers = nn.ModuleList(
[
Downsample2D(
in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
)
]
)
@ -801,7 +801,7 @@ class AttnDownEncoderBlock2D(nn.Module):
self.downsamplers = nn.ModuleList(
[
Downsample2D(
in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
)
]
)
@ -886,7 +886,7 @@ class AttnSkipDownBlock2D(nn.Module):
down=True,
kernel="fir",
)
self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)])
self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)])
self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
else:
self.resnet_down = None
@ -966,7 +966,7 @@ class SkipDownBlock2D(nn.Module):
down=True,
kernel="fir",
)
self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)])
self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)])
self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
else:
self.resnet_down = None