remove if fir from resent block and upsample, downsample for sde unet

This commit is contained in:
patil-suraj 2022-06-30 11:41:06 +02:00
parent 7e0fd19ffe
commit c9bd4d4338
2 changed files with 30 additions and 58 deletions

View File

@ -614,19 +614,11 @@ class ResnetBlockBigGANpp(nn.Module):
h = self.act(self.GroupNorm_0(x))
if self.up:
if self.fir:
h = upsample_2d(h, self.fir_kernel, factor=2)
x = upsample_2d(x, self.fir_kernel, factor=2)
else:
h = naive_upsample_2d(h, factor=2)
x = naive_upsample_2d(x, factor=2)
h = upsample_2d(h, self.fir_kernel, factor=2)
x = upsample_2d(x, self.fir_kernel, factor=2)
elif self.down:
if self.fir:
h = downsample_2d(h, self.fir_kernel, factor=2)
x = downsample_2d(x, self.fir_kernel, factor=2)
else:
h = naive_downsample_2d(h, factor=2)
x = naive_downsample_2d(x, factor=2)
h = downsample_2d(h, self.fir_kernel, factor=2)
x = downsample_2d(x, self.fir_kernel, factor=2)
h = self.Conv_0(h)
# Add bias to each feature map conditioned on the time embedding

View File

@ -417,20 +417,16 @@ class Upsample(nn.Module):
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, fir_kernel=(1, 3, 3, 1)):
super().__init__()
out_ch = out_ch if out_ch else in_ch
if not fir:
if with_conv:
self.Conv_0 = conv3x3(in_ch, out_ch)
else:
if with_conv:
self.Conv2d_0 = Conv2d(
in_ch,
out_ch,
kernel=3,
up=True,
resample_kernel=fir_kernel,
use_bias=True,
kernel_init=default_init(),
)
if with_conv:
self.Conv2d_0 = Conv2d(
in_ch,
out_ch,
kernel=3,
up=True,
resample_kernel=fir_kernel,
use_bias=True,
kernel_init=default_init(),
)
self.fir = fir
self.with_conv = with_conv
self.fir_kernel = fir_kernel
@ -438,15 +434,10 @@ class Upsample(nn.Module):
def forward(self, x):
B, C, H, W = x.shape
if not self.fir:
h = F.interpolate(x, (H * 2, W * 2), "nearest")
if self.with_conv:
h = self.Conv_0(h)
if not self.with_conv:
h = upsample_2d(x, self.fir_kernel, factor=2)
else:
if not self.with_conv:
h = upsample_2d(x, self.fir_kernel, factor=2)
else:
h = self.Conv2d_0(x)
h = self.Conv2d_0(x)
return h
@ -455,20 +446,16 @@ class Downsample(nn.Module):
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, fir_kernel=(1, 3, 3, 1)):
super().__init__()
out_ch = out_ch if out_ch else in_ch
if not fir:
if with_conv:
self.Conv_0 = conv3x3(in_ch, out_ch, stride=2, padding=0)
else:
if with_conv:
self.Conv2d_0 = Conv2d(
in_ch,
out_ch,
kernel=3,
down=True,
resample_kernel=fir_kernel,
use_bias=True,
kernel_init=default_init(),
)
if with_conv:
self.Conv2d_0 = Conv2d(
in_ch,
out_ch,
kernel=3,
down=True,
resample_kernel=fir_kernel,
use_bias=True,
kernel_init=default_init(),
)
self.fir = fir
self.fir_kernel = fir_kernel
self.with_conv = with_conv
@ -476,17 +463,10 @@ class Downsample(nn.Module):
def forward(self, x):
B, C, H, W = x.shape
if not self.fir:
if self.with_conv:
x = F.pad(x, (0, 1, 0, 1))
x = self.Conv_0(x)
else:
x = F.avg_pool2d(x, 2, stride=2)
if not self.with_conv:
x = downsample_2d(x, self.fir_kernel, factor=2)
else:
if not self.with_conv:
x = downsample_2d(x, self.fir_kernel, factor=2)
else:
x = self.Conv2d_0(x)
x = self.Conv2d_0(x)
return x