finish resnet

This commit is contained in:
Patrick von Platen 2022-07-01 15:39:57 +00:00
parent 8d7771d8b0
commit fa7443c899
2 changed files with 92 additions and 25 deletions

View File

@ -380,7 +380,7 @@ class ResnetBlock(nn.Module):
eps=1e-6,
non_linearity="swish",
time_embedding_norm="default",
fir_kernel=(1, 3, 3, 1),
kernel=None,
output_scale_factor=1.0,
use_nin_shortcut=None,
up=False,
@ -433,8 +433,18 @@ class ResnetBlock(nn.Module):
# elif down:
# self.h_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
# self.x_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
self.upsample = Upsample(in_channels, use_conv=False, dims=2) if self.up else None
self.downsample = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op") if self.down else None
self.upsample = self.downsample = None
if self.up and kernel == "fir":
fir_kernel = (1, 3, 3, 1)
self.upsample = lambda x: upsample_2d(x, k=fir_kernel)
elif self.up and kernel is None:
self.upsample = Upsample(in_channels, use_conv=False, dims=2)
elif self.down and kernel == "fir":
fir_kernel = (1, 3, 3, 1)
self.downsample = lambda x: downsample_2d(x, k=fir_kernel)
elif self.down and kernel is None:
self.downsample = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut
@ -505,8 +515,6 @@ class ResnetBlock(nn.Module):
self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=in_ch, eps=eps)
self.up = up
self.down = down
self.fir_kernel = fir_kernel
self.Conv_0 = conv2d(in_ch, out_ch, kernel_size=3, padding=1)
if temb_dim is not None:
self.Dense_0 = nn.Linear(temb_dim, out_ch)
@ -525,11 +533,6 @@ class ResnetBlock(nn.Module):
self.out_ch = out_ch
# TODO(Patrick) - move to main init
if self.up:
self.upsample = functools.partial(upsample_2d, k=self.fir_kernel)
if self.down:
self.downsample = functools.partial(downsample_2d, k=self.fir_kernel)
self.is_overwritten = False
def set_weights_grad_tts(self):

View File

@ -348,7 +348,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
for i_block in range(num_res_blocks):
out_ch = nf * ch_mult[i_level]
# modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch))
modules.append(ResnetNew(
modules.append(
ResnetNew(
in_channels=in_ch,
out_channels=out_ch,
temb_channels=4 * nf,
@ -357,7 +358,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
groups=min(in_ch // 4, 32),
groups_out=min(out_ch // 4, 32),
overwrite_for_score_vde=True,
))
)
)
in_ch = out_ch
if all_resolutions[i_level] in attn_resolutions:
@ -365,7 +367,21 @@ class NCSNpp(ModelMixin, ConfigMixin):
hs_c.append(in_ch)
if i_level != self.num_resolutions - 1:
modules.append(ResnetBlock(down=True, in_ch=in_ch))
# modules.append(ResnetBlock(down=True, in_ch=in_ch))
modules.append(
ResnetNew(
in_channels=in_ch,
temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0),
non_linearity="silu",
groups=min(in_ch // 4, 32),
groups_out=min(out_ch // 4, 32),
overwrite_for_score_vde=True,
down=True,
kernel="fir", # TODO(Patrick) - it seems like both fir and non-fir kernels are fine
use_nin_shortcut=True,
)
)
if progressive_input == "input_skip":
modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch))
@ -379,16 +395,50 @@ class NCSNpp(ModelMixin, ConfigMixin):
hs_c.append(in_ch)
in_ch = hs_c[-1]
modules.append(ResnetBlock(in_ch=in_ch))
# modules.append(ResnetBlock(in_ch=in_ch))
modules.append(
ResnetNew(
in_channels=in_ch,
temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0),
non_linearity="silu",
groups=min(in_ch // 4, 32),
groups_out=min(out_ch // 4, 32),
overwrite_for_score_vde=True,
)
)
modules.append(AttnBlock(channels=in_ch))
modules.append(ResnetBlock(in_ch=in_ch))
# modules.append(ResnetBlock(in_ch=in_ch))
modules.append(
ResnetNew(
in_channels=in_ch,
temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0),
non_linearity="silu",
groups=min(in_ch // 4, 32),
groups_out=min(out_ch // 4, 32),
overwrite_for_score_vde=True,
)
)
pyramid_ch = 0
# Upsampling block
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(num_res_blocks + 1):
out_ch = nf * ch_mult[i_level]
modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch))
# modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch))
modules.append(
ResnetNew(
in_channels=in_ch + hs_c.pop(),
out_channels=out_ch,
temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0),
non_linearity="silu",
groups=min(in_ch // 4, 32),
groups_out=min(out_ch // 4, 32),
overwrite_for_score_vde=True,
)
)
in_ch = out_ch
if all_resolutions[i_level] in attn_resolutions:
@ -420,7 +470,21 @@ class NCSNpp(ModelMixin, ConfigMixin):
raise ValueError(f"{progressive} is not a valid name")
if i_level != 0:
modules.append(ResnetBlock(in_ch=in_ch, up=True))
# modules.append(ResnetBlock(in_ch=in_ch, up=True))
modules.append(
ResnetNew(
in_channels=in_ch,
temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0),
non_linearity="silu",
groups=min(in_ch // 4, 32),
groups_out=min(out_ch // 4, 32),
overwrite_for_score_vde=True,
up=True,
kernel="fir",
use_nin_shortcut=True,
)
)
assert not hs_c