Merge pull request #59 from huggingface/fuse_final_resnets

[Resnet] Merge final 2D resnet
This commit is contained in:
Patrick von Platen 2022-07-01 19:32:36 +02:00 committed by GitHub
commit 11667d08d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 175 additions and 41 deletions

View File

@ -174,9 +174,7 @@ class Downsample(nn.Module):
# return self.conv(x)
# RESNETS
# unet.py, unet_grad_tts.py, unet_ldm.py, unet_glide.py
# unet.py, unet_grad_tts.py, unet_ldm.py, unet_glide.py, unet_score_vde.py
class ResnetBlock(nn.Module):
def __init__(
self,
@ -187,15 +185,20 @@ class ResnetBlock(nn.Module):
dropout=0.0,
temb_channels=512,
groups=32,
groups_out=None,
pre_norm=True,
eps=1e-6,
non_linearity="swish",
time_embedding_norm="default",
kernel=None,
output_scale_factor=1.0,
use_nin_shortcut=None,
up=False,
down=False,
overwrite_for_grad_tts=False,
overwrite_for_ldm=False,
overwrite_for_glide=False,
overwrite_for_score_vde=False,
):
super().__init__()
self.pre_norm = pre_norm
@ -206,6 +209,10 @@ class ResnetBlock(nn.Module):
self.time_embedding_norm = time_embedding_norm
self.up = up
self.down = down
self.output_scale_factor = output_scale_factor
if groups_out is None:
groups_out = groups
if self.pre_norm:
self.norm1 = Normalize(in_channels, num_groups=groups, eps=eps)
@ -219,7 +226,7 @@ class ResnetBlock(nn.Module):
elif time_embedding_norm == "scale_shift":
self.temb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
self.norm2 = Normalize(out_channels, num_groups=groups, eps=eps)
self.norm2 = Normalize(out_channels, num_groups=groups_out, eps=eps)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
@ -230,14 +237,29 @@ class ResnetBlock(nn.Module):
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
if up:
self.h_upd = Upsample(in_channels, use_conv=False, dims=2)
self.x_upd = Upsample(in_channels, use_conv=False, dims=2)
elif down:
self.h_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
self.x_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
# if up:
# self.h_upd = Upsample(in_channels, use_conv=False, dims=2)
# self.x_upd = Upsample(in_channels, use_conv=False, dims=2)
# elif down:
# self.h_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
# self.x_upd = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
if self.in_channels != self.out_channels:
self.upsample = self.downsample = None
if self.up and kernel == "fir":
fir_kernel = (1, 3, 3, 1)
self.upsample = lambda x: upsample_2d(x, k=fir_kernel)
elif self.up and kernel is None:
self.upsample = Upsample(in_channels, use_conv=False, dims=2)
elif self.down and kernel == "fir":
fir_kernel = (1, 3, 3, 1)
self.downsample = lambda x: downsample_2d(x, k=fir_kernel)
elif self.down and kernel is None:
self.downsample = Downsample(in_channels, use_conv=False, dims=2, padding=1, name="op")
self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut
self.nin_shortcut = None
if self.use_nin_shortcut:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
# TODO(SURAJ, PATRICK): ALL OF THE FOLLOWING OF THE INIT METHOD CAN BE DELETED ONCE WEIGHTS ARE CONVERTED
@ -245,6 +267,7 @@ class ResnetBlock(nn.Module):
self.overwrite_for_glide = overwrite_for_glide
self.overwrite_for_grad_tts = overwrite_for_grad_tts
self.overwrite_for_ldm = overwrite_for_ldm or overwrite_for_glide
self.overwrite_for_score_vde = overwrite_for_score_vde
if self.overwrite_for_grad_tts:
dim = in_channels
dim_out = out_channels
@ -260,12 +283,10 @@ class ResnetBlock(nn.Module):
self.res_conv = torch.nn.Identity()
elif self.overwrite_for_ldm:
dims = 2
# eps = 1e-5
# non_linearity = "silu"
# overwrite_for_ldm
channels = in_channels
emb_channels = temb_channels
use_scale_shift_norm = False
non_linearity = "silu"
self.in_layers = nn.Sequential(
normalization(channels, swish=1.0),
@ -289,6 +310,40 @@ class ResnetBlock(nn.Module):
self.skip_connection = nn.Identity()
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
elif self.overwrite_for_score_vde:
in_ch = in_channels
out_ch = out_channels
eps = 1e-6
num_groups = min(in_ch // 4, 32)
num_groups_out = min(out_ch // 4, 32)
temb_dim = temb_channels
# output_scale_factor = np.sqrt(2.0)
# non_linearity = "silu"
# use_nin_shortcut = in_channels != out_channels or use_nin_shortcut = True
self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=in_ch, eps=eps)
self.up = up
self.down = down
self.Conv_0 = conv2d(in_ch, out_ch, kernel_size=3, padding=1)
if temb_dim is not None:
self.Dense_0 = nn.Linear(temb_dim, out_ch)
self.Dense_0.weight.data = variance_scaling()(self.Dense_0.weight.shape)
nn.init.zeros_(self.Dense_0.bias)
self.GroupNorm_1 = nn.GroupNorm(num_groups=num_groups_out, num_channels=out_ch, eps=eps)
self.Dropout_0 = nn.Dropout(dropout)
self.Conv_1 = conv2d(out_ch, out_ch, init_scale=0.0, kernel_size=3, padding=1)
if in_ch != out_ch or up or down:
# 1x1 convolution with DDPM initialization.
self.Conv_2 = conv2d(in_ch, out_ch, kernel_size=1, padding=0)
# self.skip_rescale = skip_rescale
self.in_ch = in_ch
self.out_ch = out_ch
# TODO(Patrick) - move to main init
self.is_overwritten = False
def set_weights_grad_tts(self):
self.conv1.weight.data = self.block1.block[0].weight.data
@ -328,6 +383,24 @@ class ResnetBlock(nn.Module):
self.nin_shortcut.weight.data = self.skip_connection.weight.data
self.nin_shortcut.bias.data = self.skip_connection.bias.data
def set_weights_score_vde(self):
self.conv1.weight.data = self.Conv_0.weight.data
self.conv1.bias.data = self.Conv_0.bias.data
self.norm1.weight.data = self.GroupNorm_0.weight.data
self.norm1.bias.data = self.GroupNorm_0.bias.data
self.conv2.weight.data = self.Conv_1.weight.data
self.conv2.bias.data = self.Conv_1.bias.data
self.norm2.weight.data = self.GroupNorm_1.weight.data
self.norm2.bias.data = self.GroupNorm_1.bias.data
self.temb_proj.weight.data = self.Dense_0.weight.data
self.temb_proj.bias.data = self.Dense_0.bias.data
if self.in_channels != self.out_channels or self.up or self.down:
self.nin_shortcut.weight.data = self.Conv_2.weight.data
self.nin_shortcut.bias.data = self.Conv_2.bias.data
def forward(self, x, temb, mask=1.0):
# TODO(Patrick) eventually this class should be split into multiple classes
# too many if else statements
@ -337,6 +410,9 @@ class ResnetBlock(nn.Module):
elif self.overwrite_for_ldm and not self.is_overwritten:
self.set_weights_ldm()
self.is_overwritten = True
elif self.overwrite_for_score_vde and not self.is_overwritten:
self.set_weights_score_vde()
self.is_overwritten = True
h = x
h = h * mask
@ -344,9 +420,12 @@ class ResnetBlock(nn.Module):
h = self.norm1(h)
h = self.nonlinearity(h)
if self.up or self.down:
x = self.x_upd(x)
h = self.h_upd(h)
if self.upsample is not None:
x = self.upsample(x)
h = self.upsample(h)
elif self.downsample is not None:
x = self.downsample(x)
h = self.downsample(h)
h = self.conv1(h)
@ -379,10 +458,10 @@ class ResnetBlock(nn.Module):
h = h * mask
x = x * mask
if self.in_channels != self.out_channels:
if self.nin_shortcut is not None:
x = self.nin_shortcut(x)
return x + h
return (x + h) / self.output_scale_factor
# TODO(Patrick) - just there to convert the weights; can delete afterward

View File

@ -27,7 +27,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .attention import AttentionBlock
from .embeddings import GaussianFourierProjection, get_timestep_embedding
from .resnet import Downsample, ResnetBlockBigGANpp, Upsample, downsample_2d, upfirdn2d, upsample_2d
from .resnet import Downsample, ResnetBlock, Upsample, downsample_2d, upfirdn2d, upsample_2d
def _setup_kernel(k):
@ -276,8 +276,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
skip_rescale=skip_rescale,
continuous=continuous,
)
self.act = act = nn.SiLU()
self.act = nn.SiLU()
self.nf = nf
self.num_res_blocks = num_res_blocks
self.attn_resolutions = attn_resolutions
@ -333,19 +332,6 @@ class NCSNpp(ModelMixin, ConfigMixin):
elif progressive_input == "residual":
pyramid_downsample = functools.partial(Down_sample, use_conv=True)
ResnetBlock = functools.partial(
ResnetBlockBigGANpp,
act=act,
dropout=dropout,
fir=fir,
fir_kernel=fir_kernel,
init_scale=init_scale,
skip_rescale=skip_rescale,
temb_dim=nf * 4,
)
# Downsampling block
channels = num_channels
if progressive_input != "none":
input_pyramid_ch = channels
@ -358,7 +344,18 @@ class NCSNpp(ModelMixin, ConfigMixin):
# Residual blocks for this resolution
for i_block in range(num_res_blocks):
out_ch = nf * ch_mult[i_level]
modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch))
modules.append(
ResnetBlock(
in_channels=in_ch,
out_channels=out_ch,
temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0),
non_linearity="silu",
groups=min(in_ch // 4, 32),
groups_out=min(out_ch // 4, 32),
overwrite_for_score_vde=True,
)
)
in_ch = out_ch
if all_resolutions[i_level] in attn_resolutions:
@ -366,7 +363,20 @@ class NCSNpp(ModelMixin, ConfigMixin):
hs_c.append(in_ch)
if i_level != self.num_resolutions - 1:
modules.append(ResnetBlock(down=True, in_ch=in_ch))
modules.append(
ResnetBlock(
in_channels=in_ch,
temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0),
non_linearity="silu",
groups=min(in_ch // 4, 32),
groups_out=min(out_ch // 4, 32),
overwrite_for_score_vde=True,
down=True,
kernel="fir", # TODO(Patrick) - it seems like both fir and non-fir kernels are fine
use_nin_shortcut=True,
)
)
if progressive_input == "input_skip":
modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch))
@ -380,16 +390,48 @@ class NCSNpp(ModelMixin, ConfigMixin):
hs_c.append(in_ch)
in_ch = hs_c[-1]
modules.append(ResnetBlock(in_ch=in_ch))
modules.append(
ResnetBlock(
in_channels=in_ch,
temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0),
non_linearity="silu",
groups=min(in_ch // 4, 32),
groups_out=min(out_ch // 4, 32),
overwrite_for_score_vde=True,
)
)
modules.append(AttnBlock(channels=in_ch))
modules.append(ResnetBlock(in_ch=in_ch))
modules.append(
ResnetBlock(
in_channels=in_ch,
temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0),
non_linearity="silu",
groups=min(in_ch // 4, 32),
groups_out=min(out_ch // 4, 32),
overwrite_for_score_vde=True,
)
)
pyramid_ch = 0
# Upsampling block
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(num_res_blocks + 1):
out_ch = nf * ch_mult[i_level]
modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch))
in_ch = in_ch + hs_c.pop()
modules.append(
ResnetBlock(
in_channels=in_ch,
out_channels=out_ch,
temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0),
non_linearity="silu",
groups=min(in_ch // 4, 32),
groups_out=min(out_ch // 4, 32),
overwrite_for_score_vde=True,
)
)
in_ch = out_ch
if all_resolutions[i_level] in attn_resolutions:
@ -421,7 +463,20 @@ class NCSNpp(ModelMixin, ConfigMixin):
raise ValueError(f"{progressive} is not a valid name")
if i_level != 0:
modules.append(ResnetBlock(in_ch=in_ch, up=True))
modules.append(
ResnetBlock(
in_channels=in_ch,
temb_channels=4 * nf,
output_scale_factor=np.sqrt(2.0),
non_linearity="silu",
groups=min(in_ch // 4, 32),
groups_out=min(out_ch // 4, 32),
overwrite_for_score_vde=True,
up=True,
kernel="fir", # TODO(Patrick) - it seems like both fir and non-fir kernels are fine
use_nin_shortcut=True,
)
)
assert not hs_c