This commit is contained in:
patil-suraj 2022-06-16 18:26:00 +02:00
parent 988369a01c
commit ace07110c1
5 changed files with 197 additions and 114 deletions

View File

@ -145,8 +145,9 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
if n_spks > 1:
self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim)
self.spk_mlp = torch.nn.Sequential(torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(),
torch.nn.Linear(spk_emb_dim * 4, n_feats))
self.spk_mlp = torch.nn.Sequential(
torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(), torch.nn.Linear(spk_emb_dim * 4, n_feats)
)
self.time_pos_emb = SinusoidalPosEmb(dim)
self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(), torch.nn.Linear(dim * 4, dim))

View File

@ -1,11 +1,12 @@
# tokenizer
import re
import os
import re
from shutil import copyfile
import torch
try:
from transformers import PreTrainedTokenizer
except:
@ -25,17 +26,95 @@ except:
valid_symbols = [
'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2',
'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2',
'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY',
'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1',
'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0',
'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW',
'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH'
"AA",
"AA0",
"AA1",
"AA2",
"AE",
"AE0",
"AE1",
"AE2",
"AH",
"AH0",
"AH1",
"AH2",
"AO",
"AO0",
"AO1",
"AO2",
"AW",
"AW0",
"AW1",
"AW2",
"AY",
"AY0",
"AY1",
"AY2",
"B",
"CH",
"D",
"DH",
"EH",
"EH0",
"EH1",
"EH2",
"ER",
"ER0",
"ER1",
"ER2",
"EY",
"EY0",
"EY1",
"EY2",
"F",
"G",
"HH",
"IH",
"IH0",
"IH1",
"IH2",
"IY",
"IY0",
"IY1",
"IY2",
"JH",
"K",
"L",
"M",
"N",
"NG",
"OW",
"OW0",
"OW1",
"OW2",
"OY",
"OY0",
"OY1",
"OY2",
"P",
"R",
"S",
"SH",
"T",
"TH",
"UH",
"UH0",
"UH1",
"UH2",
"UW",
"UW0",
"UW1",
"UW2",
"V",
"W",
"Y",
"Z",
"ZH",
]
_valid_symbol_set = set(valid_symbols)
def intersperse(lst, item):
# Adds blank symbol
result = [item] * (len(lst) * 2 + 1)
@ -46,7 +125,7 @@ def intersperse(lst, item):
class CMUDict:
def __init__(self, file_or_path, keep_ambiguous=True):
if isinstance(file_or_path, str):
with open(file_or_path, encoding='latin-1') as f:
with open(file_or_path, encoding="latin-1") as f:
entries = _parse_cmudict(f)
else:
entries = _parse_cmudict(file_or_path)
@ -61,15 +140,15 @@ class CMUDict:
return self._entries.get(word.upper())
_alt_re = re.compile(r'\([0-9]+\)')
_alt_re = re.compile(r"\([0-9]+\)")
def _parse_cmudict(file):
cmudict = {}
for line in file:
if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"):
parts = line.split(' ')
word = re.sub(_alt_re, '', parts[0])
if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"):
parts = line.split(" ")
word = re.sub(_alt_re, "", parts[0])
pronunciation = _get_pronunciation(parts[1])
if pronunciation:
if word in cmudict:
@ -80,36 +159,38 @@ def _parse_cmudict(file):
def _get_pronunciation(s):
parts = s.strip().split(' ')
parts = s.strip().split(" ")
for part in parts:
if part not in _valid_symbol_set:
return None
return ' '.join(parts)
return " ".join(parts)
_whitespace_re = re.compile(r"\s+")
_whitespace_re = re.compile(r'\s+')
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
('mrs', 'misess'),
('mr', 'mister'),
('dr', 'doctor'),
('st', 'saint'),
('co', 'company'),
('jr', 'junior'),
('maj', 'major'),
('gen', 'general'),
('drs', 'doctors'),
('rev', 'reverend'),
('lt', 'lieutenant'),
('hon', 'honorable'),
('sgt', 'sergeant'),
('capt', 'captain'),
('esq', 'esquire'),
('ltd', 'limited'),
('col', 'colonel'),
('ft', 'fort'),
]]
_abbreviations = [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
("mrs", "misess"),
("mr", "mister"),
("dr", "doctor"),
("st", "saint"),
("co", "company"),
("jr", "junior"),
("maj", "major"),
("gen", "general"),
("drs", "doctors"),
("rev", "reverend"),
("lt", "lieutenant"),
("hon", "honorable"),
("sgt", "sergeant"),
("capt", "captain"),
("esq", "esquire"),
("ltd", "limited"),
("col", "colonel"),
("ft", "fort"),
]
]
def expand_abbreviations(text):
@ -127,7 +208,7 @@ def lowercase(text):
def collapse_whitespace(text):
return re.sub(_whitespace_re, ' ', text)
return re.sub(_whitespace_re, " ", text)
def convert_to_ascii(text):
@ -156,46 +237,42 @@ def english_cleaners(text):
return text
_inflect = inflect.engine()
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
_number_re = re.compile(r'[0-9]+')
_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
_number_re = re.compile(r"[0-9]+")
def _remove_commas(m):
return m.group(1).replace(',', '')
return m.group(1).replace(",", "")
def _expand_decimal_point(m):
return m.group(1).replace('.', ' point ')
return m.group(1).replace(".", " point ")
def _expand_dollars(m):
match = m.group(1)
parts = match.split('.')
parts = match.split(".")
if len(parts) > 2:
return match + ' dollars'
return match + " dollars"
dollars = int(parts[0]) if parts[0] else 0
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
if dollars and cents:
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
cent_unit = 'cent' if cents == 1 else 'cents'
return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
dollar_unit = "dollar" if dollars == 1 else "dollars"
cent_unit = "cent" if cents == 1 else "cents"
return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
elif dollars:
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
return '%s %s' % (dollars, dollar_unit)
dollar_unit = "dollar" if dollars == 1 else "dollars"
return "%s %s" % (dollars, dollar_unit)
elif cents:
cent_unit = 'cent' if cents == 1 else 'cents'
return '%s %s' % (cents, cent_unit)
cent_unit = "cent" if cents == 1 else "cents"
return "%s %s" % (cents, cent_unit)
else:
return 'zero dollars'
return "zero dollars"
def _expand_ordinal(m):
@ -206,37 +283,37 @@ def _expand_number(m):
num = int(m.group(0))
if num > 1000 and num < 3000:
if num == 2000:
return 'two thousand'
return "two thousand"
elif num > 2000 and num < 2010:
return 'two thousand ' + _inflect.number_to_words(num % 100)
return "two thousand " + _inflect.number_to_words(num % 100)
elif num % 100 == 0:
return _inflect.number_to_words(num // 100) + ' hundred'
return _inflect.number_to_words(num // 100) + " hundred"
else:
return _inflect.number_to_words(num, andword='', zero='oh',
group=2).replace(', ', ' ')
return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
else:
return _inflect.number_to_words(num, andword='')
return _inflect.number_to_words(num, andword="")
def normalize_numbers(text):
text = re.sub(_comma_number_re, _remove_commas, text)
text = re.sub(_pounds_re, r'\1 pounds', text)
text = re.sub(_pounds_re, r"\1 pounds", text)
text = re.sub(_dollars_re, _expand_dollars, text)
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
text = re.sub(_ordinal_re, _expand_ordinal, text)
text = re.sub(_number_re, _expand_number, text)
return text
""" from https://github.com/keithito/tacotron """
_pad = '_'
_punctuation = '!\'(),.:;? '
_special = '-'
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
_pad = "_"
_punctuation = "!'(),.:;? "
_special = "-"
_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
# Prepend "@" to ARPAbet symbols to ensure uniqueness:
_arpabet = ['@' + s for s in valid_symbols]
_arpabet = ["@" + s for s in valid_symbols]
# Export all symbols:
symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpabet
@ -245,7 +322,7 @@ symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpab
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
_curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)')
_curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)")
def get_arpabet(word, dictionary):
@ -257,7 +334,7 @@ def get_arpabet(word, dictionary):
def text_to_sequence(text, cleaner_names=[english_cleaners], dictionary=None):
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
@ -269,9 +346,9 @@ def text_to_sequence(text, cleaner_names=[english_cleaners], dictionary=None):
Returns:
List of integers corresponding to the symbols in the text
'''
"""
sequence = []
space = _symbols_to_sequence(' ')
space = _symbols_to_sequence(" ")
# Check for curly braces and treat their contents as ARPAbet:
while len(text):
m = _curly_re.match(text)
@ -292,7 +369,7 @@ def text_to_sequence(text, cleaner_names=[english_cleaners], dictionary=None):
sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names))
sequence += _arpabet_to_sequence(m.group(2))
text = m.group(3)
# remove trailing space
if dictionary is not None:
sequence = sequence[:-1] if sequence[-1] == space[0] else sequence
@ -300,16 +377,16 @@ def text_to_sequence(text, cleaner_names=[english_cleaners], dictionary=None):
def sequence_to_text(sequence):
'''Converts a sequence of IDs back to a string'''
result = ''
"""Converts a sequence of IDs back to a string"""
result = ""
for symbol_id in sequence:
if symbol_id in _id_to_symbol:
s = _id_to_symbol[symbol_id]
# Enclose ARPAbet back in curly braces:
if len(s) > 1 and s[0] == '@':
s = '{%s}' % s[1:]
if len(s) > 1 and s[0] == "@":
s = "{%s}" % s[1:]
result += s
return result.replace('}{', ' ')
return result.replace("}{", " ")
def _clean_text(text, cleaner_names):
@ -323,17 +400,18 @@ def _symbols_to_sequence(symbols):
def _arpabet_to_sequence(text):
return _symbols_to_sequence(['@' + s for s in text.split()])
return _symbols_to_sequence(["@" + s for s in text.split()])
def _should_keep_symbol(s):
return s in _symbol_to_id and s != '_' and s != '~'
return s in _symbol_to_id and s != "_" and s != "~"
VOCAB_FILES_NAMES = {
"dict_file": "dict_file.txt",
}
class GradTTSTokenizer(PreTrainedTokenizer):
vocab_files_names = VOCAB_FILES_NAMES
@ -341,17 +419,17 @@ class GradTTSTokenizer(PreTrainedTokenizer):
super().__init__(**kwargs)
self.cmu = CMUDict(dict_file)
self.dict_file = dict_file
def __call__(self, text):
x = torch.LongTensor(intersperse(text_to_sequence(text, dictionary=self.cmu), len(symbols)))[None]
x_lengths = torch.LongTensor([x.shape[-1]])
return x, x_lengths
def save_vocabulary(self, save_directory: str, filename_prefix = None):
def save_vocabulary(self, save_directory: str, filename_prefix=None):
dict_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["dict_file"]
)
copyfile(self.dict_file, dict_file)
return (dict_file, )
return (dict_file,)

View File

@ -4,13 +4,13 @@ import math
import torch
from torch import nn
import tqdm
import tqdm
from diffusers import DiffusionPipeline
from diffusers.configuration_utils import ConfigMixin
from diffusers.modeling_utils import ModelMixin
from diffusers import DiffusionPipeline
from .grad_tts_utils import GradTTSTokenizer # flake8: noqa
from .grad_tts_utils import GradTTSTokenizer # flake8: noqa
def sequence_mask(length, max_length=None):
@ -382,7 +382,7 @@ class TextEncoder(ModelMixin, ConfigMixin):
self.window_size = window_size
self.spk_emb_dim = spk_emb_dim
self.n_spks = n_spks
self.emb = torch.nn.Embedding(n_vocab, n_channels)
torch.nn.init.normal_(self.emb.weight, 0.0, n_channels**-0.5)
@ -403,7 +403,7 @@ class TextEncoder(ModelMixin, ConfigMixin):
n_channels + (spk_emb_dim if n_spks > 1 else 0), filter_channels_dp, kernel_size, p_dropout
)
def forward(self, x, x_lengths, spk=None):
def forward(self, x, x_lengths, spk=None):
x = self.emb(x) * math.sqrt(self.n_channels)
x = torch.transpose(x, 1, -1)
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
@ -424,26 +424,30 @@ class GradTTS(DiffusionPipeline):
def __init__(self, unet, text_encoder, noise_scheduler, tokenizer):
super().__init__()
noise_scheduler = noise_scheduler.set_format("pt")
self.register_modules(unet=unet, text_encoder=text_encoder, noise_scheduler=noise_scheduler, tokenizer=tokenizer)
self.register_modules(
unet=unet, text_encoder=text_encoder, noise_scheduler=noise_scheduler, tokenizer=tokenizer
)
@torch.no_grad()
def __call__(self, text, num_inference_steps=50, temperature=1.3, length_scale=0.91, speaker_id=15, torch_device=None):
def __call__(
self, text, num_inference_steps=50, temperature=1.3, length_scale=0.91, speaker_id=15, torch_device=None
):
if torch_device is None:
torch_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.unet.to(torch_device)
self.text_encoder.to(torch_device)
x, x_lengths = self.tokenizer(text)
x = x.to(torch_device)
x_lengths = x_lengths.to(torch_device)
if speaker_id is not None:
speaker_id= torch.LongTensor([speaker_id]).to(torch_device)
speaker_id = torch.LongTensor([speaker_id]).to(torch_device)
# Get encoder_outputs `mu_x` and log-scaled token durations `logw`
mu_x, logw, x_mask = self.text_encoder(x, x_lengths)
w = torch.exp(logw) * x_mask
w_ceil = torch.ceil(w) * length_scale
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
@ -461,16 +465,16 @@ class GradTTS(DiffusionPipeline):
# Sample latent representation from terminal distribution N(mu_y, I)
z = mu_y + torch.randn_like(mu_y, device=mu_y.device) / temperature
xt = z * y_mask
h = 1.0 / num_inference_steps
for t in tqdm.tqdm(range(num_inference_steps), total=num_inference_steps):
t = (1.0 - (t + 0.5)*h) * torch.ones(z.shape[0], dtype=z.dtype, device=z.device)
t = (1.0 - (t + 0.5) * h) * torch.ones(z.shape[0], dtype=z.dtype, device=z.device)
time = t.unsqueeze(-1).unsqueeze(-1)
residual = self.unet(xt, y_mask, mu_y, t, speaker_id)
xt = self.noise_scheduler.step(xt, residual, mu_y, h, time)
xt = xt * y_mask
return xt[:, :, :y_max_length]
return xt[:, :, :y_max_length]

View File

@ -19,6 +19,6 @@
from .classifier_free_guidance import ClassifierFreeGuidanceScheduler
from .scheduling_ddim import DDIMScheduler
from .scheduling_ddpm import DDPMScheduler
from .scheduling_pndm import PNDMScheduler
from .scheduling_grad_tts import GradTTSScheduler
from .scheduling_pndm import PNDMScheduler
from .scheduling_utils import SchedulerMixin

View File

@ -36,11 +36,11 @@ class GradTTSScheduler(SchedulerMixin, ConfigMixin):
self.timesteps = int(timesteps)
self.set_format(tensor_format=tensor_format)
def sample_noise(self, timestep):
noise = self.beta_start + (self.beta_end - self.beta_start) * timestep
return noise
def step(self, xt, residual, mu, h, timestep):
noise_t = self.sample_noise(timestep)
dxt = 0.5 * (mu - xt - residual)