make work with first resnet
This commit is contained in:
parent
9da575d63c
commit
8d7771d8b0
|
@ -436,10 +436,9 @@ class ResnetBlock(nn.Module):
|
|||
self.upsample = Upsample(in_channels, use_conv=False, dims=2) if self.up else None
|
||||
self.downsample = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op") if self.down else None
|
||||
|
||||
self.nin_shortcut = use_nin_shortcut
|
||||
if self.use_nin_shortcut is None:
|
||||
self.use_nin_shortcut = self.in_channels != self.out_channels
|
||||
self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut
|
||||
|
||||
self.nin_shortcut = None
|
||||
if self.use_nin_shortcut:
|
||||
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
|
@ -526,8 +525,10 @@ class ResnetBlock(nn.Module):
|
|||
self.out_ch = out_ch
|
||||
|
||||
# TODO(Patrick) - move to main init
|
||||
self.upsample = functools.partial(upsample_2d, k=self.fir_kernel)
|
||||
self.downsample = functools.partial(downsample_2d, k=self.fir_kernel)
|
||||
if self.up:
|
||||
self.upsample = functools.partial(upsample_2d, k=self.fir_kernel)
|
||||
if self.down:
|
||||
self.downsample = functools.partial(downsample_2d, k=self.fir_kernel)
|
||||
|
||||
self.is_overwritten = False
|
||||
|
||||
|
@ -647,7 +648,7 @@ class ResnetBlock(nn.Module):
|
|||
if self.nin_shortcut is not None:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
return (x + h) / self.output_scale_factor
|
||||
|
||||
|
||||
# TODO(Patrick) - just there to convert the weights; can delete afterward
|
||||
|
|
|
@ -348,18 +348,16 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
|||
for i_block in range(num_res_blocks):
|
||||
out_ch = nf * ch_mult[i_level]
|
||||
# modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch))
|
||||
modules.append(
|
||||
ResnetNew(
|
||||
in_channels=in_ch,
|
||||
out_channels=out_ch,
|
||||
temb_channels=4 * nf,
|
||||
output_scale_factor=np.sqrt(2.0),
|
||||
non_linearity="silu",
|
||||
groups=min(in_ch // 4, 32),
|
||||
groups_out=min(out_ch // 4, 32),
|
||||
overwrite_for_score_vde=True,
|
||||
)
|
||||
)
|
||||
modules.append(ResnetNew(
|
||||
in_channels=in_ch,
|
||||
out_channels=out_ch,
|
||||
temb_channels=4 * nf,
|
||||
output_scale_factor=np.sqrt(2.0),
|
||||
non_linearity="silu",
|
||||
groups=min(in_ch // 4, 32),
|
||||
groups_out=min(out_ch // 4, 32),
|
||||
overwrite_for_score_vde=True,
|
||||
))
|
||||
in_ch = out_ch
|
||||
|
||||
if all_resolutions[i_level] in attn_resolutions:
|
||||
|
|
Loading…
Reference in New Issue