The big purge -> remove everything except vision for now
This commit is contained in:
parent
c8c0c0e846
commit
2a69c0b7b8
|
@ -1,40 +0,0 @@
|
|||
|
||||
import argparse
|
||||
import torch
|
||||
|
||||
from diffusers.pipelines.bddm import DiffWave, BDDMPipeline
|
||||
from diffusers import DDPMScheduler
|
||||
|
||||
|
||||
def convert_bddm_orginal(checkpoint_path, noise_scheduler_checkpoint_path, output_path):
|
||||
sd = torch.load(checkpoint_path, map_location="cpu")["model_state_dict"]
|
||||
noise_scheduler_sd = torch.load(noise_scheduler_checkpoint_path, map_location="cpu")
|
||||
|
||||
model = DiffWave()
|
||||
model.load_state_dict(sd, strict=False)
|
||||
|
||||
ts, _, betas, _ = noise_scheduler_sd
|
||||
ts, betas = list(ts.numpy().tolist()), list(betas.numpy().tolist())
|
||||
|
||||
noise_scheduler = DDPMScheduler(
|
||||
timesteps=12,
|
||||
trained_betas=betas,
|
||||
timestep_values=ts,
|
||||
clip_sample=False,
|
||||
tensor_format="np",
|
||||
)
|
||||
|
||||
pipeline = BDDMPipeline(model, noise_scheduler)
|
||||
pipeline.save_pretrained(output_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--checkpoint_path", type=str, required=True)
|
||||
parser.add_argument("--noise_scheduler_checkpoint_path", type=str, required=True)
|
||||
parser.add_argument("--output_path", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_bddm_orginal(args.checkpoint_path, args.noise_scheduler_checkpoint_path, args.output_path)
|
||||
|
||||
|
|
@ -7,10 +7,9 @@ from .utils import is_inflect_available, is_transformers_available, is_unidecode
|
|||
__version__ = "0.0.4"
|
||||
|
||||
from .modeling_utils import ModelMixin
|
||||
from .models import AutoencoderKL, NCSNpp, TemporalUNet, UNetLDMModel, UNetModel, UNetUnconditionalModel, VQModel
|
||||
from .models import AutoencoderKL, NCSNpp, UNetLDMModel, UNetModel, UNetUnconditionalModel, VQModel
|
||||
from .pipeline_utils import DiffusionPipeline
|
||||
from .pipelines import (
|
||||
BDDMPipeline,
|
||||
DDIMPipeline,
|
||||
DDPMPipeline,
|
||||
LatentDiffusionUncondPipeline,
|
||||
|
@ -21,7 +20,6 @@ from .pipelines import (
|
|||
from .schedulers import (
|
||||
DDIMScheduler,
|
||||
DDPMScheduler,
|
||||
GradTTSScheduler,
|
||||
PNDMScheduler,
|
||||
SchedulerMixin,
|
||||
ScoreSdeVeScheduler,
|
||||
|
@ -31,13 +29,6 @@ from .schedulers import (
|
|||
|
||||
if is_transformers_available():
|
||||
from .models.unet_glide import GlideSuperResUNetModel, GlideTextToImageUNetModel, GlideUNetModel
|
||||
from .models.unet_grad_tts import UNetGradTTSModel
|
||||
from .pipelines import GlidePipeline, LatentDiffusionPipeline
|
||||
else:
|
||||
from .utils.dummy_transformers_objects import *
|
||||
|
||||
|
||||
if is_transformers_available() and is_inflect_available() and is_unidecode_available():
|
||||
from .pipelines import GradTTSPipeline
|
||||
else:
|
||||
from .utils.dummy_transformers_and_inflect_and_unidecode_objects import *
|
||||
|
|
|
@ -18,9 +18,7 @@
|
|||
|
||||
from .unet import UNetModel
|
||||
from .unet_glide import GlideSuperResUNetModel, GlideTextToImageUNetModel, GlideUNetModel
|
||||
from .unet_grad_tts import UNetGradTTSModel
|
||||
from .unet_ldm import UNetLDMModel
|
||||
from .unet_rl import TemporalUNet
|
||||
from .unet_sde_score_estimation import NCSNpp
|
||||
from .unet_unconditional import UNetUnconditionalModel
|
||||
from .vae import AutoencoderKL, VQModel
|
||||
|
|
|
@ -1,229 +0,0 @@
|
|||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .attention import LinearAttention
|
||||
from .embeddings import get_timestep_embedding
|
||||
from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
|
||||
from .unet_new import UNetMidBlock2D
|
||||
|
||||
|
||||
class Mish(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x * torch.tanh(torch.nn.functional.softplus(x))
|
||||
|
||||
|
||||
class Rezero(torch.nn.Module):
|
||||
def __init__(self, fn):
|
||||
super(Rezero, self).__init__()
|
||||
self.fn = fn
|
||||
self.g = torch.nn.Parameter(torch.zeros(1))
|
||||
|
||||
def forward(self, x, encoder_out=None):
|
||||
return self.fn(x, encoder_out) * self.g
|
||||
|
||||
|
||||
class Block(torch.nn.Module):
|
||||
def __init__(self, dim, dim_out, groups=8):
|
||||
super(Block, self).__init__()
|
||||
self.block = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(dim, dim_out, 3, padding=1), torch.nn.GroupNorm(groups, dim_out), Mish()
|
||||
)
|
||||
|
||||
def forward(self, x, mask):
|
||||
output = self.block(x * mask)
|
||||
return output * mask
|
||||
|
||||
|
||||
class Residual(torch.nn.Module):
|
||||
def __init__(self, fn):
|
||||
super(Residual, self).__init__()
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
output = self.fn(x, *args, **kwargs) + x
|
||||
return output
|
||||
|
||||
|
||||
class UNetGradTTSModel(ModelMixin, ConfigMixin):
|
||||
def __init__(self, dim, dim_mults=(1, 2, 4), groups=8, n_spks=None, spk_emb_dim=64, n_feats=80, pe_scale=1000):
|
||||
super(UNetGradTTSModel, self).__init__()
|
||||
|
||||
self.register_to_config(
|
||||
dim=dim,
|
||||
dim_mults=dim_mults,
|
||||
groups=groups,
|
||||
n_spks=n_spks,
|
||||
spk_emb_dim=spk_emb_dim,
|
||||
n_feats=n_feats,
|
||||
pe_scale=pe_scale,
|
||||
)
|
||||
|
||||
self.dim = dim
|
||||
self.dim_mults = dim_mults
|
||||
self.groups = groups
|
||||
self.n_spks = n_spks if not isinstance(n_spks, type(None)) else 1
|
||||
self.spk_emb_dim = spk_emb_dim
|
||||
self.pe_scale = pe_scale
|
||||
|
||||
if n_spks > 1:
|
||||
self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim)
|
||||
self.spk_mlp = torch.nn.Sequential(
|
||||
torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(), torch.nn.Linear(spk_emb_dim * 4, n_feats)
|
||||
)
|
||||
|
||||
self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(), torch.nn.Linear(dim * 4, dim))
|
||||
|
||||
dims = [2 + (1 if n_spks > 1 else 0), *map(lambda m: dim * m, dim_mults)]
|
||||
in_out = list(zip(dims[:-1], dims[1:]))
|
||||
self.downs = torch.nn.ModuleList([])
|
||||
self.ups = torch.nn.ModuleList([])
|
||||
num_resolutions = len(in_out)
|
||||
|
||||
for ind, (dim_in, dim_out) in enumerate(in_out):
|
||||
is_last = ind >= (num_resolutions - 1)
|
||||
self.downs.append(
|
||||
torch.nn.ModuleList(
|
||||
[
|
||||
ResnetBlock2D(
|
||||
in_channels=dim_in,
|
||||
out_channels=dim_out,
|
||||
temb_channels=dim,
|
||||
groups=8,
|
||||
pre_norm=False,
|
||||
eps=1e-5,
|
||||
non_linearity="mish",
|
||||
overwrite_for_grad_tts=True,
|
||||
),
|
||||
ResnetBlock2D(
|
||||
in_channels=dim_out,
|
||||
out_channels=dim_out,
|
||||
temb_channels=dim,
|
||||
groups=8,
|
||||
pre_norm=False,
|
||||
eps=1e-5,
|
||||
non_linearity="mish",
|
||||
overwrite_for_grad_tts=True,
|
||||
),
|
||||
Residual(Rezero(LinearAttention(dim_out))),
|
||||
Downsample2D(dim_out, use_conv=True, padding=1) if not is_last else torch.nn.Identity(),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
mid_dim = dims[-1]
|
||||
|
||||
self.mid = UNetMidBlock2D(
|
||||
in_channels=mid_dim,
|
||||
temb_channels=dim,
|
||||
resnet_groups=8,
|
||||
resnet_pre_norm=False,
|
||||
resnet_eps=1e-5,
|
||||
resnet_act_fn="mish",
|
||||
attention_layer_type="linear",
|
||||
)
|
||||
|
||||
self.mid_block1 = ResnetBlock2D(
|
||||
in_channels=mid_dim,
|
||||
out_channels=mid_dim,
|
||||
temb_channels=dim,
|
||||
groups=8,
|
||||
pre_norm=False,
|
||||
eps=1e-5,
|
||||
non_linearity="mish",
|
||||
overwrite_for_grad_tts=True,
|
||||
)
|
||||
self.mid_attn = Residual(Rezero(LinearAttention(mid_dim)))
|
||||
self.mid_block2 = ResnetBlock2D(
|
||||
in_channels=mid_dim,
|
||||
out_channels=mid_dim,
|
||||
temb_channels=dim,
|
||||
groups=8,
|
||||
pre_norm=False,
|
||||
eps=1e-5,
|
||||
non_linearity="mish",
|
||||
overwrite_for_grad_tts=True,
|
||||
)
|
||||
self.mid.resnets[0] = self.mid_block1
|
||||
self.mid.attentions[0] = self.mid_attn
|
||||
self.mid.resnets[1] = self.mid_block2
|
||||
|
||||
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||||
self.ups.append(
|
||||
torch.nn.ModuleList(
|
||||
[
|
||||
ResnetBlock2D(
|
||||
in_channels=dim_out * 2,
|
||||
out_channels=dim_in,
|
||||
temb_channels=dim,
|
||||
groups=8,
|
||||
pre_norm=False,
|
||||
eps=1e-5,
|
||||
non_linearity="mish",
|
||||
overwrite_for_grad_tts=True,
|
||||
),
|
||||
ResnetBlock2D(
|
||||
in_channels=dim_in,
|
||||
out_channels=dim_in,
|
||||
temb_channels=dim,
|
||||
groups=8,
|
||||
pre_norm=False,
|
||||
eps=1e-5,
|
||||
non_linearity="mish",
|
||||
overwrite_for_grad_tts=True,
|
||||
),
|
||||
Residual(Rezero(LinearAttention(dim_in))),
|
||||
Upsample2D(dim_in, use_conv_transpose=True),
|
||||
]
|
||||
)
|
||||
)
|
||||
self.final_block = Block(dim, dim)
|
||||
self.final_conv = torch.nn.Conv2d(dim, 1, 1)
|
||||
|
||||
def forward(self, sample, timesteps, mu, mask, spk=None):
|
||||
x = sample
|
||||
if self.n_spks > 1:
|
||||
# Get speaker embedding
|
||||
spk = self.spk_emb(spk)
|
||||
|
||||
if not isinstance(spk, type(None)):
|
||||
s = self.spk_mlp(spk)
|
||||
|
||||
t = get_timestep_embedding(timesteps, self.dim, scale=self.pe_scale)
|
||||
t = self.mlp(t)
|
||||
|
||||
if self.n_spks < 2:
|
||||
x = torch.stack([mu, x], 1)
|
||||
else:
|
||||
s = s.unsqueeze(-1).repeat(1, 1, x.shape[-1])
|
||||
x = torch.stack([mu, x, s], 1)
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
hiddens = []
|
||||
masks = [mask]
|
||||
for resnet1, resnet2, attn, downsample in self.downs:
|
||||
mask_down = masks[-1]
|
||||
x = resnet1(x, t, mask_down)
|
||||
x = resnet2(x, t, mask_down)
|
||||
x = attn(x)
|
||||
hiddens.append(x)
|
||||
x = downsample(x * mask_down)
|
||||
masks.append(mask_down[:, :, :, ::2])
|
||||
|
||||
masks = masks[:-1]
|
||||
mask_mid = masks[-1]
|
||||
|
||||
x = self.mid(x, t, mask=mask_mid)
|
||||
|
||||
for resnet1, resnet2, attn, upsample in self.ups:
|
||||
mask_up = masks.pop()
|
||||
x = torch.cat((x, hiddens.pop()), dim=1)
|
||||
x = resnet1(x, t, mask_up)
|
||||
x = resnet2(x, t, mask_up)
|
||||
x = attn(x)
|
||||
x = upsample(x * mask_up)
|
||||
|
||||
x = self.final_block(x, mask)
|
||||
output = self.final_conv(x * mask)
|
||||
|
||||
return (output * mask).squeeze(1)
|
|
@ -1,227 +0,0 @@
|
|||
# model adapted from diffuser https://github.com/jannerm/diffuser/blob/main/diffuser/models/temporal.py
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .embeddings import get_timestep_embedding
|
||||
from .resnet import Downsample1D, ResidualTemporalBlock, Upsample1D
|
||||
|
||||
|
||||
class SinusoidalPosEmb(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x):
|
||||
return get_timestep_embedding(x, self.dim)
|
||||
|
||||
|
||||
class RearrangeDim(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, tensor):
|
||||
if len(tensor.shape) == 2:
|
||||
return tensor[:, :, None]
|
||||
if len(tensor.shape) == 3:
|
||||
return tensor[:, :, None, :]
|
||||
elif len(tensor.shape) == 4:
|
||||
return tensor[:, :, 0, :]
|
||||
else:
|
||||
raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")
|
||||
|
||||
|
||||
class Conv1dBlock(nn.Module):
|
||||
"""
|
||||
Conv1d --> GroupNorm --> Mish
|
||||
"""
|
||||
|
||||
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
|
||||
super().__init__()
|
||||
|
||||
self.block = nn.Sequential(
|
||||
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
|
||||
RearrangeDim(),
|
||||
# Rearrange("batch channels horizon -> batch channels 1 horizon"),
|
||||
nn.GroupNorm(n_groups, out_channels),
|
||||
RearrangeDim(),
|
||||
# Rearrange("batch channels 1 horizon -> batch channels horizon"),
|
||||
nn.Mish(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
training_horizon=128,
|
||||
transition_dim=14,
|
||||
cond_dim=3,
|
||||
predict_epsilon=False,
|
||||
clip_denoised=True,
|
||||
dim=32,
|
||||
dim_mults=(1, 4, 8),
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.transition_dim = transition_dim
|
||||
self.cond_dim = cond_dim
|
||||
self.predict_epsilon = predict_epsilon
|
||||
self.clip_denoised = clip_denoised
|
||||
|
||||
dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
|
||||
in_out = list(zip(dims[:-1], dims[1:]))
|
||||
|
||||
time_dim = dim
|
||||
self.time_mlp = nn.Sequential(
|
||||
SinusoidalPosEmb(dim),
|
||||
nn.Linear(dim, dim * 4),
|
||||
nn.Mish(),
|
||||
nn.Linear(dim * 4, dim),
|
||||
)
|
||||
|
||||
self.downs = nn.ModuleList([])
|
||||
self.ups = nn.ModuleList([])
|
||||
num_resolutions = len(in_out)
|
||||
|
||||
for ind, (dim_in, dim_out) in enumerate(in_out):
|
||||
is_last = ind >= (num_resolutions - 1)
|
||||
|
||||
self.downs.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=training_horizon),
|
||||
ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=training_horizon),
|
||||
Downsample1D(dim_out, use_conv=True) if not is_last else nn.Identity(),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
if not is_last:
|
||||
training_horizon = training_horizon // 2
|
||||
|
||||
mid_dim = dims[-1]
|
||||
self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=training_horizon)
|
||||
self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=training_horizon)
|
||||
|
||||
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||||
is_last = ind >= (num_resolutions - 1)
|
||||
|
||||
self.ups.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=training_horizon),
|
||||
ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=training_horizon),
|
||||
Upsample1D(dim_in, use_conv_transpose=True) if not is_last else nn.Identity(),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
if not is_last:
|
||||
training_horizon = training_horizon * 2
|
||||
|
||||
self.final_conv = nn.Sequential(
|
||||
Conv1dBlock(dim, dim, kernel_size=5),
|
||||
nn.Conv1d(dim, transition_dim, 1),
|
||||
)
|
||||
|
||||
def forward(self, sample, timesteps):
|
||||
"""
|
||||
x : [ batch x horizon x transition ]
|
||||
"""
|
||||
x = sample
|
||||
|
||||
x = x.permute(0, 2, 1)
|
||||
|
||||
t = self.time_mlp(timesteps)
|
||||
h = []
|
||||
|
||||
for resnet, resnet2, downsample in self.downs:
|
||||
x = resnet(x, t)
|
||||
x = resnet2(x, t)
|
||||
h.append(x)
|
||||
x = downsample(x)
|
||||
|
||||
x = self.mid_block1(x, t)
|
||||
x = self.mid_block2(x, t)
|
||||
|
||||
for resnet, resnet2, upsample in self.ups:
|
||||
x = torch.cat((x, h.pop()), dim=1)
|
||||
x = resnet(x, t)
|
||||
x = resnet2(x, t)
|
||||
x = upsample(x)
|
||||
|
||||
x = self.final_conv(x)
|
||||
|
||||
x = x.permute(0, 2, 1)
|
||||
return x
|
||||
|
||||
|
||||
class TemporalValue(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
horizon,
|
||||
transition_dim,
|
||||
cond_dim,
|
||||
dim=32,
|
||||
time_dim=None,
|
||||
out_dim=1,
|
||||
dim_mults=(1, 2, 4, 8),
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
|
||||
in_out = list(zip(dims[:-1], dims[1:]))
|
||||
|
||||
time_dim = time_dim or dim
|
||||
self.time_mlp = nn.Sequential(
|
||||
SinusoidalPosEmb(dim),
|
||||
nn.Linear(dim, dim * 4),
|
||||
nn.Mish(),
|
||||
nn.Linear(dim * 4, dim),
|
||||
)
|
||||
|
||||
self.blocks = nn.ModuleList([])
|
||||
|
||||
print(in_out)
|
||||
for dim_in, dim_out in in_out:
|
||||
self.blocks.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
ResidualTemporalBlock(dim_in, dim_out, kernel_size=5, embed_dim=time_dim, horizon=horizon),
|
||||
ResidualTemporalBlock(dim_out, dim_out, kernel_size=5, embed_dim=time_dim, horizon=horizon),
|
||||
Downsample1d(dim_out),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
horizon = horizon // 2
|
||||
|
||||
fc_dim = dims[-1] * max(horizon, 1)
|
||||
|
||||
self.final_block = nn.Sequential(
|
||||
nn.Linear(fc_dim + time_dim, fc_dim // 2),
|
||||
nn.Mish(),
|
||||
nn.Linear(fc_dim // 2, out_dim),
|
||||
)
|
||||
|
||||
def forward(self, x, cond, time, *args):
|
||||
"""
|
||||
x : [ batch x horizon x transition ]
|
||||
"""
|
||||
x = x.permute(0, 2, 1)
|
||||
|
||||
t = self.time_mlp(time)
|
||||
|
||||
for resnet, resnet2, downsample in self.blocks:
|
||||
x = resnet(x, t)
|
||||
x = resnet2(x, t)
|
||||
x = downsample(x)
|
||||
|
||||
x = x.view(len(x), -1)
|
||||
out = self.final_block(torch.cat([x, t], dim=-1))
|
||||
return out
|
|
@ -1,5 +1,4 @@
|
|||
from ..utils import is_inflect_available, is_transformers_available, is_unidecode_available
|
||||
from .bddm import BDDMPipeline
|
||||
from .ddim import DDIMPipeline
|
||||
from .ddpm import DDPMPipeline
|
||||
from .latent_diffusion_uncond import LatentDiffusionUncondPipeline
|
||||
|
@ -11,7 +10,3 @@ from .score_sde_vp import ScoreSdeVpPipeline
|
|||
if is_transformers_available():
|
||||
from .glide import GlidePipeline
|
||||
from .latent_diffusion import LatentDiffusionPipeline
|
||||
|
||||
|
||||
if is_transformers_available() and is_unidecode_available() and is_inflect_available():
|
||||
from .grad_tts import GradTTSPipeline
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
from .pipeline_bddm import BDDMPipeline, DiffWave
|
|
@ -1,311 +0,0 @@
|
|||
#!/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
########################################################################
|
||||
#
|
||||
# DiffWave: A Versatile Diffusion Model for Audio Synthesis
|
||||
# (https://arxiv.org/abs/2009.09761)
|
||||
# Modified from https://github.com/philsyn/DiffWave-Vocoder
|
||||
#
|
||||
# Author: Max W. Y. Lam (maxwylam@tencent.com)
|
||||
# Copyright (c) 2021Tencent. All Rights Reserved
|
||||
#
|
||||
########################################################################
|
||||
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import tqdm
|
||||
|
||||
from ...configuration_utils import ConfigMixin
|
||||
from ...modeling_utils import ModelMixin
|
||||
from ...pipeline_utils import DiffusionPipeline
|
||||
|
||||
|
||||
def calc_diffusion_step_embedding(diffusion_steps, diffusion_step_embed_dim_in):
|
||||
"""
|
||||
Embed a diffusion step $t$ into a higher dimensional space
|
||||
E.g. the embedding vector in the 128-dimensional space is [sin(t * 10^(0*4/63)), ... , sin(t * 10^(63*4/63)),
|
||||
cos(t * 10^(0*4/63)), ... , cos(t * 10^(63*4/63))]
|
||||
|
||||
Parameters:
|
||||
diffusion_steps (torch.long tensor, shape=(batchsize, 1)):
|
||||
diffusion steps for batch data
|
||||
diffusion_step_embed_dim_in (int, default=128):
|
||||
dimensionality of the embedding space for discrete diffusion steps
|
||||
Returns:
|
||||
the embedding vectors (torch.tensor, shape=(batchsize, diffusion_step_embed_dim_in)):
|
||||
"""
|
||||
|
||||
assert diffusion_step_embed_dim_in % 2 == 0
|
||||
|
||||
half_dim = diffusion_step_embed_dim_in // 2
|
||||
_embed = np.log(10000) / (half_dim - 1)
|
||||
_embed = torch.exp(torch.arange(half_dim) * -_embed).cuda()
|
||||
_embed = diffusion_steps * _embed
|
||||
diffusion_step_embed = torch.cat((torch.sin(_embed), torch.cos(_embed)), 1)
|
||||
return diffusion_step_embed
|
||||
|
||||
|
||||
"""
|
||||
Below scripts were borrowed from https://github.com/philsyn/DiffWave-Vocoder/blob/master/WaveNet.py
|
||||
"""
|
||||
|
||||
|
||||
def swish(x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
# dilated conv layer with kaiming_normal initialization
|
||||
# from https://github.com/ksw0306/FloWaveNet/blob/master/modules.py
|
||||
class Conv(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1):
|
||||
super().__init__()
|
||||
self.padding = dilation * (kernel_size - 1) // 2
|
||||
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, dilation=dilation, padding=self.padding)
|
||||
self.conv = nn.utils.weight_norm(self.conv)
|
||||
nn.init.kaiming_normal_(self.conv.weight)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv(x)
|
||||
return out
|
||||
|
||||
|
||||
# conv1x1 layer with zero initialization
|
||||
# from https://github.com/ksw0306/FloWaveNet/blob/master/modules.py but the scale parameter is removed
|
||||
class ZeroConv1d(nn.Module):
|
||||
def __init__(self, in_channel, out_channel):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv1d(in_channel, out_channel, kernel_size=1, padding=0)
|
||||
self.conv.weight.data.zero_()
|
||||
self.conv.bias.data.zero_()
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv(x)
|
||||
return out
|
||||
|
||||
|
||||
# every residual block (named residual layer in paper)
|
||||
# contains one noncausal dilated conv
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, res_channels, skip_channels, dilation, diffusion_step_embed_dim_out):
|
||||
super().__init__()
|
||||
self.res_channels = res_channels
|
||||
|
||||
# Use a FC layer for diffusion step embedding
|
||||
self.fc_t = nn.Linear(diffusion_step_embed_dim_out, self.res_channels)
|
||||
|
||||
# Dilated conv layer
|
||||
self.dilated_conv_layer = Conv(self.res_channels, 2 * self.res_channels, kernel_size=3, dilation=dilation)
|
||||
|
||||
# Add mel spectrogram upsampler and conditioner conv1x1 layer
|
||||
self.upsample_conv2d = nn.ModuleList()
|
||||
for s in [16, 16]:
|
||||
conv_trans2d = nn.ConvTranspose2d(1, 1, (3, 2 * s), padding=(1, s // 2), stride=(1, s))
|
||||
conv_trans2d = nn.utils.weight_norm(conv_trans2d)
|
||||
nn.init.kaiming_normal_(conv_trans2d.weight)
|
||||
self.upsample_conv2d.append(conv_trans2d)
|
||||
|
||||
# 80 is mel bands
|
||||
self.mel_conv = Conv(80, 2 * self.res_channels, kernel_size=1)
|
||||
|
||||
# Residual conv1x1 layer, connect to next residual layer
|
||||
self.res_conv = nn.Conv1d(res_channels, res_channels, kernel_size=1)
|
||||
self.res_conv = nn.utils.weight_norm(self.res_conv)
|
||||
nn.init.kaiming_normal_(self.res_conv.weight)
|
||||
|
||||
# Skip conv1x1 layer, add to all skip outputs through skip connections
|
||||
self.skip_conv = nn.Conv1d(res_channels, skip_channels, kernel_size=1)
|
||||
self.skip_conv = nn.utils.weight_norm(self.skip_conv)
|
||||
nn.init.kaiming_normal_(self.skip_conv.weight)
|
||||
|
||||
def forward(self, input_data):
|
||||
x, mel_spec, diffusion_step_embed = input_data
|
||||
h = x
|
||||
batch_size, n_channels, seq_len = x.shape
|
||||
assert n_channels == self.res_channels
|
||||
|
||||
# Add in diffusion step embedding
|
||||
part_t = self.fc_t(diffusion_step_embed)
|
||||
part_t = part_t.view([batch_size, self.res_channels, 1])
|
||||
h += part_t
|
||||
|
||||
# Dilated conv layer
|
||||
h = self.dilated_conv_layer(h)
|
||||
|
||||
# Upsample2D spectrogram to size of audio
|
||||
mel_spec = torch.unsqueeze(mel_spec, dim=1)
|
||||
mel_spec = F.leaky_relu(self.upsample_conv2d[0](mel_spec), 0.4, inplace=False)
|
||||
mel_spec = F.leaky_relu(self.upsample_conv2d[1](mel_spec), 0.4, inplace=False)
|
||||
mel_spec = torch.squeeze(mel_spec, dim=1)
|
||||
|
||||
assert mel_spec.size(2) >= seq_len
|
||||
if mel_spec.size(2) > seq_len:
|
||||
mel_spec = mel_spec[:, :, :seq_len]
|
||||
|
||||
mel_spec = self.mel_conv(mel_spec)
|
||||
h += mel_spec
|
||||
|
||||
# Gated-tanh nonlinearity
|
||||
out = torch.tanh(h[:, : self.res_channels, :]) * torch.sigmoid(h[:, self.res_channels :, :])
|
||||
|
||||
# Residual and skip outputs
|
||||
res = self.res_conv(out)
|
||||
assert x.shape == res.shape
|
||||
skip = self.skip_conv(out)
|
||||
|
||||
# Normalize for training stability
|
||||
return (x + res) * math.sqrt(0.5), skip
|
||||
|
||||
|
||||
class ResidualGroup(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
res_channels,
|
||||
skip_channels,
|
||||
num_res_layers,
|
||||
dilation_cycle,
|
||||
diffusion_step_embed_dim_in,
|
||||
diffusion_step_embed_dim_mid,
|
||||
diffusion_step_embed_dim_out,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_res_layers = num_res_layers
|
||||
self.diffusion_step_embed_dim_in = diffusion_step_embed_dim_in
|
||||
|
||||
# Use the shared two FC layers for diffusion step embedding
|
||||
self.fc_t1 = nn.Linear(diffusion_step_embed_dim_in, diffusion_step_embed_dim_mid)
|
||||
self.fc_t2 = nn.Linear(diffusion_step_embed_dim_mid, diffusion_step_embed_dim_out)
|
||||
|
||||
# Stack all residual blocks with dilations 1, 2, ... , 512, ... , 1, 2, ..., 512
|
||||
self.residual_blocks = nn.ModuleList()
|
||||
for n in range(self.num_res_layers):
|
||||
self.residual_blocks.append(
|
||||
ResidualBlock(
|
||||
res_channels,
|
||||
skip_channels,
|
||||
dilation=2 ** (n % dilation_cycle),
|
||||
diffusion_step_embed_dim_out=diffusion_step_embed_dim_out,
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, input_data):
|
||||
x, mel_spectrogram, diffusion_steps = input_data
|
||||
|
||||
# Embed diffusion step t
|
||||
diffusion_step_embed = calc_diffusion_step_embedding(diffusion_steps, self.diffusion_step_embed_dim_in)
|
||||
diffusion_step_embed = swish(self.fc_t1(diffusion_step_embed))
|
||||
diffusion_step_embed = swish(self.fc_t2(diffusion_step_embed))
|
||||
|
||||
# Pass all residual layers
|
||||
h = x
|
||||
skip = 0
|
||||
for n in range(self.num_res_layers):
|
||||
# Use the output from last residual layer
|
||||
h, skip_n = self.residual_blocks[n]((h, mel_spectrogram, diffusion_step_embed))
|
||||
# Accumulate all skip outputs
|
||||
skip += skip_n
|
||||
|
||||
# Normalize for training stability
|
||||
return skip * math.sqrt(1.0 / self.num_res_layers)
|
||||
|
||||
|
||||
class DiffWave(ModelMixin, ConfigMixin):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=1,
|
||||
res_channels=128,
|
||||
skip_channels=128,
|
||||
out_channels=1,
|
||||
num_res_layers=30,
|
||||
dilation_cycle=10,
|
||||
diffusion_step_embed_dim_in=128,
|
||||
diffusion_step_embed_dim_mid=512,
|
||||
diffusion_step_embed_dim_out=512,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# register all init arguments with self.register
|
||||
self.register_to_config(
|
||||
in_channels=in_channels,
|
||||
res_channels=res_channels,
|
||||
skip_channels=skip_channels,
|
||||
out_channels=out_channels,
|
||||
num_res_layers=num_res_layers,
|
||||
dilation_cycle=dilation_cycle,
|
||||
diffusion_step_embed_dim_in=diffusion_step_embed_dim_in,
|
||||
diffusion_step_embed_dim_mid=diffusion_step_embed_dim_mid,
|
||||
diffusion_step_embed_dim_out=diffusion_step_embed_dim_out,
|
||||
)
|
||||
|
||||
# Initial conv1x1 with relu
|
||||
self.init_conv = nn.Sequential(Conv(in_channels, res_channels, kernel_size=1), nn.ReLU(inplace=False))
|
||||
# All residual layers
|
||||
self.residual_layer = ResidualGroup(
|
||||
res_channels,
|
||||
skip_channels,
|
||||
num_res_layers,
|
||||
dilation_cycle,
|
||||
diffusion_step_embed_dim_in,
|
||||
diffusion_step_embed_dim_mid,
|
||||
diffusion_step_embed_dim_out,
|
||||
)
|
||||
# Final conv1x1 -> relu -> zeroconv1x1
|
||||
self.final_conv = nn.Sequential(
|
||||
Conv(skip_channels, skip_channels, kernel_size=1),
|
||||
nn.ReLU(inplace=False),
|
||||
ZeroConv1d(skip_channels, out_channels),
|
||||
)
|
||||
|
||||
def forward(self, input_data):
|
||||
audio, mel_spectrogram, diffusion_steps = input_data
|
||||
x = audio
|
||||
x = self.init_conv(x).clone()
|
||||
x = self.residual_layer((x, mel_spectrogram, diffusion_steps))
|
||||
return self.final_conv(x)
|
||||
|
||||
|
||||
class BDDMPipeline(DiffusionPipeline):
|
||||
def __init__(self, diffwave, noise_scheduler):
|
||||
super().__init__()
|
||||
noise_scheduler = noise_scheduler.set_format("pt")
|
||||
self.register_modules(diffwave=diffwave, noise_scheduler=noise_scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, mel_spectrogram, generator, torch_device=None):
|
||||
if torch_device is None:
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
self.diffwave.to(torch_device)
|
||||
|
||||
mel_spectrogram = mel_spectrogram.to(torch_device)
|
||||
audio_length = mel_spectrogram.size(-1) * 256
|
||||
audio_size = (1, 1, audio_length)
|
||||
|
||||
# Sample gaussian noise to begin loop
|
||||
audio = torch.normal(0, 1, size=audio_size, generator=generator).to(torch_device)
|
||||
|
||||
timestep_values = self.noise_scheduler.config.timestep_values
|
||||
num_prediction_steps = len(self.noise_scheduler)
|
||||
for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
|
||||
# 1. predict noise residual
|
||||
ts = (torch.tensor(timestep_values[t]) * torch.ones((1, 1))).to(torch_device)
|
||||
residual = self.diffwave((audio, mel_spectrogram, ts))
|
||||
|
||||
# 2. predict previous mean of audio x_t-1
|
||||
pred_prev_audio = self.noise_scheduler.step(residual, audio, t)
|
||||
|
||||
# 3. optionally sample variance
|
||||
variance = 0
|
||||
if t > 0:
|
||||
noise = torch.normal(0, 1, size=audio_size, generator=generator).to(torch_device)
|
||||
variance = self.noise_scheduler.get_variance(t).sqrt() * noise
|
||||
|
||||
# 4. set current audio to prev_audio: x_t -> x_t-1
|
||||
audio = pred_prev_audio + variance
|
||||
|
||||
return audio
|
|
@ -1,6 +0,0 @@
|
|||
from ...utils import is_inflect_available, is_transformers_available, is_unidecode_available
|
||||
|
||||
|
||||
if is_transformers_available() and is_unidecode_available() and is_inflect_available():
|
||||
from .grad_tts_utils import GradTTSTokenizer
|
||||
from .pipeline_grad_tts import GradTTSPipeline, TextEncoder
|
|
@ -1,421 +0,0 @@
|
|||
# tokenizer
|
||||
|
||||
import os
|
||||
import re
|
||||
from shutil import copyfile
|
||||
|
||||
import torch
|
||||
|
||||
import inflect
|
||||
from transformers import PreTrainedTokenizer
|
||||
from unidecode import unidecode
|
||||
|
||||
|
||||
valid_symbols = [
|
||||
"AA",
|
||||
"AA0",
|
||||
"AA1",
|
||||
"AA2",
|
||||
"AE",
|
||||
"AE0",
|
||||
"AE1",
|
||||
"AE2",
|
||||
"AH",
|
||||
"AH0",
|
||||
"AH1",
|
||||
"AH2",
|
||||
"AO",
|
||||
"AO0",
|
||||
"AO1",
|
||||
"AO2",
|
||||
"AW",
|
||||
"AW0",
|
||||
"AW1",
|
||||
"AW2",
|
||||
"AY",
|
||||
"AY0",
|
||||
"AY1",
|
||||
"AY2",
|
||||
"B",
|
||||
"CH",
|
||||
"D",
|
||||
"DH",
|
||||
"EH",
|
||||
"EH0",
|
||||
"EH1",
|
||||
"EH2",
|
||||
"ER",
|
||||
"ER0",
|
||||
"ER1",
|
||||
"ER2",
|
||||
"EY",
|
||||
"EY0",
|
||||
"EY1",
|
||||
"EY2",
|
||||
"F",
|
||||
"G",
|
||||
"HH",
|
||||
"IH",
|
||||
"IH0",
|
||||
"IH1",
|
||||
"IH2",
|
||||
"IY",
|
||||
"IY0",
|
||||
"IY1",
|
||||
"IY2",
|
||||
"JH",
|
||||
"K",
|
||||
"L",
|
||||
"M",
|
||||
"N",
|
||||
"NG",
|
||||
"OW",
|
||||
"OW0",
|
||||
"OW1",
|
||||
"OW2",
|
||||
"OY",
|
||||
"OY0",
|
||||
"OY1",
|
||||
"OY2",
|
||||
"P",
|
||||
"R",
|
||||
"S",
|
||||
"SH",
|
||||
"T",
|
||||
"TH",
|
||||
"UH",
|
||||
"UH0",
|
||||
"UH1",
|
||||
"UH2",
|
||||
"UW",
|
||||
"UW0",
|
||||
"UW1",
|
||||
"UW2",
|
||||
"V",
|
||||
"W",
|
||||
"Y",
|
||||
"Z",
|
||||
"ZH",
|
||||
]
|
||||
|
||||
_valid_symbol_set = set(valid_symbols)
|
||||
|
||||
|
||||
def intersperse(lst, item):
|
||||
# Adds blank symbol
|
||||
result = [item] * (len(lst) * 2 + 1)
|
||||
result[1::2] = lst
|
||||
return result
|
||||
|
||||
|
||||
class CMUDict:
|
||||
def __init__(self, file_or_path, keep_ambiguous=True):
|
||||
if isinstance(file_or_path, str):
|
||||
with open(file_or_path, encoding="latin-1") as f:
|
||||
entries = _parse_cmudict(f)
|
||||
else:
|
||||
entries = _parse_cmudict(file_or_path)
|
||||
if not keep_ambiguous:
|
||||
entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
|
||||
self._entries = entries
|
||||
|
||||
def __len__(self):
|
||||
return len(self._entries)
|
||||
|
||||
def lookup(self, word):
|
||||
return self._entries.get(word.upper())
|
||||
|
||||
|
||||
_alt_re = re.compile(r"\([0-9]+\)")
|
||||
|
||||
|
||||
def _parse_cmudict(file):
|
||||
cmudict = {}
|
||||
for line in file:
|
||||
if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"):
|
||||
parts = line.split(" ")
|
||||
word = re.sub(_alt_re, "", parts[0])
|
||||
pronunciation = _get_pronunciation(parts[1])
|
||||
if pronunciation:
|
||||
if word in cmudict:
|
||||
cmudict[word].append(pronunciation)
|
||||
else:
|
||||
cmudict[word] = [pronunciation]
|
||||
return cmudict
|
||||
|
||||
|
||||
def _get_pronunciation(s):
|
||||
parts = s.strip().split(" ")
|
||||
for part in parts:
|
||||
if part not in _valid_symbol_set:
|
||||
return None
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
_whitespace_re = re.compile(r"\s+")
|
||||
|
||||
_abbreviations = [
|
||||
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("mrs", "misess"),
|
||||
("mr", "mister"),
|
||||
("dr", "doctor"),
|
||||
("st", "saint"),
|
||||
("co", "company"),
|
||||
("jr", "junior"),
|
||||
("maj", "major"),
|
||||
("gen", "general"),
|
||||
("drs", "doctors"),
|
||||
("rev", "reverend"),
|
||||
("lt", "lieutenant"),
|
||||
("hon", "honorable"),
|
||||
("sgt", "sergeant"),
|
||||
("capt", "captain"),
|
||||
("esq", "esquire"),
|
||||
("ltd", "limited"),
|
||||
("col", "colonel"),
|
||||
("ft", "fort"),
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
def expand_abbreviations(text):
|
||||
for regex, replacement in _abbreviations:
|
||||
text = re.sub(regex, replacement, text)
|
||||
return text
|
||||
|
||||
|
||||
def expand_numbers(text):
|
||||
return normalize_numbers(text)
|
||||
|
||||
|
||||
def lowercase(text):
|
||||
return text.lower()
|
||||
|
||||
|
||||
def collapse_whitespace(text):
|
||||
return re.sub(_whitespace_re, " ", text)
|
||||
|
||||
|
||||
def convert_to_ascii(text):
|
||||
return unidecode(text)
|
||||
|
||||
|
||||
def basic_cleaners(text):
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def transliteration_cleaners(text):
|
||||
text = convert_to_ascii(text)
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def english_cleaners(text):
|
||||
text = convert_to_ascii(text)
|
||||
text = lowercase(text)
|
||||
text = expand_numbers(text)
|
||||
text = expand_abbreviations(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
_inflect = inflect.engine()
|
||||
_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
|
||||
_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
|
||||
_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
|
||||
_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
|
||||
_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
|
||||
_number_re = re.compile(r"[0-9]+")
|
||||
|
||||
|
||||
def _remove_commas(m):
|
||||
return m.group(1).replace(",", "")
|
||||
|
||||
|
||||
def _expand_decimal_point(m):
|
||||
return m.group(1).replace(".", " point ")
|
||||
|
||||
|
||||
def _expand_dollars(m):
|
||||
match = m.group(1)
|
||||
parts = match.split(".")
|
||||
if len(parts) > 2:
|
||||
return match + " dollars"
|
||||
dollars = int(parts[0]) if parts[0] else 0
|
||||
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
|
||||
if dollars and cents:
|
||||
dollar_unit = "dollar" if dollars == 1 else "dollars"
|
||||
cent_unit = "cent" if cents == 1 else "cents"
|
||||
return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
|
||||
elif dollars:
|
||||
dollar_unit = "dollar" if dollars == 1 else "dollars"
|
||||
return "%s %s" % (dollars, dollar_unit)
|
||||
elif cents:
|
||||
cent_unit = "cent" if cents == 1 else "cents"
|
||||
return "%s %s" % (cents, cent_unit)
|
||||
else:
|
||||
return "zero dollars"
|
||||
|
||||
|
||||
def _expand_ordinal(m):
|
||||
return _inflect.number_to_words(m.group(0))
|
||||
|
||||
|
||||
def _expand_number(m):
|
||||
num = int(m.group(0))
|
||||
if num > 1000 and num < 3000:
|
||||
if num == 2000:
|
||||
return "two thousand"
|
||||
elif num > 2000 and num < 2010:
|
||||
return "two thousand " + _inflect.number_to_words(num % 100)
|
||||
elif num % 100 == 0:
|
||||
return _inflect.number_to_words(num // 100) + " hundred"
|
||||
else:
|
||||
return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
|
||||
else:
|
||||
return _inflect.number_to_words(num, andword="")
|
||||
|
||||
|
||||
def normalize_numbers(text):
|
||||
text = re.sub(_comma_number_re, _remove_commas, text)
|
||||
text = re.sub(_pounds_re, r"\1 pounds", text)
|
||||
text = re.sub(_dollars_re, _expand_dollars, text)
|
||||
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
|
||||
text = re.sub(_ordinal_re, _expand_ordinal, text)
|
||||
text = re.sub(_number_re, _expand_number, text)
|
||||
return text
|
||||
|
||||
|
||||
""" from https://github.com/keithito/tacotron"""
|
||||
|
||||
|
||||
_pad = "_"
|
||||
_punctuation = "!'(),.:;? "
|
||||
_special = "-"
|
||||
_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
||||
|
||||
# Prepend "@" to ARPAbet symbols to ensure uniqueness:
|
||||
_arpabet = ["@" + s for s in valid_symbols]
|
||||
|
||||
# Export all symbols:
|
||||
symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpabet
|
||||
|
||||
|
||||
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
||||
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
|
||||
|
||||
_curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)")
|
||||
|
||||
|
||||
def get_arpabet(word, dictionary):
|
||||
word_arpabet = dictionary.lookup(word)
|
||||
if word_arpabet is not None:
|
||||
return "{" + word_arpabet[0] + "}"
|
||||
else:
|
||||
return word
|
||||
|
||||
|
||||
def text_to_sequence(text, cleaner_names=[english_cleaners], dictionary=None):
|
||||
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
||||
|
||||
The text can optionally have ARPAbet sequences enclosed in curly braces embedded in it. For example, "Turn left on
|
||||
{HH AW1 S S T AH0 N} Street."
|
||||
|
||||
Args:
|
||||
text: string to convert to a sequence
|
||||
cleaner_names: names of the cleaner functions to run the text through
|
||||
dictionary: arpabet class with arpabet dictionary
|
||||
|
||||
Returns:
|
||||
List of integers corresponding to the symbols in the text
|
||||
"""
|
||||
sequence = []
|
||||
space = _symbols_to_sequence(" ")
|
||||
# Check for curly braces and treat their contents as ARPAbet:
|
||||
while len(text):
|
||||
m = _curly_re.match(text)
|
||||
if not m:
|
||||
clean_text = _clean_text(text, cleaner_names)
|
||||
if dictionary is not None:
|
||||
clean_text = [get_arpabet(w, dictionary) for w in clean_text.split(" ")]
|
||||
for i in range(len(clean_text)):
|
||||
t = clean_text[i]
|
||||
if t.startswith("{"):
|
||||
sequence += _arpabet_to_sequence(t[1:-1])
|
||||
else:
|
||||
sequence += _symbols_to_sequence(t)
|
||||
sequence += space
|
||||
else:
|
||||
sequence += _symbols_to_sequence(clean_text)
|
||||
break
|
||||
sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names))
|
||||
sequence += _arpabet_to_sequence(m.group(2))
|
||||
text = m.group(3)
|
||||
|
||||
# remove trailing space
|
||||
if dictionary is not None:
|
||||
sequence = sequence[:-1] if sequence[-1] == space[0] else sequence
|
||||
return sequence
|
||||
|
||||
|
||||
def sequence_to_text(sequence):
|
||||
"""Converts a sequence of IDs back to a string"""
|
||||
result = ""
|
||||
for symbol_id in sequence:
|
||||
if symbol_id in _id_to_symbol:
|
||||
s = _id_to_symbol[symbol_id]
|
||||
# Enclose ARPAbet back in curly braces:
|
||||
if len(s) > 1 and s[0] == "@":
|
||||
s = "{%s}" % s[1:]
|
||||
result += s
|
||||
return result.replace("}{", " ")
|
||||
|
||||
|
||||
def _clean_text(text, cleaner_names):
|
||||
for cleaner in cleaner_names:
|
||||
text = cleaner(text)
|
||||
return text
|
||||
|
||||
|
||||
def _symbols_to_sequence(symbols):
|
||||
return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]
|
||||
|
||||
|
||||
def _arpabet_to_sequence(text):
|
||||
return _symbols_to_sequence(["@" + s for s in text.split()])
|
||||
|
||||
|
||||
def _should_keep_symbol(s):
|
||||
return s in _symbol_to_id and s != "_" and s != "~"
|
||||
|
||||
|
||||
VOCAB_FILES_NAMES = {
|
||||
"dict_file": "dict_file.txt",
|
||||
}
|
||||
|
||||
|
||||
class GradTTSTokenizer(PreTrainedTokenizer):
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
|
||||
def __init__(self, dict_file, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.cmu = CMUDict(dict_file)
|
||||
self.dict_file = dict_file
|
||||
|
||||
def __call__(self, text):
|
||||
x = torch.LongTensor(intersperse(text_to_sequence(text, dictionary=self.cmu), len(symbols)))[None]
|
||||
x_lengths = torch.LongTensor([x.shape[-1]])
|
||||
return x, x_lengths
|
||||
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix=None):
|
||||
dict_file = os.path.join(
|
||||
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["dict_file"]
|
||||
)
|
||||
|
||||
copyfile(self.dict_file, dict_file)
|
||||
|
||||
return (dict_file,)
|
|
@ -1,489 +0,0 @@
|
|||
""" from https://github.com/jaywalnut310/glow-tts"""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
import tqdm
|
||||
|
||||
from ...configuration_utils import ConfigMixin
|
||||
from ...modeling_utils import ModelMixin
|
||||
from ...pipeline_utils import DiffusionPipeline
|
||||
from .grad_tts_utils import GradTTSTokenizer # flake8: noqa
|
||||
|
||||
|
||||
def sequence_mask(length, max_length=None):
|
||||
if max_length is None:
|
||||
max_length = length.max()
|
||||
x = torch.arange(int(max_length), dtype=length.dtype, device=length.device)
|
||||
return x.unsqueeze(0) < length.unsqueeze(1)
|
||||
|
||||
|
||||
def fix_len_compatibility(length, num_downsamplings_in_unet=2):
|
||||
while True:
|
||||
if length % (2**num_downsamplings_in_unet) == 0:
|
||||
return length
|
||||
length += 1
|
||||
|
||||
|
||||
def convert_pad_shape(pad_shape):
|
||||
l = pad_shape[::-1]
|
||||
pad_shape = [item for sublist in l for item in sublist]
|
||||
return pad_shape
|
||||
|
||||
|
||||
def generate_path(duration, mask):
|
||||
device = duration.device
|
||||
|
||||
b, t_x, t_y = mask.shape
|
||||
cum_duration = torch.cumsum(duration, 1)
|
||||
path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device)
|
||||
|
||||
cum_duration_flat = cum_duration.view(b * t_x)
|
||||
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
||||
path = path.view(b, t_x, t_y)
|
||||
path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
|
||||
path = path * mask
|
||||
return path
|
||||
|
||||
|
||||
def duration_loss(logw, logw_, lengths):
|
||||
loss = torch.sum((logw - logw_) ** 2) / torch.sum(lengths)
|
||||
return loss
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, channels, eps=1e-4):
|
||||
super(LayerNorm, self).__init__()
|
||||
self.channels = channels
|
||||
self.eps = eps
|
||||
|
||||
self.gamma = torch.nn.Parameter(torch.ones(channels))
|
||||
self.beta = torch.nn.Parameter(torch.zeros(channels))
|
||||
|
||||
def forward(self, x):
|
||||
n_dims = len(x.shape)
|
||||
mean = torch.mean(x, 1, keepdim=True)
|
||||
variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
|
||||
|
||||
x = (x - mean) * torch.rsqrt(variance + self.eps)
|
||||
|
||||
shape = [1, -1] + [1] * (n_dims - 2)
|
||||
x = x * self.gamma.view(*shape) + self.beta.view(*shape)
|
||||
return x
|
||||
|
||||
|
||||
class ConvReluNorm(nn.Module):
|
||||
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
|
||||
super(ConvReluNorm, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.n_layers = n_layers
|
||||
self.p_dropout = p_dropout
|
||||
|
||||
self.conv_layers = torch.nn.ModuleList()
|
||||
self.norm_layers = torch.nn.ModuleList()
|
||||
self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
|
||||
self.norm_layers.append(LayerNorm(hidden_channels))
|
||||
self.relu_drop = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Dropout(p_dropout))
|
||||
for _ in range(n_layers - 1):
|
||||
self.conv_layers.append(
|
||||
torch.nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2)
|
||||
)
|
||||
self.norm_layers.append(LayerNorm(hidden_channels))
|
||||
self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1)
|
||||
self.proj.weight.data.zero_()
|
||||
self.proj.bias.data.zero_()
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
x_org = x
|
||||
for i in range(self.n_layers):
|
||||
x = self.conv_layers[i](x * x_mask)
|
||||
x = self.norm_layers[i](x)
|
||||
x = self.relu_drop(x)
|
||||
x = x_org + self.proj(x)
|
||||
return x * x_mask
|
||||
|
||||
|
||||
class DurationPredictor(nn.Module):
|
||||
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout):
|
||||
super(DurationPredictor, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.p_dropout = p_dropout
|
||||
|
||||
self.drop = torch.nn.Dropout(p_dropout)
|
||||
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
||||
self.norm_1 = LayerNorm(filter_channels)
|
||||
self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
||||
self.norm_2 = LayerNorm(filter_channels)
|
||||
self.proj = torch.nn.Conv1d(filter_channels, 1, 1)
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
x = self.conv_1(x * x_mask)
|
||||
x = torch.relu(x)
|
||||
x = self.norm_1(x)
|
||||
x = self.drop(x)
|
||||
x = self.conv_2(x * x_mask)
|
||||
x = torch.relu(x)
|
||||
x = self.norm_2(x)
|
||||
x = self.drop(x)
|
||||
x = self.proj(x * x_mask)
|
||||
return x * x_mask
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
out_channels,
|
||||
n_heads,
|
||||
window_size=None,
|
||||
heads_share=True,
|
||||
p_dropout=0.0,
|
||||
proximal_bias=False,
|
||||
proximal_init=False,
|
||||
):
|
||||
super(MultiHeadAttention, self).__init__()
|
||||
assert channels % n_heads == 0
|
||||
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels
|
||||
self.n_heads = n_heads
|
||||
self.window_size = window_size
|
||||
self.heads_share = heads_share
|
||||
self.proximal_bias = proximal_bias
|
||||
self.p_dropout = p_dropout
|
||||
self.attn = None
|
||||
|
||||
self.k_channels = channels // n_heads
|
||||
self.conv_q = torch.nn.Conv1d(channels, channels, 1)
|
||||
self.conv_k = torch.nn.Conv1d(channels, channels, 1)
|
||||
self.conv_v = torch.nn.Conv1d(channels, channels, 1)
|
||||
if window_size is not None:
|
||||
n_heads_rel = 1 if heads_share else n_heads
|
||||
rel_stddev = self.k_channels**-0.5
|
||||
self.emb_rel_k = torch.nn.Parameter(
|
||||
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev
|
||||
)
|
||||
self.emb_rel_v = torch.nn.Parameter(
|
||||
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev
|
||||
)
|
||||
self.conv_o = torch.nn.Conv1d(channels, out_channels, 1)
|
||||
self.drop = torch.nn.Dropout(p_dropout)
|
||||
|
||||
torch.nn.init.xavier_uniform_(self.conv_q.weight)
|
||||
torch.nn.init.xavier_uniform_(self.conv_k.weight)
|
||||
if proximal_init:
|
||||
self.conv_k.weight.data.copy_(self.conv_q.weight.data)
|
||||
self.conv_k.bias.data.copy_(self.conv_q.bias.data)
|
||||
torch.nn.init.xavier_uniform_(self.conv_v.weight)
|
||||
|
||||
def forward(self, x, c, attn_mask=None):
|
||||
q = self.conv_q(x)
|
||||
k = self.conv_k(c)
|
||||
v = self.conv_v(c)
|
||||
|
||||
x, self.attn = self.attention(q, k, v, mask=attn_mask)
|
||||
|
||||
x = self.conv_o(x)
|
||||
return x
|
||||
|
||||
def attention(self, query, key, value, mask=None):
|
||||
b, d, t_s, t_t = (*key.size(), query.size(2))
|
||||
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
|
||||
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
||||
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
||||
|
||||
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels)
|
||||
if self.window_size is not None:
|
||||
assert t_s == t_t, "Relative attention is only available for self-attention."
|
||||
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
|
||||
rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings)
|
||||
rel_logits = self._relative_position_to_absolute_position(rel_logits)
|
||||
scores_local = rel_logits / math.sqrt(self.k_channels)
|
||||
scores = scores + scores_local
|
||||
if self.proximal_bias:
|
||||
assert t_s == t_t, "Proximal bias is only available for self-attention."
|
||||
scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
|
||||
if mask is not None:
|
||||
scores = scores.masked_fill(mask == 0, -1e4)
|
||||
p_attn = torch.nn.functional.softmax(scores, dim=-1)
|
||||
p_attn = self.drop(p_attn)
|
||||
output = torch.matmul(p_attn, value)
|
||||
if self.window_size is not None:
|
||||
relative_weights = self._absolute_position_to_relative_position(p_attn)
|
||||
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
|
||||
output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
|
||||
output = output.transpose(2, 3).contiguous().view(b, d, t_t)
|
||||
return output, p_attn
|
||||
|
||||
def _matmul_with_relative_values(self, x, y):
|
||||
ret = torch.matmul(x, y.unsqueeze(0))
|
||||
return ret
|
||||
|
||||
def _matmul_with_relative_keys(self, x, y):
|
||||
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
|
||||
return ret
|
||||
|
||||
def _get_relative_embeddings(self, relative_embeddings, length):
|
||||
pad_length = max(length - (self.window_size + 1), 0)
|
||||
slice_start_position = max((self.window_size + 1) - length, 0)
|
||||
slice_end_position = slice_start_position + 2 * length - 1
|
||||
if pad_length > 0:
|
||||
padded_relative_embeddings = torch.nn.functional.pad(
|
||||
relative_embeddings, convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]])
|
||||
)
|
||||
else:
|
||||
padded_relative_embeddings = relative_embeddings
|
||||
used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position]
|
||||
return used_relative_embeddings
|
||||
|
||||
def _relative_position_to_absolute_position(self, x):
|
||||
batch, heads, length, _ = x.size()
|
||||
x = torch.nn.functional.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
|
||||
x_flat = x.view([batch, heads, length * 2 * length])
|
||||
x_flat = torch.nn.functional.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
|
||||
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1 :]
|
||||
return x_final
|
||||
|
||||
def _absolute_position_to_relative_position(self, x):
|
||||
batch, heads, length, _ = x.size()
|
||||
x = torch.nn.functional.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
|
||||
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
|
||||
x_flat = torch.nn.functional.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
|
||||
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
|
||||
return x_final
|
||||
|
||||
def _attention_bias_proximal(self, length):
|
||||
r = torch.arange(length, dtype=torch.float32)
|
||||
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
|
||||
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
|
||||
|
||||
|
||||
class FFN(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0):
|
||||
super(FFN, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
|
||||
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
||||
self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size, padding=kernel_size // 2)
|
||||
self.drop = torch.nn.Dropout(p_dropout)
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
x = self.conv_1(x * x_mask)
|
||||
x = torch.relu(x)
|
||||
x = self.drop(x)
|
||||
x = self.conv_2(x * x_mask)
|
||||
return x * x_mask
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size=1,
|
||||
p_dropout=0.0,
|
||||
window_size=None,
|
||||
**kwargs,
|
||||
):
|
||||
super(Encoder, self).__init__()
|
||||
self.hidden_channels = hidden_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.n_heads = n_heads
|
||||
self.n_layers = n_layers
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
self.window_size = window_size
|
||||
|
||||
self.drop = torch.nn.Dropout(p_dropout)
|
||||
self.attn_layers = torch.nn.ModuleList()
|
||||
self.norm_layers_1 = torch.nn.ModuleList()
|
||||
self.ffn_layers = torch.nn.ModuleList()
|
||||
self.norm_layers_2 = torch.nn.ModuleList()
|
||||
for _ in range(self.n_layers):
|
||||
self.attn_layers.append(
|
||||
MultiHeadAttention(
|
||||
hidden_channels, hidden_channels, n_heads, window_size=window_size, p_dropout=p_dropout
|
||||
)
|
||||
)
|
||||
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
||||
self.ffn_layers.append(
|
||||
FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout)
|
||||
)
|
||||
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
||||
for i in range(self.n_layers):
|
||||
x = x * x_mask
|
||||
y = self.attn_layers[i](x, x, attn_mask)
|
||||
y = self.drop(y)
|
||||
x = self.norm_layers_1[i](x + y)
|
||||
y = self.ffn_layers[i](x, x_mask)
|
||||
y = self.drop(y)
|
||||
x = self.norm_layers_2[i](x + y)
|
||||
x = x * x_mask
|
||||
return x
|
||||
|
||||
|
||||
class TextEncoder(ModelMixin, ConfigMixin):
|
||||
def __init__(
|
||||
self,
|
||||
n_vocab,
|
||||
n_feats,
|
||||
n_channels,
|
||||
filter_channels,
|
||||
filter_channels_dp,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout,
|
||||
window_size=None,
|
||||
spk_emb_dim=64,
|
||||
n_spks=1,
|
||||
):
|
||||
super(TextEncoder, self).__init__()
|
||||
|
||||
self.register_to_config(
|
||||
n_vocab=n_vocab,
|
||||
n_feats=n_feats,
|
||||
n_channels=n_channels,
|
||||
filter_channels=filter_channels,
|
||||
filter_channels_dp=filter_channels_dp,
|
||||
n_heads=n_heads,
|
||||
n_layers=n_layers,
|
||||
kernel_size=kernel_size,
|
||||
p_dropout=p_dropout,
|
||||
window_size=window_size,
|
||||
spk_emb_dim=spk_emb_dim,
|
||||
n_spks=n_spks,
|
||||
)
|
||||
|
||||
self.n_vocab = n_vocab
|
||||
self.n_feats = n_feats
|
||||
self.n_channels = n_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.filter_channels_dp = filter_channels_dp
|
||||
self.n_heads = n_heads
|
||||
self.n_layers = n_layers
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
self.window_size = window_size
|
||||
self.spk_emb_dim = spk_emb_dim
|
||||
self.n_spks = n_spks
|
||||
|
||||
self.emb = torch.nn.Embedding(n_vocab, n_channels)
|
||||
torch.nn.init.normal_(self.emb.weight, 0.0, n_channels**-0.5)
|
||||
|
||||
self.prenet = ConvReluNorm(n_channels, n_channels, n_channels, kernel_size=5, n_layers=3, p_dropout=0.5)
|
||||
|
||||
self.encoder = Encoder(
|
||||
n_channels + (spk_emb_dim if n_spks > 1 else 0),
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout,
|
||||
window_size=window_size,
|
||||
)
|
||||
|
||||
self.proj_m = torch.nn.Conv1d(n_channels + (spk_emb_dim if n_spks > 1 else 0), n_feats, 1)
|
||||
self.proj_w = DurationPredictor(
|
||||
n_channels + (spk_emb_dim if n_spks > 1 else 0), filter_channels_dp, kernel_size, p_dropout
|
||||
)
|
||||
|
||||
def forward(self, x, x_lengths, spk=None):
|
||||
x = self.emb(x) * math.sqrt(self.n_channels)
|
||||
x = torch.transpose(x, 1, -1)
|
||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||
|
||||
x = self.prenet(x, x_mask)
|
||||
if self.n_spks > 1:
|
||||
x = torch.cat([x, spk.unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1)
|
||||
x = self.encoder(x, x_mask)
|
||||
mu = self.proj_m(x) * x_mask
|
||||
|
||||
x_dp = torch.detach(x)
|
||||
logw = self.proj_w(x_dp, x_mask)
|
||||
|
||||
return mu, logw, x_mask
|
||||
|
||||
|
||||
class GradTTSPipeline(DiffusionPipeline):
|
||||
def __init__(self, unet, text_encoder, noise_scheduler, tokenizer):
|
||||
super().__init__()
|
||||
noise_scheduler = noise_scheduler.set_format("pt")
|
||||
self.register_modules(
|
||||
unet=unet, text_encoder=text_encoder, noise_scheduler=noise_scheduler, tokenizer=tokenizer
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
text,
|
||||
num_inference_steps=50,
|
||||
temperature=1.3,
|
||||
length_scale=0.91,
|
||||
speaker_id=15,
|
||||
torch_device=None,
|
||||
generator=None,
|
||||
):
|
||||
if torch_device is None:
|
||||
torch_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
self.unet.to(torch_device)
|
||||
self.text_encoder.to(torch_device)
|
||||
|
||||
x, x_lengths = self.tokenizer(text)
|
||||
x = x.to(torch_device)
|
||||
x_lengths = x_lengths.to(torch_device)
|
||||
|
||||
if speaker_id is not None:
|
||||
speaker_id = torch.LongTensor([speaker_id]).to(torch_device)
|
||||
|
||||
# Get encoder_outputs `mu_x` and log-scaled token durations `logw`
|
||||
mu_x, logw, x_mask = self.text_encoder(x, x_lengths)
|
||||
|
||||
w = torch.exp(logw) * x_mask
|
||||
w_ceil = torch.ceil(w) * length_scale
|
||||
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
|
||||
y_max_length = int(y_lengths.max())
|
||||
y_max_length_ = fix_len_compatibility(y_max_length)
|
||||
|
||||
# Using obtained durations `w` construct alignment map `attn`
|
||||
y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype)
|
||||
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
|
||||
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1)
|
||||
|
||||
# Align encoded text and get mu_y
|
||||
mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
|
||||
mu_y = mu_y.transpose(1, 2)
|
||||
|
||||
# Sample latent representation from terminal distribution N(mu_y, I)
|
||||
z = mu_y + torch.randn(mu_y.shape, generator=generator).to(mu_y.device)
|
||||
|
||||
xt = z * y_mask
|
||||
h = 1.0 / num_inference_steps
|
||||
# (Patrick: TODO)
|
||||
for t in tqdm.tqdm(range(num_inference_steps), total=num_inference_steps):
|
||||
t_new = num_inference_steps - t - 1
|
||||
t = (1.0 - (t + 0.5) * h) * torch.ones(z.shape[0], dtype=z.dtype, device=z.device)
|
||||
|
||||
residual = self.unet(xt, t, mu_y, y_mask, speaker_id)
|
||||
|
||||
scheduler_residual = residual - mu_y + xt
|
||||
xt = self.noise_scheduler.step(scheduler_residual, xt, t_new, num_inference_steps)
|
||||
xt = xt * y_mask
|
||||
|
||||
return xt[:, :, :y_max_length]
|
|
@ -18,7 +18,6 @@
|
|||
|
||||
from .scheduling_ddim import DDIMScheduler
|
||||
from .scheduling_ddpm import DDPMScheduler
|
||||
from .scheduling_grad_tts import GradTTSScheduler
|
||||
from .scheduling_pndm import PNDMScheduler
|
||||
from .scheduling_sde_ve import ScoreSdeVeScheduler
|
||||
from .scheduling_sde_vp import ScoreSdeVpScheduler
|
||||
|
|
|
@ -1,54 +0,0 @@
|
|||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
class GradTTSScheduler(SchedulerMixin, ConfigMixin):
|
||||
def __init__(
|
||||
self,
|
||||
beta_start=0.05,
|
||||
beta_end=20,
|
||||
tensor_format="np",
|
||||
):
|
||||
super().__init__()
|
||||
self.register_to_config(
|
||||
beta_start=beta_start,
|
||||
beta_end=beta_end,
|
||||
)
|
||||
self.set_format(tensor_format=tensor_format)
|
||||
self.betas = None
|
||||
|
||||
def get_timesteps(self, num_inference_steps):
|
||||
return np.array([(t + 0.5) / num_inference_steps for t in range(num_inference_steps)])
|
||||
|
||||
def set_betas(self, num_inference_steps):
|
||||
timesteps = self.get_timesteps(num_inference_steps)
|
||||
self.betas = np.array([self.beta_start + (self.beta_end - self.beta_start) * t for t in timesteps])
|
||||
|
||||
def step(self, residual, sample, t, num_inference_steps):
|
||||
# This is a VE scheduler from https://arxiv.org/pdf/2011.13456.pdf (see Algorithm 2 in Appendix)
|
||||
if self.betas is None:
|
||||
self.set_betas(num_inference_steps)
|
||||
|
||||
beta_t = self.betas[t]
|
||||
beta_t_deriv = beta_t / num_inference_steps
|
||||
|
||||
sample_deriv = residual * beta_t_deriv / 2
|
||||
|
||||
sample = sample + sample_deriv
|
||||
return sample
|
|
@ -23,7 +23,6 @@ import torch
|
|||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
BDDMPipeline,
|
||||
DDIMPipeline,
|
||||
DDIMScheduler,
|
||||
DDPMPipeline,
|
||||
|
@ -31,8 +30,6 @@ from diffusers import (
|
|||
GlidePipeline,
|
||||
GlideSuperResUNetModel,
|
||||
GlideTextToImageUNetModel,
|
||||
GradTTSPipeline,
|
||||
GradTTSScheduler,
|
||||
LatentDiffusionPipeline,
|
||||
LatentDiffusionUncondPipeline,
|
||||
NCSNpp,
|
||||
|
@ -42,8 +39,6 @@ from diffusers import (
|
|||
ScoreSdeVeScheduler,
|
||||
ScoreSdeVpPipeline,
|
||||
ScoreSdeVpScheduler,
|
||||
TemporalUNet,
|
||||
UNetGradTTSModel,
|
||||
UNetLDMModel,
|
||||
UNetModel,
|
||||
UNetUnconditionalModel,
|
||||
|
@ -51,7 +46,6 @@ from diffusers import (
|
|||
)
|
||||
from diffusers.configuration_utils import ConfigMixin
|
||||
from diffusers.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.pipelines.bddm.pipeline_bddm import DiffWave
|
||||
from diffusers.testing_utils import floats_tensor, slow, torch_device
|
||||
from diffusers.training_utils import EMAModel
|
||||
|
||||
|
@ -556,149 +550,6 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
|
|||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
|
||||
|
||||
class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = UNetGradTTSModel
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_features = 32
|
||||
seq_len = 16
|
||||
|
||||
noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device)
|
||||
condition = floats_tensor((batch_size, num_features, seq_len)).to(torch_device)
|
||||
mask = floats_tensor((batch_size, 1, seq_len)).to(torch_device)
|
||||
time_step = torch.tensor([10] * batch_size).to(torch_device)
|
||||
|
||||
return {"sample": noise, "timesteps": time_step, "mu": condition, "mask": mask}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 32, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 32, 16)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"dim": 64,
|
||||
"groups": 4,
|
||||
"dim_mults": (1, 2),
|
||||
"n_feats": 32,
|
||||
"pe_scale": 1000,
|
||||
"n_spks": 1,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_from_pretrained_hub(self):
|
||||
model, loading_info = UNetGradTTSModel.from_pretrained("fusing/unet-grad-tts-dummy", output_loading_info=True)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertEqual(len(loading_info["missing_keys"]), 0)
|
||||
|
||||
model.to(torch_device)
|
||||
image = model(**self.dummy_input)
|
||||
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
def test_output_pretrained(self):
|
||||
model = UNetGradTTSModel.from_pretrained("fusing/unet-grad-tts-dummy")
|
||||
model.eval()
|
||||
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
num_features = model.config.n_feats
|
||||
seq_len = 16
|
||||
noise = torch.randn((1, num_features, seq_len))
|
||||
condition = torch.randn((1, num_features, seq_len))
|
||||
mask = torch.randn((1, 1, seq_len))
|
||||
time_step = torch.tensor([10])
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(noise, time_step, condition, mask)
|
||||
|
||||
output_slice = output[0, -3:, -3:].flatten()
|
||||
# fmt: off
|
||||
expected_output_slice = torch.tensor([-0.0690, -0.0531, 0.0633, -0.0660, -0.0541, 0.0650, -0.0656, -0.0555, 0.0617])
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3))
|
||||
|
||||
|
||||
class TemporalUNetModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = TemporalUNet
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_features = 14
|
||||
seq_len = 16
|
||||
|
||||
noise = floats_tensor((batch_size, seq_len, num_features)).to(torch_device)
|
||||
time_step = torch.tensor([10] * batch_size).to(torch_device)
|
||||
|
||||
return {"sample": noise, "timesteps": time_step}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 16, 14)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 16, 14)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"training_horizon": 128,
|
||||
"dim": 32,
|
||||
"dim_mults": [1, 4, 8],
|
||||
"predict_epsilon": False,
|
||||
"clip_denoised": True,
|
||||
"transition_dim": 14,
|
||||
"cond_dim": 3,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_from_pretrained_hub(self):
|
||||
model, loading_info = TemporalUNet.from_pretrained(
|
||||
"fusing/ddpm-unet-rl-hopper-hor128", output_loading_info=True
|
||||
)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertEqual(len(loading_info["missing_keys"]), 0)
|
||||
|
||||
model.to(torch_device)
|
||||
image = model(**self.dummy_input)
|
||||
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
def test_output_pretrained(self):
|
||||
model = TemporalUNet.from_pretrained("fusing/ddpm-unet-rl-hopper-hor128")
|
||||
model.eval()
|
||||
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
num_features = model.transition_dim
|
||||
seq_len = 16
|
||||
noise = torch.randn((1, seq_len, num_features))
|
||||
time_step = torch.full((num_features,), 0)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(noise, time_step)
|
||||
|
||||
output_slice = output[0, -3:, -3:].flatten()
|
||||
# fmt: off
|
||||
expected_output_slice = torch.tensor([-0.2714, 0.1042, -0.0794, -0.2820, 0.0803, -0.0811, -0.2345, 0.0580, -0.0584])
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3))
|
||||
|
||||
|
||||
class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = NCSNpp
|
||||
|
||||
|
@ -1116,25 +967,6 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||
expected_slice = torch.tensor([0.7119, 0.7073, 0.6460, 0.7780, 0.7423, 0.6926, 0.7378, 0.7189, 0.7784])
|
||||
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
||||
|
||||
@slow
|
||||
def test_grad_tts(self):
|
||||
model_id = "fusing/grad-tts-libri-tts"
|
||||
grad_tts = GradTTSPipeline.from_pretrained(model_id)
|
||||
noise_scheduler = GradTTSScheduler()
|
||||
grad_tts.noise_scheduler = noise_scheduler
|
||||
|
||||
text = "Hello world, I missed you so much."
|
||||
generator = torch.manual_seed(0)
|
||||
|
||||
# generate mel spectograms using text
|
||||
mel_spec = grad_tts(text, generator=generator)
|
||||
|
||||
assert mel_spec.shape == (1, 80, 143)
|
||||
expected_slice = torch.tensor(
|
||||
[-6.7584, -6.8347, -6.3293, -6.6437, -6.7233, -6.4684, -6.1187, -6.3172, -6.6890]
|
||||
)
|
||||
assert (mel_spec[0, :3, :3].cpu().flatten() - expected_slice).abs().max() < 1e-2
|
||||
|
||||
@slow
|
||||
def test_score_sde_ve_pipeline(self):
|
||||
model = NCSNpp.from_pretrained("fusing/ffhq_ncsnpp")
|
||||
|
@ -1181,21 +1013,3 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||
[-0.1202, -0.1005, -0.0635, -0.0520, -0.1282, -0.0838, -0.0981, -0.1318, -0.1106]
|
||||
)
|
||||
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
||||
|
||||
def test_module_from_pipeline(self):
|
||||
model = DiffWave(num_res_layers=4)
|
||||
noise_scheduler = DDPMScheduler(timesteps=12)
|
||||
|
||||
bddm = BDDMPipeline(model, noise_scheduler)
|
||||
|
||||
# check if the library name for the diffwave moduel is set to pipeline module
|
||||
self.assertTrue(bddm.config["diffwave"][0] == "bddm")
|
||||
|
||||
# check if we can save and load the pipeline
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
bddm.save_pretrained(tmpdirname)
|
||||
_ = BDDMPipeline.from_pretrained(tmpdirname)
|
||||
# check if the same works using the DifusionPipeline class
|
||||
bddm = DiffusionPipeline.from_pretrained(tmpdirname)
|
||||
|
||||
self.assertTrue(bddm.config["diffwave"][0] == "bddm")
|
||||
|
|
Loading…
Reference in New Issue