quick fix to include non-fir kernels for sde-vp
This commit is contained in:
parent
11667d08d3
commit
dcb9070bc2
|
@ -237,24 +237,23 @@ class ResnetBlock(nn.Module):
|
|||
elif non_linearity == "silu":
|
||||
self.nonlinearity = nn.SiLU()
|
||||
|
||||
# if up:
|
||||
# self.h_upd = Upsample(in_channels, use_conv=False, dims=2)
|
||||
# self.x_upd = Upsample(in_channels, use_conv=False, dims=2)
|
||||
# elif down:
|
||||
# self.h_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
|
||||
# self.x_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
|
||||
|
||||
self.upsample = self.downsample = None
|
||||
if self.up and kernel == "fir":
|
||||
fir_kernel = (1, 3, 3, 1)
|
||||
self.upsample = lambda x: upsample_2d(x, k=fir_kernel)
|
||||
elif self.up and kernel is None:
|
||||
self.upsample = Upsample(in_channels, use_conv=False, dims=2)
|
||||
elif self.down and kernel == "fir":
|
||||
fir_kernel = (1, 3, 3, 1)
|
||||
self.downsample = lambda x: downsample_2d(x, k=fir_kernel)
|
||||
elif self.down and kernel is None:
|
||||
self.downsample = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
|
||||
if self.up:
|
||||
if kernel == "fir":
|
||||
fir_kernel = (1, 3, 3, 1)
|
||||
self.upsample = lambda x: upsample_2d(x, k=fir_kernel)
|
||||
elif kernel == "sde_vp":
|
||||
self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
|
||||
else:
|
||||
self.upsample = Upsample(in_channels, use_conv=False, dims=2)
|
||||
elif self.down:
|
||||
if kernel == "fir":
|
||||
fir_kernel = (1, 3, 3, 1)
|
||||
self.downsample = lambda x: downsample_2d(x, k=fir_kernel)
|
||||
elif kernel == "sde_vp":
|
||||
self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
|
||||
else:
|
||||
self.downsample = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
|
||||
|
||||
self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut
|
||||
|
||||
|
@ -473,87 +472,6 @@ class Block(torch.nn.Module):
|
|||
)
|
||||
|
||||
|
||||
# unet_score_estimation.py
|
||||
class ResnetBlockBigGANpp(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
act,
|
||||
in_ch,
|
||||
out_ch=None,
|
||||
temb_dim=None,
|
||||
up=False,
|
||||
down=False,
|
||||
dropout=0.1,
|
||||
fir=False,
|
||||
fir_kernel=(1, 3, 3, 1),
|
||||
skip_rescale=True,
|
||||
init_scale=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
out_ch = out_ch if out_ch else in_ch
|
||||
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
|
||||
self.up = up
|
||||
self.down = down
|
||||
self.fir = fir
|
||||
self.fir_kernel = fir_kernel
|
||||
|
||||
if self.up:
|
||||
if self.fir:
|
||||
self.upsample = partial(upsample_2d, k=self.fir_kernel, factor=2)
|
||||
else:
|
||||
self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
|
||||
elif self.down:
|
||||
if self.fir:
|
||||
self.downsample = partial(downsample_2d, k=self.fir_kernel, factor=2)
|
||||
else:
|
||||
self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
|
||||
|
||||
self.Conv_0 = conv2d(in_ch, out_ch, kernel_size=3, padding=1)
|
||||
if temb_dim is not None:
|
||||
self.Dense_0 = nn.Linear(temb_dim, out_ch)
|
||||
self.Dense_0.weight.data = variance_scaling()(self.Dense_0.weight.shape)
|
||||
nn.init.zeros_(self.Dense_0.bias)
|
||||
|
||||
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
|
||||
self.Dropout_0 = nn.Dropout(dropout)
|
||||
self.Conv_1 = conv2d(out_ch, out_ch, init_scale=init_scale, kernel_size=3, padding=1)
|
||||
if in_ch != out_ch or up or down:
|
||||
# 1x1 convolution with DDPM initialization.
|
||||
self.Conv_2 = conv2d(in_ch, out_ch, kernel_size=1, padding=0)
|
||||
|
||||
self.skip_rescale = skip_rescale
|
||||
self.act = act
|
||||
self.in_ch = in_ch
|
||||
self.out_ch = out_ch
|
||||
|
||||
def forward(self, x, temb=None):
|
||||
h = self.act(self.GroupNorm_0(x))
|
||||
|
||||
if self.up:
|
||||
h = self.upsample(h)
|
||||
x = self.upsample(x)
|
||||
elif self.down:
|
||||
h = self.downsample(h)
|
||||
x = self.downsample(x)
|
||||
|
||||
h = self.Conv_0(h)
|
||||
# Add bias to each feature map conditioned on the time embedding
|
||||
if temb is not None:
|
||||
h += self.Dense_0(self.act(temb))[:, :, None, None]
|
||||
h = self.act(self.GroupNorm_1(h))
|
||||
h = self.Dropout_0(h)
|
||||
h = self.Conv_1(h)
|
||||
|
||||
if self.in_ch != self.out_ch or self.up or self.down:
|
||||
x = self.Conv_2(x)
|
||||
|
||||
if not self.skip_rescale:
|
||||
return x + h
|
||||
else:
|
||||
return (x + h) / np.sqrt(2.0)
|
||||
|
||||
|
||||
# unet_rl.py
|
||||
class ResidualTemporalBlock(nn.Module):
|
||||
def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
|
||||
|
|
|
@ -373,7 +373,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
|||
groups_out=min(out_ch // 4, 32),
|
||||
overwrite_for_score_vde=True,
|
||||
down=True,
|
||||
kernel="fir", # TODO(Patrick) - it seems like both fir and non-fir kernels are fine
|
||||
kernel="fir" if self.fir else "sde_vp",
|
||||
use_nin_shortcut=True,
|
||||
)
|
||||
)
|
||||
|
@ -473,7 +473,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
|||
groups_out=min(out_ch // 4, 32),
|
||||
overwrite_for_score_vde=True,
|
||||
up=True,
|
||||
kernel="fir", # TODO(Patrick) - it seems like both fir and non-fir kernels are fine
|
||||
kernel="fir" if self.fir else "sde_vp",
|
||||
use_nin_shortcut=True,
|
||||
)
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue