make work with first resnet

This commit is contained in:
Patrick von Platen 2022-07-01 15:24:26 +00:00
parent 9da575d63c
commit 8d7771d8b0
2 changed files with 17 additions and 18 deletions

View File

@ -436,10 +436,9 @@ class ResnetBlock(nn.Module):
self.upsample = Upsample(in_channels, use_conv=False, dims=2) if self.up else None
self.downsample = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op") if self.down else None
self.nin_shortcut = use_nin_shortcut
if self.use_nin_shortcut is None:
self.use_nin_shortcut = self.in_channels != self.out_channels
self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut
self.nin_shortcut = None
if self.use_nin_shortcut:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
@ -526,8 +525,10 @@ class ResnetBlock(nn.Module):
self.out_ch = out_ch
# TODO(Patrick) - move to main init
self.upsample = functools.partial(upsample_2d, k=self.fir_kernel)
self.downsample = functools.partial(downsample_2d, k=self.fir_kernel)
if self.up:
self.upsample = functools.partial(upsample_2d, k=self.fir_kernel)
if self.down:
self.downsample = functools.partial(downsample_2d, k=self.fir_kernel)
self.is_overwritten = False
@ -647,7 +648,7 @@ class ResnetBlock(nn.Module):
if self.nin_shortcut is not None:
x = self.nin_shortcut(x)
return x + h
return (x + h) / self.output_scale_factor
# TODO(Patrick) - just there to convert the weights; can delete afterward

View File

@ -348,18 +348,16 @@ class NCSNpp(ModelMixin, ConfigMixin):
for i_block in range(num_res_blocks):
out_ch = nf * ch_mult[i_level]
# modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch))
modules.append(
ResnetNew(
in_channels=in_ch,
out_channels=out_ch,
temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0),
non_linearity="silu",
groups=min(in_ch // 4, 32),
groups_out=min(out_ch // 4, 32),
overwrite_for_score_vde=True,
)
)
modules.append(ResnetNew(
in_channels=in_ch,
out_channels=out_ch,
temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0),
non_linearity="silu",
groups=min(in_ch // 4, 32),
groups_out=min(out_ch // 4, 32),
overwrite_for_score_vde=True,
))
in_ch = out_ch
if all_resolutions[i_level] in attn_resolutions: