Merge pull request #44 from huggingface/unify_resnet

Unify resnet [GradTTS & Unet.py]
This commit is contained in:
Patrick von Platen 2022-06-29 16:37:13 +02:00 committed by GitHub
commit eb90d3be13
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 203 additions and 93 deletions

View File

@ -46,8 +46,8 @@ def conv_transpose_nd(dims, *args, **kwargs):
raise ValueError(f"unsupported dimensions: {dims}") raise ValueError(f"unsupported dimensions: {dims}")
def Normalize(in_channels): def Normalize(in_channels, num_groups=32, eps=1e-6):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=eps, affine=True)
def nonlinearity(x, swish=1.0): def nonlinearity(x, swish=1.0):
@ -166,8 +166,8 @@ class Downsample(nn.Module):
# #
# class GlideUpsample(nn.Module): # class GlideUpsample(nn.Module):
# """ # """
# An upsampling layer with an optional convolution. # # :param channels: channels in the inputs and outputs. :param # An upsampling layer with an optional convolution. # # :param channels: channels in the inputs and outputs. :param #
# use_conv: a bool determining if a convolution is # applied. :param dims: determines if the signal is 1D, 2D, or 3D. If use_conv: a bool determining if a convolution is # applied. :param dims: determines if the signal is 1D, 2D, or 3D. If
# 3D, then # upsampling occurs in the inner-two dimensions. #""" # 3D, then # upsampling occurs in the inner-two dimensions. #"""
# #
# def __init__(self, channels, use_conv, dims=2, out_channels=None): # def __init__(self, channels, use_conv, dims=2, out_channels=None):
@ -192,8 +192,8 @@ class Downsample(nn.Module):
# #
# class LDMUpsample(nn.Module): # class LDMUpsample(nn.Module):
# """ # """
# An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param # # An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param # #
# use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. # If use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. # If
# 3D, then # upsampling occurs in the inner-two dimensions. #""" # 3D, then # upsampling occurs in the inner-two dimensions. #"""
# #
# def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): # def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
@ -340,40 +340,118 @@ class ResBlock(TimestepBlock):
return self.skip_connection(x) + h return self.skip_connection(x) + h
# unet.py # unet.py and unet_grad_tts.py
class ResnetBlock(nn.Module): class ResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512): def __init__(
self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout=0.0,
temb_channels=512,
groups=32,
pre_norm=True,
eps=1e-6,
non_linearity="swish",
overwrite_for_grad_tts=False,
):
super().__init__() super().__init__()
self.pre_norm = pre_norm
self.in_channels = in_channels self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize(in_channels) if self.pre_norm:
self.norm1 = Normalize(in_channels, num_groups=groups, eps=eps)
else:
self.norm1 = Normalize(out_channels, num_groups=groups, eps=eps)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.temb_proj = torch.nn.Linear(temb_channels, out_channels) self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = Normalize(out_channels) self.norm2 = Normalize(out_channels, num_groups=groups, eps=eps)
self.dropout = torch.nn.Dropout(dropout) self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if non_linearity == "swish":
self.nonlinearity = nonlinearity
elif non_linearity == "mish":
self.nonlinearity = Mish()
if self.in_channels != self.out_channels: if self.in_channels != self.out_channels:
if self.use_conv_shortcut: if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
else: else:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x, temb): self.is_overwritten = False
self.overwrite_for_grad_tts = overwrite_for_grad_tts
if self.overwrite_for_grad_tts:
dim = in_channels
dim_out = out_channels
time_emb_dim = temb_channels
self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, dim_out))
self.pre_norm = pre_norm
self.block1 = Block(dim, dim_out, groups=groups)
self.block2 = Block(dim_out, dim_out, groups=groups)
if dim != dim_out:
self.res_conv = torch.nn.Conv2d(dim, dim_out, 1)
else:
self.res_conv = torch.nn.Identity()
def set_weights_grad_tts(self):
self.conv1.weight.data = self.block1.block[0].weight.data
self.conv1.bias.data = self.block1.block[0].bias.data
self.norm1.weight.data = self.block1.block[1].weight.data
self.norm1.bias.data = self.block1.block[1].bias.data
self.conv2.weight.data = self.block2.block[0].weight.data
self.conv2.bias.data = self.block2.block[0].bias.data
self.norm2.weight.data = self.block2.block[1].weight.data
self.norm2.bias.data = self.block2.block[1].bias.data
self.temb_proj.weight.data = self.mlp[1].weight.data
self.temb_proj.bias.data = self.mlp[1].bias.data
if self.in_channels != self.out_channels:
self.nin_shortcut.weight.data = self.res_conv.weight.data
self.nin_shortcut.bias.data = self.res_conv.bias.data
def forward(self, x, temb, mask=None):
if self.overwrite_for_grad_tts and not self.is_overwritten:
self.set_weights_grad_tts()
self.is_overwritten = True
h = x h = x
h = self.norm1(h) h = h * mask if mask is not None else h
h = nonlinearity(h) if self.pre_norm:
h = self.norm1(h)
h = self.nonlinearity(h)
h = self.conv1(h) h = self.conv1(h)
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] if not self.pre_norm:
h = self.norm1(h)
h = self.nonlinearity(h)
h = h * mask if mask is not None else h
h = h + self.temb_proj(self.nonlinearity(temb))[:, :, None, None]
h = h * mask if mask is not None else h
if self.pre_norm:
h = self.norm2(h)
h = self.nonlinearity(h)
h = self.norm2(h)
h = nonlinearity(h)
h = self.dropout(h) h = self.dropout(h)
h = self.conv2(h) h = self.conv2(h)
if not self.pre_norm:
h = self.norm2(h)
h = self.nonlinearity(h)
h = h * mask if mask is not None else h
x = x * mask if mask is not None else x
if self.in_channels != self.out_channels: if self.in_channels != self.out_channels:
if self.use_conv_shortcut: if self.use_conv_shortcut:
x = self.conv_shortcut(x) x = self.conv_shortcut(x)
@ -383,58 +461,17 @@ class ResnetBlock(nn.Module):
return x + h return x + h
# unet_grad_tts.py # TODO(Patrick) - just there to convert the weights; can delete afterward
class ResnetBlockGradTTS(torch.nn.Module): class Block(torch.nn.Module):
def __init__(self, dim, dim_out, time_emb_dim, groups=8): def __init__(self, dim, dim_out, groups=8):
super(ResnetBlockGradTTS, self).__init__() super(Block, self).__init__()
self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, dim_out)) self.block = torch.nn.Sequential(
torch.nn.Conv2d(dim, dim_out, 3, padding=1), torch.nn.GroupNorm(groups, dim_out), Mish()
self.block1 = Block(dim, dim_out, groups=groups)
self.block2 = Block(dim_out, dim_out, groups=groups)
if dim != dim_out:
self.res_conv = torch.nn.Conv2d(dim, dim_out, 1)
else:
self.res_conv = torch.nn.Identity()
def forward(self, x, mask, time_emb):
h = self.block1(x, mask)
h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1)
h = self.block2(h, mask)
output = h + self.res_conv(x * mask)
return output
# unet_rl.py
class ResidualTemporalBlock(nn.Module):
def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
super().__init__()
self.blocks = nn.ModuleList(
[
Conv1dBlock(inp_channels, out_channels, kernel_size),
Conv1dBlock(out_channels, out_channels, kernel_size),
]
) )
self.time_mlp = nn.Sequential( def forward(self, x, mask):
nn.Mish(), output = self.block(x * mask)
nn.Linear(embed_dim, out_channels), return output * mask
RearrangeDim(),
# Rearrange("batch t -> batch t 1"),
)
self.residual_conv = (
nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
)
def forward(self, x, t):
"""
x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x
out_channels x horizon ]
"""
out = self.blocks[0](x) + self.time_mlp(t)
out = self.blocks[1](out)
return out + self.residual_conv(x)
# unet_score_estimation.py # unet_score_estimation.py
@ -570,6 +607,39 @@ class ResnetBlockDDPMpp(nn.Module):
return (x + h) / np.sqrt(2.0) return (x + h) / np.sqrt(2.0)
# unet_rl.py
class ResidualTemporalBlock(nn.Module):
def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
super().__init__()
self.blocks = nn.ModuleList(
[
Conv1dBlock(inp_channels, out_channels, kernel_size),
Conv1dBlock(out_channels, out_channels, kernel_size),
]
)
self.time_mlp = nn.Sequential(
nn.Mish(),
nn.Linear(embed_dim, out_channels),
RearrangeDim(),
# Rearrange("batch t -> batch t 1"),
)
self.residual_conv = (
nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
)
def forward(self, x, t):
"""
x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x
out_channels x horizon ]
"""
out = self.blocks[0](x) + self.time_mlp(t)
out = self.blocks[1](out)
return out + self.residual_conv(x)
# HELPER Modules # HELPER Modules
@ -617,18 +687,6 @@ class Mish(torch.nn.Module):
return x * torch.tanh(torch.nn.functional.softplus(x)) return x * torch.tanh(torch.nn.functional.softplus(x))
class Block(torch.nn.Module):
def __init__(self, dim, dim_out, groups=8):
super(Block, self).__init__()
self.block = torch.nn.Sequential(
torch.nn.Conv2d(dim, dim_out, 3, padding=1), torch.nn.GroupNorm(groups, dim_out), Mish()
)
def forward(self, x, mask):
output = self.block(x * mask)
return output * mask
class Conv1dBlock(nn.Module): class Conv1dBlock(nn.Module):
""" """
Conv1d --> GroupNorm --> Mish Conv1d --> GroupNorm --> Mish

