remove if fir from resent block and upsample, downsample for sde unet
This commit is contained in:
parent
7e0fd19ffe
commit
c9bd4d4338
|
@ -614,19 +614,11 @@ class ResnetBlockBigGANpp(nn.Module):
|
||||||
h = self.act(self.GroupNorm_0(x))
|
h = self.act(self.GroupNorm_0(x))
|
||||||
|
|
||||||
if self.up:
|
if self.up:
|
||||||
if self.fir:
|
h = upsample_2d(h, self.fir_kernel, factor=2)
|
||||||
h = upsample_2d(h, self.fir_kernel, factor=2)
|
x = upsample_2d(x, self.fir_kernel, factor=2)
|
||||||
x = upsample_2d(x, self.fir_kernel, factor=2)
|
|
||||||
else:
|
|
||||||
h = naive_upsample_2d(h, factor=2)
|
|
||||||
x = naive_upsample_2d(x, factor=2)
|
|
||||||
elif self.down:
|
elif self.down:
|
||||||
if self.fir:
|
h = downsample_2d(h, self.fir_kernel, factor=2)
|
||||||
h = downsample_2d(h, self.fir_kernel, factor=2)
|
x = downsample_2d(x, self.fir_kernel, factor=2)
|
||||||
x = downsample_2d(x, self.fir_kernel, factor=2)
|
|
||||||
else:
|
|
||||||
h = naive_downsample_2d(h, factor=2)
|
|
||||||
x = naive_downsample_2d(x, factor=2)
|
|
||||||
|
|
||||||
h = self.Conv_0(h)
|
h = self.Conv_0(h)
|
||||||
# Add bias to each feature map conditioned on the time embedding
|
# Add bias to each feature map conditioned on the time embedding
|
||||||
|
|
|
@ -417,20 +417,16 @@ class Upsample(nn.Module):
|
||||||
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, fir_kernel=(1, 3, 3, 1)):
|
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, fir_kernel=(1, 3, 3, 1)):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
out_ch = out_ch if out_ch else in_ch
|
out_ch = out_ch if out_ch else in_ch
|
||||||
if not fir:
|
if with_conv:
|
||||||
if with_conv:
|
self.Conv2d_0 = Conv2d(
|
||||||
self.Conv_0 = conv3x3(in_ch, out_ch)
|
in_ch,
|
||||||
else:
|
out_ch,
|
||||||
if with_conv:
|
kernel=3,
|
||||||
self.Conv2d_0 = Conv2d(
|
up=True,
|
||||||
in_ch,
|
resample_kernel=fir_kernel,
|
||||||
out_ch,
|
use_bias=True,
|
||||||
kernel=3,
|
kernel_init=default_init(),
|
||||||
up=True,
|
)
|
||||||
resample_kernel=fir_kernel,
|
|
||||||
use_bias=True,
|
|
||||||
kernel_init=default_init(),
|
|
||||||
)
|
|
||||||
self.fir = fir
|
self.fir = fir
|
||||||
self.with_conv = with_conv
|
self.with_conv = with_conv
|
||||||
self.fir_kernel = fir_kernel
|
self.fir_kernel = fir_kernel
|
||||||
|
@ -438,15 +434,10 @@ class Upsample(nn.Module):
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
B, C, H, W = x.shape
|
B, C, H, W = x.shape
|
||||||
if not self.fir:
|
if not self.with_conv:
|
||||||
h = F.interpolate(x, (H * 2, W * 2), "nearest")
|
h = upsample_2d(x, self.fir_kernel, factor=2)
|
||||||
if self.with_conv:
|
|
||||||
h = self.Conv_0(h)
|
|
||||||
else:
|
else:
|
||||||
if not self.with_conv:
|
h = self.Conv2d_0(x)
|
||||||
h = upsample_2d(x, self.fir_kernel, factor=2)
|
|
||||||
else:
|
|
||||||
h = self.Conv2d_0(x)
|
|
||||||
|
|
||||||
return h
|
return h
|
||||||
|
|
||||||
|
@ -455,20 +446,16 @@ class Downsample(nn.Module):
|
||||||
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, fir_kernel=(1, 3, 3, 1)):
|
def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, fir_kernel=(1, 3, 3, 1)):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
out_ch = out_ch if out_ch else in_ch
|
out_ch = out_ch if out_ch else in_ch
|
||||||
if not fir:
|
if with_conv:
|
||||||
if with_conv:
|
self.Conv2d_0 = Conv2d(
|
||||||
self.Conv_0 = conv3x3(in_ch, out_ch, stride=2, padding=0)
|
in_ch,
|
||||||
else:
|
out_ch,
|
||||||
if with_conv:
|
kernel=3,
|
||||||
self.Conv2d_0 = Conv2d(
|
down=True,
|
||||||
in_ch,
|
resample_kernel=fir_kernel,
|
||||||
out_ch,
|
use_bias=True,
|
||||||
kernel=3,
|
kernel_init=default_init(),
|
||||||
down=True,
|
)
|
||||||
resample_kernel=fir_kernel,
|
|
||||||
use_bias=True,
|
|
||||||
kernel_init=default_init(),
|
|
||||||
)
|
|
||||||
self.fir = fir
|
self.fir = fir
|
||||||
self.fir_kernel = fir_kernel
|
self.fir_kernel = fir_kernel
|
||||||
self.with_conv = with_conv
|
self.with_conv = with_conv
|
||||||
|
@ -476,17 +463,10 @@ class Downsample(nn.Module):
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
B, C, H, W = x.shape
|
B, C, H, W = x.shape
|
||||||
if not self.fir:
|
if not self.with_conv:
|
||||||
if self.with_conv:
|
x = downsample_2d(x, self.fir_kernel, factor=2)
|
||||||
x = F.pad(x, (0, 1, 0, 1))
|
|
||||||
x = self.Conv_0(x)
|
|
||||||
else:
|
|
||||||
x = F.avg_pool2d(x, 2, stride=2)
|
|
||||||
else:
|
else:
|
||||||
if not self.with_conv:
|
x = self.Conv2d_0(x)
|
||||||
x = downsample_2d(x, self.fir_kernel, factor=2)
|
|
||||||
else:
|
|
||||||
x = self.Conv2d_0(x)
|
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue