finish resnet
This commit is contained in:
parent
8d7771d8b0
commit
fa7443c899
|
@ -380,7 +380,7 @@ class ResnetBlock(nn.Module):
|
|||
eps=1e-6,
|
||||
non_linearity="swish",
|
||||
time_embedding_norm="default",
|
||||
fir_kernel=(1, 3, 3, 1),
|
||||
kernel=None,
|
||||
output_scale_factor=1.0,
|
||||
use_nin_shortcut=None,
|
||||
up=False,
|
||||
|
@ -433,8 +433,18 @@ class ResnetBlock(nn.Module):
|
|||
# elif down:
|
||||
# self.h_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
|
||||
# self.x_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
|
||||
self.upsample = Upsample(in_channels, use_conv=False, dims=2) if self.up else None
|
||||
self.downsample = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op") if self.down else None
|
||||
|
||||
self.upsample = self.downsample = None
|
||||
if self.up and kernel == "fir":
|
||||
fir_kernel = (1, 3, 3, 1)
|
||||
self.upsample = lambda x: upsample_2d(x, k=fir_kernel)
|
||||
elif self.up and kernel is None:
|
||||
self.upsample = Upsample(in_channels, use_conv=False, dims=2)
|
||||
elif self.down and kernel == "fir":
|
||||
fir_kernel = (1, 3, 3, 1)
|
||||
self.downsample = lambda x: downsample_2d(x, k=fir_kernel)
|
||||
elif self.down and kernel is None:
|
||||
self.downsample = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
|
||||
|
||||
self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut
|
||||
|
||||
|
@ -505,8 +515,6 @@ class ResnetBlock(nn.Module):
|
|||
self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=in_ch, eps=eps)
|
||||
self.up = up
|
||||
self.down = down
|
||||
self.fir_kernel = fir_kernel
|
||||
|
||||
self.Conv_0 = conv2d(in_ch, out_ch, kernel_size=3, padding=1)
|
||||
if temb_dim is not None:
|
||||
self.Dense_0 = nn.Linear(temb_dim, out_ch)
|
||||
|
@ -525,11 +533,6 @@ class ResnetBlock(nn.Module):
|
|||
self.out_ch = out_ch
|
||||
|
||||
# TODO(Patrick) - move to main init
|
||||
if self.up:
|
||||
self.upsample = functools.partial(upsample_2d, k=self.fir_kernel)
|
||||
if self.down:
|
||||
self.downsample = functools.partial(downsample_2d, k=self.fir_kernel)
|
||||
|
||||
self.is_overwritten = False
|
||||
|
||||
def set_weights_grad_tts(self):
|
||||
|
|
|
@ -348,7 +348,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
|||
for i_block in range(num_res_blocks):
|
||||
out_ch = nf * ch_mult[i_level]
|
||||
# modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch))
|
||||
modules.append(ResnetNew(
|
||||
modules.append(
|
||||
ResnetNew(
|
||||
in_channels=in_ch,
|
||||
out_channels=out_ch,
|
||||
temb_channels=4 * nf,
|
||||
|
@ -357,7 +358,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
|||
groups=min(in_ch // 4, 32),
|
||||
groups_out=min(out_ch // 4, 32),
|
||||
overwrite_for_score_vde=True,
|
||||
))
|
||||
)
|
||||
)
|
||||
in_ch = out_ch
|
||||
|
||||
if all_resolutions[i_level] in attn_resolutions:
|
||||
|
@ -365,7 +367,21 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
|||
hs_c.append(in_ch)
|
||||
|
||||
if i_level != self.num_resolutions - 1:
|
||||
modules.append(ResnetBlock(down=True, in_ch=in_ch))
|
||||
# modules.append(ResnetBlock(down=True, in_ch=in_ch))
|
||||
modules.append(
|
||||
ResnetNew(
|
||||
in_channels=in_ch,
|
||||
temb_channels=4 * nf,
|
||||
output_scale_factor=np.sqrt(2.0),
|
||||
non_linearity="silu",
|
||||
groups=min(in_ch // 4, 32),
|
||||
groups_out=min(out_ch // 4, 32),
|
||||
overwrite_for_score_vde=True,
|
||||
down=True,
|
||||
kernel="fir", # TODO(Patrick) - it seems like both fir and non-fir kernels are fine
|
||||
use_nin_shortcut=True,
|
||||
)
|
||||
)
|
||||
|
||||
if progressive_input == "input_skip":
|
||||
modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch))
|
||||
|
@ -379,16 +395,50 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
|||
hs_c.append(in_ch)
|
||||
|
||||
in_ch = hs_c[-1]
|
||||
modules.append(ResnetBlock(in_ch=in_ch))
|
||||
# modules.append(ResnetBlock(in_ch=in_ch))
|
||||
modules.append(
|
||||
ResnetNew(
|
||||
in_channels=in_ch,
|
||||
temb_channels=4 * nf,
|
||||
output_scale_factor=np.sqrt(2.0),
|
||||
non_linearity="silu",
|
||||
groups=min(in_ch // 4, 32),
|
||||
groups_out=min(out_ch // 4, 32),
|
||||
overwrite_for_score_vde=True,
|
||||
)
|
||||
)
|
||||
modules.append(AttnBlock(channels=in_ch))
|
||||
modules.append(ResnetBlock(in_ch=in_ch))
|
||||
# modules.append(ResnetBlock(in_ch=in_ch))
|
||||
modules.append(
|
||||
ResnetNew(
|
||||
in_channels=in_ch,
|
||||
temb_channels=4 * nf,
|
||||
output_scale_factor=np.sqrt(2.0),
|
||||
non_linearity="silu",
|
||||
groups=min(in_ch // 4, 32),
|
||||
groups_out=min(out_ch // 4, 32),
|
||||
overwrite_for_score_vde=True,
|
||||
)
|
||||
)
|
||||
|
||||
pyramid_ch = 0
|
||||
# Upsampling block
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(num_res_blocks + 1):
|
||||
out_ch = nf * ch_mult[i_level]
|
||||
modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch))
|
||||
# modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch))
|
||||
modules.append(
|
||||
ResnetNew(
|
||||
in_channels=in_ch + hs_c.pop(),
|
||||
out_channels=out_ch,
|
||||
temb_channels=4 * nf,
|
||||
output_scale_factor=np.sqrt(2.0),
|
||||
non_linearity="silu",
|
||||
groups=min(in_ch // 4, 32),
|
||||
groups_out=min(out_ch // 4, 32),
|
||||
overwrite_for_score_vde=True,
|
||||
)
|
||||
)
|
||||
in_ch = out_ch
|
||||
|
||||
if all_resolutions[i_level] in attn_resolutions:
|
||||
|
@ -420,7 +470,21 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
|||
raise ValueError(f"{progressive} is not a valid name")
|
||||
|
||||
if i_level != 0:
|
||||
modules.append(ResnetBlock(in_ch=in_ch, up=True))
|
||||
# modules.append(ResnetBlock(in_ch=in_ch, up=True))
|
||||
modules.append(
|
||||
ResnetNew(
|
||||
in_channels=in_ch,
|
||||
temb_channels=4 * nf,
|
||||
output_scale_factor=np.sqrt(2.0),
|
||||
non_linearity="silu",
|
||||
groups=min(in_ch // 4, 32),
|
||||
groups_out=min(out_ch // 4, 32),
|
||||
overwrite_for_score_vde=True,
|
||||
up=True,
|
||||
kernel="fir",
|
||||
use_nin_shortcut=True,
|
||||
)
|
||||
)
|
||||
|
||||
assert not hs_c
|
||||
|
||||
|
|
Loading…
Reference in New Issue