remove if fir from resent block and upsample, downsample for sde unet
This commit is contained in:
parent
7e0fd19ffe
commit
c9bd4d4338
|
@ -614,19 +614,11 @@ class ResnetBlockBigGANpp(nn.Module):
|
||||||
h = self.act(self.GroupNorm_0(x))
|
h = self.act(self.GroupNorm_0(x))
|
||||||
|
|
||||||
if self.up:
|
if self.up:
|
||||||
if self.fir:
|
|
||||||
h = upsample_2d(h, self.fir_kernel, factor=2)
|
h = upsample_2d(h, self.fir_kernel, factor=2)
|
||||||
x = upsample_2d(x, self.fir_kernel, factor=2)
|
x = upsample_2d(x, self.fir_kernel, factor=2)
|
||||||
else:
|
|
||||||
h = naive_upsample_2d(h, factor=2)
|
|
||||||
x = naive_upsample_2d(x, factor=2)
|
|
||||||
elif self.down:
|
elif self.down:
|
||||||
if self.fir:
|
|
||||||
h = downsample_2d(h, self.fir_kernel, factor=2)
|
h = downsample_2d(h, self.fir_kernel, factor=2)
|
||||||
x = downsample_2d(x, self.fir_kernel, factor=2)
|
x = downsample_2d(x, self.fir_kernel, factor=2)
|
||||||
else:
|
|
||||||
h = naive_downsample_2d(h, factor=2)
|
|
||||||
x = naive_downsample_2d(x, factor=2)
|
|
||||||
|
|
||||||
h = self.Conv_0(h)
|
h = self.Conv_0(h)
|
||||||
# Add bias to each feature map conditioned on the time embedding
|
# Add bias to each feature map conditioned on the time embedding
|
||||||
|
|
|
@ -417,10 +417,6 @@ class Upsample(nn.Module):
|
||||||
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, fir_kernel=(1, 3, 3, 1)):
|
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, fir_kernel=(1, 3, 3, 1)):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
out_ch = out_ch if out_ch else in_ch
|
out_ch = out_ch if out_ch else in_ch
|
||||||
if not fir:
|
|
||||||
if with_conv:
|
|
||||||
self.Conv_0 = conv3x3(in_ch, out_ch)
|
|
||||||
else:
|
|
||||||
if with_conv:
|
if with_conv:
|
||||||
self.Conv2d_0 = Conv2d(
|
self.Conv2d_0 = Conv2d(
|
||||||
in_ch,
|
in_ch,
|
||||||
|
@ -438,11 +434,6 @@ class Upsample(nn.Module):
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
B, C, H, W = x.shape
|
B, C, H, W = x.shape
|
||||||
if not self.fir:
|
|
||||||
h = F.interpolate(x, (H * 2, W * 2), "nearest")
|
|
||||||
if self.with_conv:
|
|
||||||
h = self.Conv_0(h)
|
|
||||||
else:
|
|
||||||
if not self.with_conv:
|
if not self.with_conv:
|
||||||
h = upsample_2d(x, self.fir_kernel, factor=2)
|
h = upsample_2d(x, self.fir_kernel, factor=2)
|
||||||
else:
|
else:
|
||||||
|
@ -455,10 +446,6 @@ class Downsample(nn.Module):
|
||||||
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, fir_kernel=(1, 3, 3, 1)):
|
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, fir_kernel=(1, 3, 3, 1)):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
out_ch = out_ch if out_ch else in_ch
|
out_ch = out_ch if out_ch else in_ch
|
||||||
if not fir:
|
|
||||||
if with_conv:
|
|
||||||
self.Conv_0 = conv3x3(in_ch, out_ch, stride=2, padding=0)
|
|
||||||
else:
|
|
||||||
if with_conv:
|
if with_conv:
|
||||||
self.Conv2d_0 = Conv2d(
|
self.Conv2d_0 = Conv2d(
|
||||||
in_ch,
|
in_ch,
|
||||||
|
@ -476,13 +463,6 @@ class Downsample(nn.Module):
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
B, C, H, W = x.shape
|
B, C, H, W = x.shape
|
||||||
if not self.fir:
|
|
||||||
if self.with_conv:
|
|
||||||
x = F.pad(x, (0, 1, 0, 1))
|
|
||||||
x = self.Conv_0(x)
|
|
||||||
else:
|
|
||||||
x = F.avg_pool2d(x, 2, stride=2)
|
|
||||||
else:
|
|
||||||
if not self.with_conv:
|
if not self.with_conv:
|
||||||
x = downsample_2d(x, self.fir_kernel, factor=2)
|
x = downsample_2d(x, self.fir_kernel, factor=2)
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue