Unet for Grad TTS and pipeline

Unet for Grad TTS and pipeline
This commit is contained in:
Suraj Patil 2022-06-15 11:17:29 +02:00 committed by GitHub
commit f7d91f8b8c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 620 additions and 0 deletions

View File

@ -8,6 +8,7 @@ from .modeling_utils import ModelMixin
from .models.unet import UNetModel
from .models.unet_glide import GLIDEUNetModel, GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from .models.unet_ldm import UNetLDMModel
from .models.unet_grad_tts import UNetGradTTSModel
from .pipeline_utils import DiffusionPipeline
from .pipelines import DDIM, DDPM, GLIDE, LatentDiffusion, PNDM, BDDM
from .schedulers import DDIMScheduler, DDPMScheduler, SchedulerMixin, PNDMScheduler

View File

@ -19,3 +19,4 @@
from .unet import UNetModel
from .unet_glide import GLIDEUNetModel, GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from .unet_ldm import UNetLDMModel
from .unet_grad_tts import UNetGradTTSModel

View File

@ -0,0 +1,233 @@
import math
import torch
try:
from einops import rearrange, repeat
except:
print("Einops is not installed")
pass
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
class Mish(torch.nn.Module):
def forward(self, x):
return x * torch.tanh(torch.nn.functional.softplus(x))
class Upsample(torch.nn.Module):
def __init__(self, dim):
super(Upsample, self).__init__()
self.conv = torch.nn.ConvTranspose2d(dim, dim, 4, 2, 1)
def forward(self, x):
return self.conv(x)
class Downsample(torch.nn.Module):
def __init__(self, dim):
super(Downsample, self).__init__()
self.conv = torch.nn.Conv2d(dim, dim, 3, 2, 1)
def forward(self, x):
return self.conv(x)
class Rezero(torch.nn.Module):
def __init__(self, fn):
super(Rezero, self).__init__()
self.fn = fn
self.g = torch.nn.Parameter(torch.zeros(1))
def forward(self, x):
return self.fn(x) * self.g
class Block(torch.nn.Module):
def __init__(self, dim, dim_out, groups=8):
super(Block, self).__init__()
self.block = torch.nn.Sequential(torch.nn.Conv2d(dim, dim_out, 3,
padding=1), torch.nn.GroupNorm(
groups, dim_out), Mish())
def forward(self, x, mask):
output = self.block(x * mask)
return output * mask
class ResnetBlock(torch.nn.Module):
def __init__(self, dim, dim_out, time_emb_dim, groups=8):
super(ResnetBlock, self).__init__()
self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim,
dim_out))
self.block1 = Block(dim, dim_out, groups=groups)
self.block2 = Block(dim_out, dim_out, groups=groups)
if dim != dim_out:
self.res_conv = torch.nn.Conv2d(dim, dim_out, 1)
else:
self.res_conv = torch.nn.Identity()
def forward(self, x, mask, time_emb):
h = self.block1(x, mask)
h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1)
h = self.block2(h, mask)
output = h + self.res_conv(x * mask)
return output
class LinearAttention(torch.nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super(LinearAttention, self).__init__()
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)',
heads = self.heads, qkv=3)
k = k.softmax(dim=-1)
context = torch.einsum('bhdn,bhen->bhde', k, v)
out = torch.einsum('bhde,bhdn->bhen', context, q)
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w',
heads=self.heads, h=h, w=w)
return self.to_out(out)
class Residual(torch.nn.Module):
def __init__(self, fn):
super(Residual, self).__init__()
self.fn = fn
def forward(self, x, *args, **kwargs):
output = self.fn(x, *args, **kwargs) + x
return output
class SinusoidalPosEmb(torch.nn.Module):
def __init__(self, dim):
super(SinusoidalPosEmb, self).__init__()
self.dim = dim
def forward(self, x, scale=1000):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class UNetGradTTSModel(ModelMixin, ConfigMixin):
def __init__(
self,
dim,
dim_mults=(1, 2, 4),
groups=8,
n_spks=None,
spk_emb_dim=64,
n_feats=80,
pe_scale=1000
):
super(UNetGradTTSModel, self).__init__()
self.register(
dim=dim,
dim_mults=dim_mults,
groups=groups,
n_spks=n_spks,
spk_emb_dim=spk_emb_dim,
n_feats=n_feats,
pe_scale=pe_scale
)
self.dim = dim
self.dim_mults = dim_mults
self.groups = groups
self.n_spks = n_spks if not isinstance(n_spks, type(None)) else 1
self.spk_emb_dim = spk_emb_dim
self.pe_scale = pe_scale
if n_spks > 1:
self.spk_mlp = torch.nn.Sequential(torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(),
torch.nn.Linear(spk_emb_dim * 4, n_feats))
self.time_pos_emb = SinusoidalPosEmb(dim)
self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(),
torch.nn.Linear(dim * 4, dim))
dims = [2 + (1 if n_spks > 1 else 0), *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
self.downs = torch.nn.ModuleList([])
self.ups = torch.nn.ModuleList([])
num_resolutions = len(in_out)
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1)
self.downs.append(torch.nn.ModuleList([
ResnetBlock(dim_in, dim_out, time_emb_dim=dim),
ResnetBlock(dim_out, dim_out, time_emb_dim=dim),
Residual(Rezero(LinearAttention(dim_out))),
Downsample(dim_out) if not is_last else torch.nn.Identity()]))
mid_dim = dims[-1]
self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)
self.mid_attn = Residual(Rezero(LinearAttention(mid_dim)))
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
self.ups.append(torch.nn.ModuleList([
ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim),
ResnetBlock(dim_in, dim_in, time_emb_dim=dim),
Residual(Rezero(LinearAttention(dim_in))),
Upsample(dim_in)]))
self.final_block = Block(dim, dim)
self.final_conv = torch.nn.Conv2d(dim, 1, 1)
def forward(self, x, mask, mu, t, spk=None):
if not isinstance(spk, type(None)):
s = self.spk_mlp(spk)
t = self.time_pos_emb(t, scale=self.pe_scale)
t = self.mlp(t)
if self.n_spks < 2:
x = torch.stack([mu, x], 1)
else:
s = s.unsqueeze(-1).repeat(1, 1, x.shape[-1])
x = torch.stack([mu, x, s], 1)
mask = mask.unsqueeze(1)
hiddens = []
masks = [mask]
for resnet1, resnet2, attn, downsample in self.downs:
mask_down = masks[-1]
x = resnet1(x, mask_down, t)
x = resnet2(x, mask_down, t)
x = attn(x)
hiddens.append(x)
x = downsample(x * mask_down)
masks.append(mask_down[:, :, :, ::2])
masks = masks[:-1]
mask_mid = masks[-1]
x = self.mid_block1(x, mask_mid, t)
x = self.mid_attn(x)
x = self.mid_block2(x, mask_mid, t)
for resnet1, resnet2, attn, upsample in self.ups:
mask_up = masks.pop()
x = torch.cat((x, hiddens.pop()), dim=1)
x = resnet1(x, mask_up, t)
x = resnet2(x, mask_up, t)
x = attn(x)
x = upsample(x * mask_up)
x = self.final_block(x, mask)
output = self.final_conv(x * mask)
return (output * mask).squeeze(1)

