Merge pull request #44 from huggingface/unify_resnet
Unify resnet [GradTTS & Unet.py]
This commit is contained in:
commit
eb90d3be13
|
@ -46,8 +46,8 @@ def conv_transpose_nd(dims, *args, **kwargs):
|
|||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
def Normalize(in_channels):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
def Normalize(in_channels, num_groups=32, eps=1e-6):
|
||||
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=eps, affine=True)
|
||||
|
||||
|
||||
def nonlinearity(x, swish=1.0):
|
||||
|
@ -166,8 +166,8 @@ class Downsample(nn.Module):
|
|||
#
|
||||
# class GlideUpsample(nn.Module):
|
||||
# """
|
||||
# An upsampling layer with an optional convolution. # # :param channels: channels in the inputs and outputs. :param
|
||||
# use_conv: a bool determining if a convolution is # applied. :param dims: determines if the signal is 1D, 2D, or 3D. If
|
||||
# An upsampling layer with an optional convolution. # # :param channels: channels in the inputs and outputs. :param #
|
||||
use_conv: a bool determining if a convolution is # applied. :param dims: determines if the signal is 1D, 2D, or 3D. If
|
||||
# 3D, then # upsampling occurs in the inner-two dimensions. #"""
|
||||
#
|
||||
# def __init__(self, channels, use_conv, dims=2, out_channels=None):
|
||||
|
@ -192,8 +192,8 @@ class Downsample(nn.Module):
|
|||
#
|
||||
# class LDMUpsample(nn.Module):
|
||||
# """
|
||||
# An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param #
|
||||
# use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. # If
|
||||
# An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param # #
|
||||
use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. # If
|
||||
# 3D, then # upsampling occurs in the inner-two dimensions. #"""
|
||||
#
|
||||
# def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
||||
|
@ -340,40 +340,118 @@ class ResBlock(TimestepBlock):
|
|||
return self.skip_connection(x) + h
|
||||
|
||||
|
||||
# unet.py
|
||||
# unet.py and unet_grad_tts.py
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
in_channels,
|
||||
out_channels=None,
|
||||
conv_shortcut=False,
|
||||
dropout=0.0,
|
||||
temb_channels=512,
|
||||
groups=32,
|
||||
pre_norm=True,
|
||||
eps=1e-6,
|
||||
non_linearity="swish",
|
||||
overwrite_for_grad_tts=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.pre_norm = pre_norm
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
|
||||
self.norm1 = Normalize(in_channels)
|
||||
if self.pre_norm:
|
||||
self.norm1 = Normalize(in_channels, num_groups=groups, eps=eps)
|
||||
else:
|
||||
self.norm1 = Normalize(out_channels, num_groups=groups, eps=eps)
|
||||
|
||||
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
||||
self.norm2 = Normalize(out_channels)
|
||||
self.norm2 = Normalize(out_channels, num_groups=groups, eps=eps)
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if non_linearity == "swish":
|
||||
self.nonlinearity = nonlinearity
|
||||
elif non_linearity == "mish":
|
||||
self.nonlinearity = Mish()
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
else:
|
||||
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x, temb):
|
||||
self.is_overwritten = False
|
||||
self.overwrite_for_grad_tts = overwrite_for_grad_tts
|
||||
if self.overwrite_for_grad_tts:
|
||||
dim = in_channels
|
||||
dim_out = out_channels
|
||||
time_emb_dim = temb_channels
|
||||
self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, dim_out))
|
||||
self.pre_norm = pre_norm
|
||||
|
||||
self.block1 = Block(dim, dim_out, groups=groups)
|
||||
self.block2 = Block(dim_out, dim_out, groups=groups)
|
||||
if dim != dim_out:
|
||||
self.res_conv = torch.nn.Conv2d(dim, dim_out, 1)
|
||||
else:
|
||||
self.res_conv = torch.nn.Identity()
|
||||
|
||||
def set_weights_grad_tts(self):
|
||||
self.conv1.weight.data = self.block1.block[0].weight.data
|
||||
self.conv1.bias.data = self.block1.block[0].bias.data
|
||||
self.norm1.weight.data = self.block1.block[1].weight.data
|
||||
self.norm1.bias.data = self.block1.block[1].bias.data
|
||||
|
||||
self.conv2.weight.data = self.block2.block[0].weight.data
|
||||
self.conv2.bias.data = self.block2.block[0].bias.data
|
||||
self.norm2.weight.data = self.block2.block[1].weight.data
|
||||
self.norm2.bias.data = self.block2.block[1].bias.data
|
||||
|
||||
self.temb_proj.weight.data = self.mlp[1].weight.data
|
||||
self.temb_proj.bias.data = self.mlp[1].bias.data
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
self.nin_shortcut.weight.data = self.res_conv.weight.data
|
||||
self.nin_shortcut.bias.data = self.res_conv.bias.data
|
||||
|
||||
def forward(self, x, temb, mask=None):
|
||||
if self.overwrite_for_grad_tts and not self.is_overwritten:
|
||||
self.set_weights_grad_tts()
|
||||
self.is_overwritten = True
|
||||
|
||||
h = x
|
||||
h = h * mask if mask is not None else h
|
||||
if self.pre_norm:
|
||||
h = self.norm1(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.nonlinearity(h)
|
||||
|
||||
h = self.conv1(h)
|
||||
|
||||
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
||||
if not self.pre_norm:
|
||||
h = self.norm1(h)
|
||||
h = self.nonlinearity(h)
|
||||
h = h * mask if mask is not None else h
|
||||
|
||||
h = h + self.temb_proj(self.nonlinearity(temb))[:, :, None, None]
|
||||
|
||||
h = h * mask if mask is not None else h
|
||||
if self.pre_norm:
|
||||
h = self.norm2(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.nonlinearity(h)
|
||||
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if not self.pre_norm:
|
||||
h = self.norm2(h)
|
||||
h = self.nonlinearity(h)
|
||||
h = h * mask if mask is not None else h
|
||||
|
||||
x = x * mask if mask is not None else x
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x)
|
||||
|
@ -383,58 +461,17 @@ class ResnetBlock(nn.Module):
|
|||
return x + h
|
||||
|
||||
|
||||
# unet_grad_tts.py
|
||||
class ResnetBlockGradTTS(torch.nn.Module):
|
||||
def __init__(self, dim, dim_out, time_emb_dim, groups=8):
|
||||
super(ResnetBlockGradTTS, self).__init__()
|
||||
self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, dim_out))
|
||||
|
||||
self.block1 = Block(dim, dim_out, groups=groups)
|
||||
self.block2 = Block(dim_out, dim_out, groups=groups)
|
||||
if dim != dim_out:
|
||||
self.res_conv = torch.nn.Conv2d(dim, dim_out, 1)
|
||||
else:
|
||||
self.res_conv = torch.nn.Identity()
|
||||
|
||||
def forward(self, x, mask, time_emb):
|
||||
h = self.block1(x, mask)
|
||||
h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1)
|
||||
h = self.block2(h, mask)
|
||||
output = h + self.res_conv(x * mask)
|
||||
return output
|
||||
|
||||
|
||||
# unet_rl.py
|
||||
class ResidualTemporalBlock(nn.Module):
|
||||
def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
|
||||
super().__init__()
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
Conv1dBlock(inp_channels, out_channels, kernel_size),
|
||||
Conv1dBlock(out_channels, out_channels, kernel_size),
|
||||
]
|
||||
# TODO(Patrick) - just there to convert the weights; can delete afterward
|
||||
class Block(torch.nn.Module):
|
||||
def __init__(self, dim, dim_out, groups=8):
|
||||
super(Block, self).__init__()
|
||||
self.block = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(dim, dim_out, 3, padding=1), torch.nn.GroupNorm(groups, dim_out), Mish()
|
||||
)
|
||||
|
||||
self.time_mlp = nn.Sequential(
|
||||
nn.Mish(),
|
||||
nn.Linear(embed_dim, out_channels),
|
||||
RearrangeDim(),
|
||||
# Rearrange("batch t -> batch t 1"),
|
||||
)
|
||||
|
||||
self.residual_conv = (
|
||||
nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(self, x, t):
|
||||
"""
|
||||
x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x
|
||||
out_channels x horizon ]
|
||||
"""
|
||||
out = self.blocks[0](x) + self.time_mlp(t)
|
||||
out = self.blocks[1](out)
|
||||
return out + self.residual_conv(x)
|
||||
def forward(self, x, mask):
|
||||
output = self.block(x * mask)
|
||||
return output * mask
|
||||
|
||||
|
||||
# unet_score_estimation.py
|
||||
|
@ -570,6 +607,39 @@ class ResnetBlockDDPMpp(nn.Module):
|
|||
return (x + h) / np.sqrt(2.0)
|
||||
|
||||
|
||||
# unet_rl.py
|
||||
class ResidualTemporalBlock(nn.Module):
|
||||
def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
|
||||
super().__init__()
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
Conv1dBlock(inp_channels, out_channels, kernel_size),
|
||||
Conv1dBlock(out_channels, out_channels, kernel_size),
|
||||
]
|
||||
)
|
||||
|
||||
self.time_mlp = nn.Sequential(
|
||||
nn.Mish(),
|
||||
nn.Linear(embed_dim, out_channels),
|
||||
RearrangeDim(),
|
||||
# Rearrange("batch t -> batch t 1"),
|
||||
)
|
||||
|
||||
self.residual_conv = (
|
||||
nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(self, x, t):
|
||||
"""
|
||||
x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x
|
||||
out_channels x horizon ]
|
||||
"""
|
||||
out = self.blocks[0](x) + self.time_mlp(t)
|
||||
out = self.blocks[1](out)
|
||||
return out + self.residual_conv(x)
|
||||
|
||||
|
||||
# HELPER Modules
|
||||
|
||||
|
||||
|
@ -617,18 +687,6 @@ class Mish(torch.nn.Module):
|
|||
return x * torch.tanh(torch.nn.functional.softplus(x))
|
||||
|
||||
|
||||
class Block(torch.nn.Module):
|
||||
def __init__(self, dim, dim_out, groups=8):
|
||||
super(Block, self).__init__()
|
||||
self.block = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(dim, dim_out, 3, padding=1), torch.nn.GroupNorm(groups, dim_out), Mish()
|
||||
)
|
||||
|
||||
def forward(self, x, mask):
|
||||
output = self.block(x * mask)
|
||||
return output * mask
|
||||
|
||||
|
||||
class Conv1dBlock(nn.Module):
|
||||
"""
|
||||
Conv1d --> GroupNorm --> Mish
|
||||
|
|
|
@ -4,9 +4,7 @@ from ..configuration_utils import ConfigMixin
|
|||
from ..modeling_utils import ModelMixin
|
||||
from .attention import LinearAttention
|
||||
from .embeddings import get_timestep_embedding
|
||||
from .resnet import Downsample
|
||||
from .resnet import ResnetBlockGradTTS as ResnetBlock
|
||||
from .resnet import Upsample
|
||||
from .resnet import Downsample, ResnetBlock, Upsample
|
||||
|
||||
|
||||
class Mish(torch.nn.Module):
|
||||
|
@ -86,8 +84,26 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
|
|||
self.downs.append(
|
||||
torch.nn.ModuleList(
|
||||
[
|
||||
ResnetBlock(dim_in, dim_out, time_emb_dim=dim),
|
||||
ResnetBlock(dim_out, dim_out, time_emb_dim=dim),
|
||||
ResnetBlock(
|
||||
in_channels=dim_in,
|
||||
out_channels=dim_out,
|
||||
temb_channels=dim,
|
||||
groups=8,
|
||||
pre_norm=False,
|
||||
eps=1e-5,
|
||||
non_linearity="mish",
|
||||
overwrite_for_grad_tts=True,
|
||||
),
|
||||
ResnetBlock(
|
||||
in_channels=dim_out,
|
||||
out_channels=dim_out,
|
||||
temb_channels=dim,
|
||||
groups=8,
|
||||
pre_norm=False,
|
||||
eps=1e-5,
|
||||
non_linearity="mish",
|
||||
overwrite_for_grad_tts=True,
|
||||
),
|
||||
Residual(Rezero(LinearAttention(dim_out))),
|
||||
Downsample(dim_out, use_conv=True, padding=1) if not is_last else torch.nn.Identity(),
|
||||
]
|
||||
|
@ -95,16 +111,52 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
|
|||
)
|
||||
|
||||
mid_dim = dims[-1]
|
||||
self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)
|
||||
self.mid_block1 = ResnetBlock(
|
||||
in_channels=mid_dim,
|
||||
out_channels=mid_dim,
|
||||
temb_channels=dim,
|
||||
groups=8,
|
||||
pre_norm=False,
|
||||
eps=1e-5,
|
||||
non_linearity="mish",
|
||||
overwrite_for_grad_tts=True,
|
||||
)
|
||||
self.mid_attn = Residual(Rezero(LinearAttention(mid_dim)))
|
||||
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)
|
||||
self.mid_block2 = ResnetBlock(
|
||||
in_channels=mid_dim,
|
||||
out_channels=mid_dim,
|
||||
temb_channels=dim,
|
||||
groups=8,
|
||||
pre_norm=False,
|
||||
eps=1e-5,
|
||||
non_linearity="mish",
|
||||
overwrite_for_grad_tts=True,
|
||||
)
|
||||
|
||||
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||||
self.ups.append(
|
||||
torch.nn.ModuleList(
|
||||
[
|
||||
ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim),
|
||||
ResnetBlock(dim_in, dim_in, time_emb_dim=dim),
|
||||
ResnetBlock(
|
||||
in_channels=dim_out * 2,
|
||||
out_channels=dim_in,
|
||||
temb_channels=dim,
|
||||
groups=8,
|
||||
pre_norm=False,
|
||||
eps=1e-5,
|
||||
non_linearity="mish",
|
||||
overwrite_for_grad_tts=True,
|
||||
),
|
||||
ResnetBlock(
|
||||
in_channels=dim_in,
|
||||
out_channels=dim_in,
|
||||
temb_channels=dim,
|
||||
groups=8,
|
||||
pre_norm=False,
|
||||
eps=1e-5,
|
||||
non_linearity="mish",
|
||||
overwrite_for_grad_tts=True,
|
||||
),
|
||||
Residual(Rezero(LinearAttention(dim_in))),
|
||||
Upsample(dim_in, use_conv_transpose=True),
|
||||
]
|
||||
|
@ -135,8 +187,8 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
|
|||
masks = [mask]
|
||||
for resnet1, resnet2, attn, downsample in self.downs:
|
||||
mask_down = masks[-1]
|
||||
x = resnet1(x, mask_down, t)
|
||||
x = resnet2(x, mask_down, t)
|
||||
x = resnet1(x, t, mask_down)
|
||||
x = resnet2(x, t, mask_down)
|
||||
x = attn(x)
|
||||
hiddens.append(x)
|
||||
x = downsample(x * mask_down)
|
||||
|
@ -144,15 +196,15 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
|
|||
|
||||
masks = masks[:-1]
|
||||
mask_mid = masks[-1]
|
||||
x = self.mid_block1(x, mask_mid, t)
|
||||
x = self.mid_block1(x, t, mask_mid)
|
||||
x = self.mid_attn(x)
|
||||
x = self.mid_block2(x, mask_mid, t)
|
||||
x = self.mid_block2(x, t, mask_mid)
|
||||
|
||||
for resnet1, resnet2, attn, upsample in self.ups:
|
||||
mask_up = masks.pop()
|
||||
x = torch.cat((x, hiddens.pop()), dim=1)
|
||||
x = resnet1(x, mask_up, t)
|
||||
x = resnet2(x, mask_up, t)
|
||||
x = resnet1(x, t, mask_up)
|
||||
x = resnet2(x, t, mask_up)
|
||||
x = attn(x)
|
||||
x = upsample(x * mask_up)
|
||||
|
||||
|
|
Loading…
Reference in New Issue