remove if fir from resent block and upsample, downsample for sde unet
This commit is contained in:
parent
7e0fd19ffe
commit
c9bd4d4338
|
@ -614,19 +614,11 @@ class ResnetBlockBigGANpp(nn.Module):
|
|||
h = self.act(self.GroupNorm_0(x))
|
||||
|
||||
if self.up:
|
||||
if self.fir:
|
||||
h = upsample_2d(h, self.fir_kernel, factor=2)
|
||||
x = upsample_2d(x, self.fir_kernel, factor=2)
|
||||
else:
|
||||
h = naive_upsample_2d(h, factor=2)
|
||||
x = naive_upsample_2d(x, factor=2)
|
||||
h = upsample_2d(h, self.fir_kernel, factor=2)
|
||||
x = upsample_2d(x, self.fir_kernel, factor=2)
|
||||
elif self.down:
|
||||
if self.fir:
|
||||
h = downsample_2d(h, self.fir_kernel, factor=2)
|
||||
x = downsample_2d(x, self.fir_kernel, factor=2)
|
||||
else:
|
||||
h = naive_downsample_2d(h, factor=2)
|
||||
x = naive_downsample_2d(x, factor=2)
|
||||
h = downsample_2d(h, self.fir_kernel, factor=2)
|
||||
x = downsample_2d(x, self.fir_kernel, factor=2)
|
||||
|
||||
h = self.Conv_0(h)
|
||||
# Add bias to each feature map conditioned on the time embedding
|
||||
|
|
|
@ -417,20 +417,16 @@ class Upsample(nn.Module):
|
|||
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, fir_kernel=(1, 3, 3, 1)):
|
||||
super().__init__()
|
||||
out_ch = out_ch if out_ch else in_ch
|
||||
if not fir:
|
||||
if with_conv:
|
||||
self.Conv_0 = conv3x3(in_ch, out_ch)
|
||||
else:
|
||||
if with_conv:
|
||||
self.Conv2d_0 = Conv2d(
|
||||
in_ch,
|
||||
out_ch,
|
||||
kernel=3,
|
||||
up=True,
|
||||
resample_kernel=fir_kernel,
|
||||
use_bias=True,
|
||||
kernel_init=default_init(),
|
||||
)
|
||||
if with_conv:
|
||||
self.Conv2d_0 = Conv2d(
|
||||
in_ch,
|
||||
out_ch,
|
||||
kernel=3,
|
||||
up=True,
|
||||
resample_kernel=fir_kernel,
|
||||
use_bias=True,
|
||||
kernel_init=default_init(),
|
||||
)
|
||||
self.fir = fir
|
||||
self.with_conv = with_conv
|
||||
self.fir_kernel = fir_kernel
|
||||
|
@ -438,15 +434,10 @@ class Upsample(nn.Module):
|
|||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
if not self.fir:
|
||||
h = F.interpolate(x, (H * 2, W * 2), "nearest")
|
||||
if self.with_conv:
|
||||
h = self.Conv_0(h)
|
||||
if not self.with_conv:
|
||||
h = upsample_2d(x, self.fir_kernel, factor=2)
|
||||
else:
|
||||
if not self.with_conv:
|
||||
h = upsample_2d(x, self.fir_kernel, factor=2)
|
||||
else:
|
||||
h = self.Conv2d_0(x)
|
||||
h = self.Conv2d_0(x)
|
||||
|
||||
return h
|
||||
|
||||
|
@ -455,20 +446,16 @@ class Downsample(nn.Module):
|
|||
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, fir_kernel=(1, 3, 3, 1)):
|
||||
super().__init__()
|
||||
out_ch = out_ch if out_ch else in_ch
|
||||
if not fir:
|
||||
if with_conv:
|
||||
self.Conv_0 = conv3x3(in_ch, out_ch, stride=2, padding=0)
|
||||
else:
|
||||
if with_conv:
|
||||
self.Conv2d_0 = Conv2d(
|
||||
in_ch,
|
||||
out_ch,
|
||||
kernel=3,
|
||||
down=True,
|
||||
resample_kernel=fir_kernel,
|
||||
use_bias=True,
|
||||
kernel_init=default_init(),
|
||||
)
|
||||
if with_conv:
|
||||
self.Conv2d_0 = Conv2d(
|
||||
in_ch,
|
||||
out_ch,
|
||||
kernel=3,
|
||||
down=True,
|
||||
resample_kernel=fir_kernel,
|
||||
use_bias=True,
|
||||
kernel_init=default_init(),
|
||||
)
|
||||
self.fir = fir
|
||||
self.fir_kernel = fir_kernel
|
||||
self.with_conv = with_conv
|
||||
|
@ -476,17 +463,10 @@ class Downsample(nn.Module):
|
|||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
if not self.fir:
|
||||
if self.with_conv:
|
||||
x = F.pad(x, (0, 1, 0, 1))
|
||||
x = self.Conv_0(x)
|
||||
else:
|
||||
x = F.avg_pool2d(x, 2, stride=2)
|
||||
if not self.with_conv:
|
||||
x = downsample_2d(x, self.fir_kernel, factor=2)
|
||||
else:
|
||||
if not self.with_conv:
|
||||
x = downsample_2d(x, self.fir_kernel, factor=2)
|
||||
else:
|
||||
x = self.Conv2d_0(x)
|
||||
x = self.Conv2d_0(x)
|
||||
|
||||
return x
|
||||
|
||||
|
|
Loading…
Reference in New Issue