View File

@ -0,0 +1,385 @@
""" from https://github.com/jaywalnut310/glow-tts """
import math
import torch
from torch import nn
from diffusers.configuration_utils import ConfigMixin
from diffusers.modeling_utils import ModelMixin
def sequence_mask(length, max_length=None):
if max_length is None:
max_length = length.max()
x = torch.arange(int(max_length), dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1)
def fix_len_compatibility(length, num_downsamplings_in_unet=2):
while True:
if length % (2**num_downsamplings_in_unet) == 0:
return length
length += 1
def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
def generate_path(duration, mask):
device = duration.device
b, t_x, t_y = mask.shape
cum_duration = torch.cumsum(duration, 1)
path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device)
cum_duration_flat = cum_duration.view(b * t_x)
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
path = path.view(b, t_x, t_y)
path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0],
[1, 0], [0, 0]]))[:, :-1]
path = path * mask
return path
def duration_loss(logw, logw_, lengths):
loss = torch.sum((logw - logw_)**2) / torch.sum(lengths)
return loss
class LayerNorm(nn.Module):
def __init__(self, channels, eps=1e-4):
super(LayerNorm, self).__init__()
self.channels = channels
self.eps = eps
self.gamma = torch.nn.Parameter(torch.ones(channels))
self.beta = torch.nn.Parameter(torch.zeros(channels))
def forward(self, x):
n_dims = len(x.shape)
mean = torch.mean(x, 1, keepdim=True)
variance = torch.mean((x - mean)**2, 1, keepdim=True)
x = (x - mean) * torch.rsqrt(variance + self.eps)
shape = [1, -1] + [1] * (n_dims - 2)
x = x * self.gamma.view(*shape) + self.beta.view(*shape)
return x
class ConvReluNorm(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size,
n_layers, p_dropout):
super(ConvReluNorm, self).__init__()
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.p_dropout = p_dropout
self.conv_layers = torch.nn.ModuleList()
self.norm_layers = torch.nn.ModuleList()
self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels,
kernel_size, padding=kernel_size//2))
self.norm_layers.append(LayerNorm(hidden_channels))
self.relu_drop = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Dropout(p_dropout))
for _ in range(n_layers - 1):
self.conv_layers.append(torch.nn.Conv1d(hidden_channels, hidden_channels,
kernel_size, padding=kernel_size//2))
self.norm_layers.append(LayerNorm(hidden_channels))
self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1)
self.proj.weight.data.zero_()
self.proj.bias.data.zero_()
def forward(self, x, x_mask):
x_org = x
for i in range(self.n_layers):
x = self.conv_layers[i](x * x_mask)
x = self.norm_layers[i](x)
x = self.relu_drop(x)
x = x_org + self.proj(x)
return x * x_mask
class DurationPredictor(nn.Module):
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout):
super(DurationPredictor, self).__init__()
self.in_channels = in_channels
self.filter_channels = filter_channels
self.p_dropout = p_dropout
self.drop = torch.nn.Dropout(p_dropout)
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels,
kernel_size, padding=kernel_size//2)
self.norm_1 = LayerNorm(filter_channels)
self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels,
kernel_size, padding=kernel_size//2)
self.norm_2 = LayerNorm(filter_channels)
self.proj = torch.nn.Conv1d(filter_channels, 1, 1)
def forward(self, x, x_mask):
x = self.conv_1(x * x_mask)
x = torch.relu(x)
x = self.norm_1(x)
x = self.drop(x)
x = self.conv_2(x * x_mask)
x = torch.relu(x)
x = self.norm_2(x)
x = self.drop(x)
x = self.proj(x * x_mask)
return x * x_mask
class MultiHeadAttention(nn.Module):
def __init__(self, channels, out_channels, n_heads, window_size=None,
heads_share=True, p_dropout=0.0, proximal_bias=False,
proximal_init=False):
super(MultiHeadAttention, self).__init__()
assert channels % n_heads == 0
self.channels = channels
self.out_channels = out_channels
self.n_heads = n_heads
self.window_size = window_size
self.heads_share = heads_share
self.proximal_bias = proximal_bias
self.p_dropout = p_dropout
self.attn = None
self.k_channels = channels // n_heads
self.conv_q = torch.nn.Conv1d(channels, channels, 1)
self.conv_k = torch.nn.Conv1d(channels, channels, 1)
self.conv_v = torch.nn.Conv1d(channels, channels, 1)
if window_size is not None:
n_heads_rel = 1 if heads_share else n_heads
rel_stddev = self.k_channels**-0.5
self.emb_rel_k = torch.nn.Parameter(torch.randn(n_heads_rel,
window_size * 2 + 1, self.k_channels) * rel_stddev)
self.emb_rel_v = torch.nn.Parameter(torch.randn(n_heads_rel,
window_size * 2 + 1, self.k_channels) * rel_stddev)
self.conv_o = torch.nn.Conv1d(channels, out_channels, 1)
self.drop = torch.nn.Dropout(p_dropout)
torch.nn.init.xavier_uniform_(self.conv_q.weight)
torch.nn.init.xavier_uniform_(self.conv_k.weight)
if proximal_init:
self.conv_k.weight.data.copy_(self.conv_q.weight.data)
self.conv_k.bias.data.copy_(self.conv_q.bias.data)
torch.nn.init.xavier_uniform_(self.conv_v.weight)
def forward(self, x, c, attn_mask=None):
q = self.conv_q(x)
k = self.conv_k(c)
v = self.conv_v(c)
x, self.attn = self.attention(q, k, v, mask=attn_mask)
x = self.conv_o(x)
return x
def attention(self, query, key, value, mask=None):
b, d, t_s, t_t = (*key.size(), query.size(2))
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels)
if self.window_size is not None:
assert t_s == t_t, "Relative attention is only available for self-attention."
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings)
rel_logits = self._relative_position_to_absolute_position(rel_logits)
scores_local = rel_logits / math.sqrt(self.k_channels)
scores = scores + scores_local
if self.proximal_bias:
assert t_s == t_t, "Proximal bias is only available for self-attention."
scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device,
dtype=scores.dtype)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e4)
p_attn = torch.nn.functional.softmax(scores, dim=-1)
p_attn = self.drop(p_attn)
output = torch.matmul(p_attn, value)
if self.window_size is not None:
relative_weights = self._absolute_position_to_relative_position(p_attn)
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
output = output + self._matmul_with_relative_values(relative_weights,
value_relative_embeddings)
output = output.transpose(2, 3).contiguous().view(b, d, t_t)
return output, p_attn
def _matmul_with_relative_values(self, x, y):
ret = torch.matmul(x, y.unsqueeze(0))
return ret
def _matmul_with_relative_keys(self, x, y):
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
return ret
def _get_relative_embeddings(self, relative_embeddings, length):
pad_length = max(length - (self.window_size + 1), 0)
slice_start_position = max((self.window_size + 1) - length, 0)
slice_end_position = slice_start_position + 2 * length - 1
if pad_length > 0:
padded_relative_embeddings = torch.nn.functional.pad(
relative_embeddings, convert_pad_shape([[0, 0],
[pad_length, pad_length], [0, 0]]))
else:
padded_relative_embeddings = relative_embeddings
used_relative_embeddings = padded_relative_embeddings[:,
slice_start_position:slice_end_position]
return used_relative_embeddings
def _relative_position_to_absolute_position(self, x):
batch, heads, length, _ = x.size()
x = torch.nn.functional.pad(x, convert_pad_shape([[0,0],[0,0],[0,0],[0,1]]))
x_flat = x.view([batch, heads, length * 2 * length])
x_flat = torch.nn.functional.pad(x_flat, convert_pad_shape([[0,0],[0,0],[0,length-1]]))
x_final = x_flat.view([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:]
return x_final
def _absolute_position_to_relative_position(self, x):
batch, heads, length, _ = x.size()
x = torch.nn.functional.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length-1]]))
x_flat = x.view([batch, heads, length**2 + length*(length - 1)])
x_flat = torch.nn.functional.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
x_final = x_flat.view([batch, heads, length, 2*length])[:,:,:,1:]
return x_final
def _attention_bias_proximal(self, length):
r = torch.arange(length, dtype=torch.float32)
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
class FFN(nn.Module):
def __init__(self, in_channels, out_channels, filter_channels, kernel_size,
p_dropout=0.0):
super(FFN, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size,
padding=kernel_size//2)
self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size,
padding=kernel_size//2)
self.drop = torch.nn.Dropout(p_dropout)
def forward(self, x, x_mask):
x = self.conv_1(x * x_mask)
x = torch.relu(x)
x = self.drop(x)
x = self.conv_2(x * x_mask)
return x * x_mask
class Encoder(nn.Module):
def __init__(self, hidden_channels, filter_channels, n_heads, n_layers,
kernel_size=1, p_dropout=0.0, window_size=None, **kwargs):
super(Encoder, self).__init__()
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.window_size = window_size
self.drop = torch.nn.Dropout(p_dropout)
self.attn_layers = torch.nn.ModuleList()
self.norm_layers_1 = torch.nn.ModuleList()
self.ffn_layers = torch.nn.ModuleList()
self.norm_layers_2 = torch.nn.ModuleList()
for _ in range(self.n_layers):
self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels,
n_heads, window_size=window_size, p_dropout=p_dropout))
self.norm_layers_1.append(LayerNorm(hidden_channels))
self.ffn_layers.append(FFN(hidden_channels, hidden_channels,
filter_channels, kernel_size, p_dropout=p_dropout))
self.norm_layers_2.append(LayerNorm(hidden_channels))
def forward(self, x, x_mask):
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
for i in range(self.n_layers):
x = x * x_mask
y = self.attn_layers[i](x, x, attn_mask)
y = self.drop(y)
x = self.norm_layers_1[i](x + y)
y = self.ffn_layers[i](x, x_mask)
y = self.drop(y)
x = self.norm_layers_2[i](x + y)
x = x * x_mask
return x
class TextEncoder(ModelMixin, ConfigMixin):
def __init__(self, n_vocab, n_feats, n_channels, filter_channels,
filter_channels_dp, n_heads, n_layers, kernel_size,
p_dropout, window_size=None, spk_emb_dim=64, n_spks=1):
super(TextEncoder, self).__init__()
self.register(
n_vocab=n_vocab,
n_feats=n_feats,
n_channels=n_channels,
filter_channels=filter_channels,
filter_channels_dp=filter_channels_dp,
n_heads=n_heads,
n_layers=n_layers,
kernel_size=kernel_size,
p_dropout=p_dropout,
window_size=window_size,
spk_emb_dim=spk_emb_dim,
n_spks=n_spks
)
self.n_vocab = n_vocab
self.n_feats = n_feats
self.n_channels = n_channels
self.filter_channels = filter_channels
self.filter_channels_dp = filter_channels_dp
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.window_size = window_size
self.spk_emb_dim = spk_emb_dim
self.n_spks = n_spks
self.emb = torch.nn.Embedding(n_vocab, n_channels)
torch.nn.init.normal_(self.emb.weight, 0.0, n_channels**-0.5)
self.prenet = ConvReluNorm(n_channels, n_channels, n_channels,
kernel_size=5, n_layers=3, p_dropout=0.5)
self.encoder = Encoder(n_channels + (spk_emb_dim if n_spks > 1 else 0), filter_channels, n_heads, n_layers,
kernel_size, p_dropout, window_size=window_size)
self.proj_m = torch.nn.Conv1d(n_channels + (spk_emb_dim if n_spks > 1 else 0), n_feats, 1)
self.proj_w = DurationPredictor(n_channels + (spk_emb_dim if n_spks > 1 else 0), filter_channels_dp,
kernel_size, p_dropout)
def forward(self, x, x_lengths, spk=None):
x = self.emb(x) * math.sqrt(self.n_channels)
x = torch.transpose(x, 1, -1)
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x = self.prenet(x, x_mask)
if self.n_spks > 1:
x = torch.cat([x, spk.unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1)
x = self.encoder(x, x_mask)
mu = self.proj_m(x) * x_mask
x_dp = torch.detach(x)
logw = self.proj_w(x_dp, x_mask)
return mu, logw, x_mask