Merge pull request #59 from huggingface/fuse_final_resnets
[Resnet] Merge final 2D resnet
This commit is contained in:
commit
11667d08d3
|
@ -174,9 +174,7 @@ class Downsample(nn.Module):
|
|||
# return self.conv(x)
|
||||
|
||||
|
||||
# RESNETS
|
||||
|
||||
# unet.py, unet_grad_tts.py, unet_ldm.py, unet_glide.py
|
||||
# unet.py, unet_grad_tts.py, unet_ldm.py, unet_glide.py, unet_score_vde.py
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -187,15 +185,20 @@ class ResnetBlock(nn.Module):
|
|||
dropout=0.0,
|
||||
temb_channels=512,
|
||||
groups=32,
|
||||
groups_out=None,
|
||||
pre_norm=True,
|
||||
eps=1e-6,
|
||||
non_linearity="swish",
|
||||
time_embedding_norm="default",
|
||||
kernel=None,
|
||||
output_scale_factor=1.0,
|
||||
use_nin_shortcut=None,
|
||||
up=False,
|
||||
down=False,
|
||||
overwrite_for_grad_tts=False,
|
||||
overwrite_for_ldm=False,
|
||||
overwrite_for_glide=False,
|
||||
overwrite_for_score_vde=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.pre_norm = pre_norm
|
||||
|
@ -206,6 +209,10 @@ class ResnetBlock(nn.Module):
|
|||
self.time_embedding_norm = time_embedding_norm
|
||||
self.up = up
|
||||
self.down = down
|
||||
self.output_scale_factor = output_scale_factor
|
||||
|
||||
if groups_out is None:
|
||||
groups_out = groups
|
||||
|
||||
if self.pre_norm:
|
||||
self.norm1 = Normalize(in_channels, num_groups=groups, eps=eps)
|
||||
|
@ -219,7 +226,7 @@ class ResnetBlock(nn.Module):
|
|||
elif time_embedding_norm == "scale_shift":
|
||||
self.temb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
|
||||
|
||||
self.norm2 = Normalize(out_channels, num_groups=groups, eps=eps)
|
||||
self.norm2 = Normalize(out_channels, num_groups=groups_out, eps=eps)
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
|
@ -230,14 +237,29 @@ class ResnetBlock(nn.Module):
|
|||
elif non_linearity == "silu":
|
||||
self.nonlinearity = nn.SiLU()
|
||||
|
||||
if up:
|
||||
self.h_upd = Upsample(in_channels, use_conv=False, dims=2)
|
||||
self.x_upd = Upsample(in_channels, use_conv=False, dims=2)
|
||||
elif down:
|
||||
self.h_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
|
||||
self.x_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
|
||||
# if up:
|
||||
# self.h_upd = Upsample(in_channels, use_conv=False, dims=2)
|
||||
# self.x_upd = Upsample(in_channels, use_conv=False, dims=2)
|
||||
# elif down:
|
||||
# self.h_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
|
||||
# self.x_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
self.upsample = self.downsample = None
|
||||
if self.up and kernel == "fir":
|
||||
fir_kernel = (1, 3, 3, 1)
|
||||
self.upsample = lambda x: upsample_2d(x, k=fir_kernel)
|
||||
elif self.up and kernel is None:
|
||||
self.upsample = Upsample(in_channels, use_conv=False, dims=2)
|
||||
elif self.down and kernel == "fir":
|
||||
fir_kernel = (1, 3, 3, 1)
|
||||
self.downsample = lambda x: downsample_2d(x, k=fir_kernel)
|
||||
elif self.down and kernel is None:
|
||||
self.downsample = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
|
||||
|
||||
self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut
|
||||
|
||||
self.nin_shortcut = None
|
||||
if self.use_nin_shortcut:
|
||||
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
# TODO(SURAJ, PATRICK): ALL OF THE FOLLOWING OF THE INIT METHOD CAN BE DELETED ONCE WEIGHTS ARE CONVERTED
|
||||
|
@ -245,6 +267,7 @@ class ResnetBlock(nn.Module):
|
|||
self.overwrite_for_glide = overwrite_for_glide
|
||||
self.overwrite_for_grad_tts = overwrite_for_grad_tts
|
||||
self.overwrite_for_ldm = overwrite_for_ldm or overwrite_for_glide
|
||||
self.overwrite_for_score_vde = overwrite_for_score_vde
|
||||
if self.overwrite_for_grad_tts:
|
||||
dim = in_channels
|
||||
dim_out = out_channels
|
||||
|
@ -260,12 +283,10 @@ class ResnetBlock(nn.Module):
|
|||
self.res_conv = torch.nn.Identity()
|
||||
elif self.overwrite_for_ldm:
|
||||
dims = 2
|
||||
# eps = 1e-5
|
||||
# non_linearity = "silu"
|
||||
# overwrite_for_ldm
|
||||
channels = in_channels
|
||||
emb_channels = temb_channels
|
||||
use_scale_shift_norm = False
|
||||
non_linearity = "silu"
|
||||
|
||||
self.in_layers = nn.Sequential(
|
||||
normalization(channels, swish=1.0),
|
||||
|
@ -289,6 +310,40 @@ class ResnetBlock(nn.Module):
|
|||
self.skip_connection = nn.Identity()
|
||||
else:
|
||||
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
||||
elif self.overwrite_for_score_vde:
|
||||
in_ch = in_channels
|
||||
out_ch = out_channels
|
||||
|
||||
eps = 1e-6
|
||||
num_groups = min(in_ch // 4, 32)
|
||||
num_groups_out = min(out_ch // 4, 32)
|
||||
temb_dim = temb_channels
|
||||
# output_scale_factor = np.sqrt(2.0)
|
||||
# non_linearity = "silu"
|
||||
# use_nin_shortcut = in_channels != out_channels or use_nin_shortcut = True
|
||||
|
||||
self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=in_ch, eps=eps)
|
||||
self.up = up
|
||||
self.down = down
|
||||
self.Conv_0 = conv2d(in_ch, out_ch, kernel_size=3, padding=1)
|
||||
if temb_dim is not None:
|
||||
self.Dense_0 = nn.Linear(temb_dim, out_ch)
|
||||
self.Dense_0.weight.data = variance_scaling()(self.Dense_0.weight.shape)
|
||||
nn.init.zeros_(self.Dense_0.bias)
|
||||
|
||||
self.GroupNorm_1 = nn.GroupNorm(num_groups=num_groups_out, num_channels=out_ch, eps=eps)
|
||||
self.Dropout_0 = nn.Dropout(dropout)
|
||||
self.Conv_1 = conv2d(out_ch, out_ch, init_scale=0.0, kernel_size=3, padding=1)
|
||||
if in_ch != out_ch or up or down:
|
||||
# 1x1 convolution with DDPM initialization.
|
||||
self.Conv_2 = conv2d(in_ch, out_ch, kernel_size=1, padding=0)
|
||||
|
||||
# self.skip_rescale = skip_rescale
|
||||
self.in_ch = in_ch
|
||||
self.out_ch = out_ch
|
||||
|
||||
# TODO(Patrick) - move to main init
|
||||
self.is_overwritten = False
|
||||
|
||||
def set_weights_grad_tts(self):
|
||||
self.conv1.weight.data = self.block1.block[0].weight.data
|
||||
|
@ -328,6 +383,24 @@ class ResnetBlock(nn.Module):
|
|||
self.nin_shortcut.weight.data = self.skip_connection.weight.data
|
||||
self.nin_shortcut.bias.data = self.skip_connection.bias.data
|
||||
|
||||
def set_weights_score_vde(self):
|
||||
self.conv1.weight.data = self.Conv_0.weight.data
|
||||
self.conv1.bias.data = self.Conv_0.bias.data
|
||||
self.norm1.weight.data = self.GroupNorm_0.weight.data
|
||||
self.norm1.bias.data = self.GroupNorm_0.bias.data
|
||||
|
||||
self.conv2.weight.data = self.Conv_1.weight.data
|
||||
self.conv2.bias.data = self.Conv_1.bias.data
|
||||
self.norm2.weight.data = self.GroupNorm_1.weight.data
|
||||
self.norm2.bias.data = self.GroupNorm_1.bias.data
|
||||
|
||||
self.temb_proj.weight.data = self.Dense_0.weight.data
|
||||
self.temb_proj.bias.data = self.Dense_0.bias.data
|
||||
|
||||
if self.in_channels != self.out_channels or self.up or self.down:
|
||||
self.nin_shortcut.weight.data = self.Conv_2.weight.data
|
||||
self.nin_shortcut.bias.data = self.Conv_2.bias.data
|
||||
|
||||
def forward(self, x, temb, mask=1.0):
|
||||
# TODO(Patrick) eventually this class should be split into multiple classes
|
||||
# too many if else statements
|
||||
|
@ -337,6 +410,9 @@ class ResnetBlock(nn.Module):
|
|||
elif self.overwrite_for_ldm and not self.is_overwritten:
|
||||
self.set_weights_ldm()
|
||||
self.is_overwritten = True
|
||||
elif self.overwrite_for_score_vde and not self.is_overwritten:
|
||||
self.set_weights_score_vde()
|
||||
self.is_overwritten = True
|
||||
|
||||
h = x
|
||||
h = h * mask
|
||||
|
@ -344,9 +420,12 @@ class ResnetBlock(nn.Module):
|
|||
h = self.norm1(h)
|
||||
h = self.nonlinearity(h)
|
||||
|
||||
if self.up or self.down:
|
||||
x = self.x_upd(x)
|
||||
h = self.h_upd(h)
|
||||
if self.upsample is not None:
|
||||
x = self.upsample(x)
|
||||
h = self.upsample(h)
|
||||
elif self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
h = self.downsample(h)
|
||||
|
||||
h = self.conv1(h)
|
||||
|
||||
|
@ -379,10 +458,10 @@ class ResnetBlock(nn.Module):
|
|||
h = h * mask
|
||||
|
||||
x = x * mask
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.nin_shortcut is not None:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
return (x + h) / self.output_scale_factor
|
||||
|
||||
|
||||
# TODO(Patrick) - just there to convert the weights; can delete afterward
|
||||
|
|
|
@ -27,7 +27,7 @@ from ..configuration_utils import ConfigMixin
|
|||
from ..modeling_utils import ModelMixin
|
||||
from .attention import AttentionBlock
|
||||
from .embeddings import GaussianFourierProjection, get_timestep_embedding
|
||||
from .resnet import Downsample, ResnetBlockBigGANpp, Upsample, downsample_2d, upfirdn2d, upsample_2d
|
||||
from .resnet import Downsample, ResnetBlock, Upsample, downsample_2d, upfirdn2d, upsample_2d
|
||||
|
||||
|
||||
def _setup_kernel(k):
|
||||
|
@ -276,8 +276,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
|||
skip_rescale=skip_rescale,
|
||||
continuous=continuous,
|
||||
)
|
||||
self.act = act = nn.SiLU()
|
||||
|
||||
self.act = nn.SiLU()
|
||||
self.nf = nf
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_resolutions = attn_resolutions
|
||||
|
@ -333,19 +332,6 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
|||
elif progressive_input == "residual":
|
||||
pyramid_downsample = functools.partial(Down_sample, use_conv=True)
|
||||
|
||||
ResnetBlock = functools.partial(
|
||||
ResnetBlockBigGANpp,
|
||||
act=act,
|
||||
dropout=dropout,
|
||||
fir=fir,
|
||||
fir_kernel=fir_kernel,
|
||||
init_scale=init_scale,
|
||||
skip_rescale=skip_rescale,
|
||||
temb_dim=nf * 4,
|
||||
)
|
||||
|
||||
# Downsampling block
|
||||
|
||||
channels = num_channels
|
||||
if progressive_input != "none":
|
||||
input_pyramid_ch = channels
|
||||
|
@ -358,7 +344,18 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
|||
# Residual blocks for this resolution
|
||||
for i_block in range(num_res_blocks):
|
||||
out_ch = nf * ch_mult[i_level]
|
||||
modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch))
|
||||
modules.append(
|
||||
ResnetBlock(
|
||||
in_channels=in_ch,
|
||||
out_channels=out_ch,
|
||||
temb_channels=4 * nf,
|
||||
output_scale_factor=np.sqrt(2.0),
|
||||
non_linearity="silu",
|
||||
groups=min(in_ch // 4, 32),
|
||||
groups_out=min(out_ch // 4, 32),
|
||||
overwrite_for_score_vde=True,
|
||||
)
|
||||
)
|
||||
in_ch = out_ch
|
||||
|
||||
if all_resolutions[i_level] in attn_resolutions:
|
||||
|
@ -366,7 +363,20 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
|||
hs_c.append(in_ch)
|
||||
|
||||
if i_level != self.num_resolutions - 1:
|
||||
modules.append(ResnetBlock(down=True, in_ch=in_ch))
|
||||
modules.append(
|
||||
ResnetBlock(
|
||||
in_channels=in_ch,
|
||||
temb_channels=4 * nf,
|
||||
output_scale_factor=np.sqrt(2.0),
|
||||
non_linearity="silu",
|
||||
groups=min(in_ch // 4, 32),
|
||||
groups_out=min(out_ch // 4, 32),
|
||||
overwrite_for_score_vde=True,
|
||||
down=True,
|
||||
kernel="fir", # TODO(Patrick) - it seems like both fir and non-fir kernels are fine
|
||||
use_nin_shortcut=True,
|
||||
)
|
||||
)
|
||||
|
||||
if progressive_input == "input_skip":
|
||||
modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch))
|
||||
|
@ -380,16 +390,48 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
|||
hs_c.append(in_ch)
|
||||
|
||||
in_ch = hs_c[-1]
|
||||
modules.append(ResnetBlock(in_ch=in_ch))
|
||||
modules.append(
|
||||
ResnetBlock(
|
||||
in_channels=in_ch,
|
||||
temb_channels=4 * nf,
|
||||
output_scale_factor=np.sqrt(2.0),
|
||||
non_linearity="silu",
|
||||
groups=min(in_ch // 4, 32),
|
||||
groups_out=min(out_ch // 4, 32),
|
||||
overwrite_for_score_vde=True,
|
||||
)
|
||||
)
|
||||
modules.append(AttnBlock(channels=in_ch))
|
||||
modules.append(ResnetBlock(in_ch=in_ch))
|
||||
modules.append(
|
||||
ResnetBlock(
|
||||
in_channels=in_ch,
|
||||
temb_channels=4 * nf,
|
||||
output_scale_factor=np.sqrt(2.0),
|
||||
non_linearity="silu",
|
||||
groups=min(in_ch // 4, 32),
|
||||
groups_out=min(out_ch // 4, 32),
|
||||
overwrite_for_score_vde=True,
|
||||
)
|
||||
)
|
||||
|
||||
pyramid_ch = 0
|
||||
# Upsampling block
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(num_res_blocks + 1):
|
||||
out_ch = nf * ch_mult[i_level]
|
||||
modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch))
|
||||
in_ch = in_ch + hs_c.pop()
|
||||
modules.append(
|
||||
ResnetBlock(
|
||||
in_channels=in_ch,
|
||||
out_channels=out_ch,
|
||||
temb_channels=4 * nf,
|
||||
output_scale_factor=np.sqrt(2.0),
|
||||
non_linearity="silu",
|
||||
groups=min(in_ch // 4, 32),
|
||||
groups_out=min(out_ch // 4, 32),
|
||||
overwrite_for_score_vde=True,
|
||||
)
|
||||
)
|
||||
in_ch = out_ch
|
||||
|
||||
if all_resolutions[i_level] in attn_resolutions:
|
||||
|
@ -421,7 +463,20 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
|||
raise ValueError(f"{progressive} is not a valid name")
|
||||
|
||||
if i_level != 0:
|
||||
modules.append(ResnetBlock(in_ch=in_ch, up=True))
|
||||
modules.append(
|
||||
ResnetBlock(
|
||||
in_channels=in_ch,
|
||||
temb_channels=4 * nf,
|
||||
output_scale_factor=np.sqrt(2.0),
|
||||
non_linearity="silu",
|
||||
groups=min(in_ch // 4, 32),
|
||||
groups_out=min(out_ch // 4, 32),
|
||||
overwrite_for_score_vde=True,
|
||||
up=True,
|
||||
kernel="fir", # TODO(Patrick) - it seems like both fir and non-fir kernels are fine
|
||||
use_nin_shortcut=True,
|
||||
)
|
||||
)
|
||||
|
||||
assert not hs_c
|
||||
|
||||
|
|
Loading…
Reference in New Issue