View File

@ -4,9 +4,7 @@ from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from .attention import LinearAttention from .attention import LinearAttention
from .embeddings import get_timestep_embedding from .embeddings import get_timestep_embedding
from .resnet import Downsample from .resnet import Downsample, ResnetBlock, Upsample
from .resnet import ResnetBlockGradTTS as ResnetBlock
from .resnet import Upsample
class Mish(torch.nn.Module): class Mish(torch.nn.Module):
@ -86,8 +84,26 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
self.downs.append( self.downs.append(
torch.nn.ModuleList( torch.nn.ModuleList(
[ [
ResnetBlock(dim_in, dim_out, time_emb_dim=dim), ResnetBlock(
ResnetBlock(dim_out, dim_out, time_emb_dim=dim), in_channels=dim_in,
out_channels=dim_out,
temb_channels=dim,
groups=8,
pre_norm=False,
eps=1e-5,
non_linearity="mish",
overwrite_for_grad_tts=True,
),
ResnetBlock(
in_channels=dim_out,
out_channels=dim_out,
temb_channels=dim,
groups=8,
pre_norm=False,
eps=1e-5,
non_linearity="mish",
overwrite_for_grad_tts=True,
),
Residual(Rezero(LinearAttention(dim_out))), Residual(Rezero(LinearAttention(dim_out))),
Downsample(dim_out, use_conv=True, padding=1) if not is_last else torch.nn.Identity(), Downsample(dim_out, use_conv=True, padding=1) if not is_last else torch.nn.Identity(),
] ]
@ -95,16 +111,52 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
) )
mid_dim = dims[-1] mid_dim = dims[-1]
self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim) self.mid_block1 = ResnetBlock(
in_channels=mid_dim,
out_channels=mid_dim,
temb_channels=dim,
groups=8,
pre_norm=False,
eps=1e-5,
non_linearity="mish",
overwrite_for_grad_tts=True,
)
self.mid_attn = Residual(Rezero(LinearAttention(mid_dim))) self.mid_attn = Residual(Rezero(LinearAttention(mid_dim)))
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim) self.mid_block2 = ResnetBlock(
in_channels=mid_dim,
out_channels=mid_dim,
temb_channels=dim,
groups=8,
pre_norm=False,
eps=1e-5,
non_linearity="mish",
overwrite_for_grad_tts=True,
)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
self.ups.append( self.ups.append(
torch.nn.ModuleList( torch.nn.ModuleList(
[ [
ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim), ResnetBlock(
ResnetBlock(dim_in, dim_in, time_emb_dim=dim), in_channels=dim_out * 2,
out_channels=dim_in,
temb_channels=dim,
groups=8,
pre_norm=False,
eps=1e-5,
non_linearity="mish",
overwrite_for_grad_tts=True,
),
ResnetBlock(
in_channels=dim_in,
out_channels=dim_in,
temb_channels=dim,
groups=8,
pre_norm=False,
eps=1e-5,
non_linearity="mish",
overwrite_for_grad_tts=True,
),
Residual(Rezero(LinearAttention(dim_in))), Residual(Rezero(LinearAttention(dim_in))),
Upsample(dim_in, use_conv_transpose=True), Upsample(dim_in, use_conv_transpose=True),
] ]
@ -135,8 +187,8 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
masks = [mask] masks = [mask]
for resnet1, resnet2, attn, downsample in self.downs: for resnet1, resnet2, attn, downsample in self.downs:
mask_down = masks[-1] mask_down = masks[-1]
x = resnet1(x, mask_down, t) x = resnet1(x, t, mask_down)
x = resnet2(x, mask_down, t) x = resnet2(x, t, mask_down)
x = attn(x) x = attn(x)
hiddens.append(x) hiddens.append(x)
x = downsample(x * mask_down) x = downsample(x * mask_down)
@ -144,15 +196,15 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
masks = masks[:-1] masks = masks[:-1]
mask_mid = masks[-1] mask_mid = masks[-1]
x = self.mid_block1(x, mask_mid, t) x = self.mid_block1(x, t, mask_mid)
x = self.mid_attn(x) x = self.mid_attn(x)
x = self.mid_block2(x, mask_mid, t) x = self.mid_block2(x, t, mask_mid)
for resnet1, resnet2, attn, upsample in self.ups: for resnet1, resnet2, attn, upsample in self.ups:
mask_up = masks.pop() mask_up = masks.pop()
x = torch.cat((x, hiddens.pop()), dim=1) x = torch.cat((x, hiddens.pop()), dim=1)
x = resnet1(x, mask_up, t) x = resnet1(x, t, mask_up)
x = resnet2(x, mask_up, t) x = resnet2(x, t, mask_up)
x = attn(x) x = attn(x)
x = upsample(x * mask_up) x = upsample(x * mask_up)