style
This commit is contained in:
parent
988369a01c
commit
ace07110c1
|
@ -145,8 +145,9 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
|
|||
|
||||
if n_spks > 1:
|
||||
self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim)
|
||||
self.spk_mlp = torch.nn.Sequential(torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(),
|
||||
torch.nn.Linear(spk_emb_dim * 4, n_feats))
|
||||
self.spk_mlp = torch.nn.Sequential(
|
||||
torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(), torch.nn.Linear(spk_emb_dim * 4, n_feats)
|
||||
)
|
||||
|
||||
self.time_pos_emb = SinusoidalPosEmb(dim)
|
||||
self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(), torch.nn.Linear(dim * 4, dim))
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
# tokenizer
|
||||
|
||||
import re
|
||||
import os
|
||||
import re
|
||||
from shutil import copyfile
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
try:
|
||||
from transformers import PreTrainedTokenizer
|
||||
except:
|
||||
|
@ -25,17 +26,95 @@ except:
|
|||
|
||||
|
||||
valid_symbols = [
|
||||
'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2',
|
||||
'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2',
|
||||
'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY',
|
||||
'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1',
|
||||
'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0',
|
||||
'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW',
|
||||
'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH'
|
||||
"AA",
|
||||
"AA0",
|
||||
"AA1",
|
||||
"AA2",
|
||||
"AE",
|
||||
"AE0",
|
||||
"AE1",
|
||||
"AE2",
|
||||
"AH",
|
||||
"AH0",
|
||||
"AH1",
|
||||
"AH2",
|
||||
"AO",
|
||||
"AO0",
|
||||
"AO1",
|
||||
"AO2",
|
||||
"AW",
|
||||
"AW0",
|
||||
"AW1",
|
||||
"AW2",
|
||||
"AY",
|
||||
"AY0",
|
||||
"AY1",
|
||||
"AY2",
|
||||
"B",
|
||||
"CH",
|
||||
"D",
|
||||
"DH",
|
||||
"EH",
|
||||
"EH0",
|
||||
"EH1",
|
||||
"EH2",
|
||||
"ER",
|
||||
"ER0",
|
||||
"ER1",
|
||||
"ER2",
|
||||
"EY",
|
||||
"EY0",
|
||||
"EY1",
|
||||
"EY2",
|
||||
"F",
|
||||
"G",
|
||||
"HH",
|
||||
"IH",
|
||||
"IH0",
|
||||
"IH1",
|
||||
"IH2",
|
||||
"IY",
|
||||
"IY0",
|
||||
"IY1",
|
||||
"IY2",
|
||||
"JH",
|
||||
"K",
|
||||
"L",
|
||||
"M",
|
||||
"N",
|
||||
"NG",
|
||||
"OW",
|
||||
"OW0",
|
||||
"OW1",
|
||||
"OW2",
|
||||
"OY",
|
||||
"OY0",
|
||||
"OY1",
|
||||
"OY2",
|
||||
"P",
|
||||
"R",
|
||||
"S",
|
||||
"SH",
|
||||
"T",
|
||||
"TH",
|
||||
"UH",
|
||||
"UH0",
|
||||
"UH1",
|
||||
"UH2",
|
||||
"UW",
|
||||
"UW0",
|
||||
"UW1",
|
||||
"UW2",
|
||||
"V",
|
||||
"W",
|
||||
"Y",
|
||||
"Z",
|
||||
"ZH",
|
||||
]
|
||||
|
||||
_valid_symbol_set = set(valid_symbols)
|
||||
|
||||
|
||||
def intersperse(lst, item):
|
||||
# Adds blank symbol
|
||||
result = [item] * (len(lst) * 2 + 1)
|
||||
|
@ -46,7 +125,7 @@ def intersperse(lst, item):
|
|||
class CMUDict:
|
||||
def __init__(self, file_or_path, keep_ambiguous=True):
|
||||
if isinstance(file_or_path, str):
|
||||
with open(file_or_path, encoding='latin-1') as f:
|
||||
with open(file_or_path, encoding="latin-1") as f:
|
||||
entries = _parse_cmudict(f)
|
||||
else:
|
||||
entries = _parse_cmudict(file_or_path)
|
||||
|
@ -61,15 +140,15 @@ class CMUDict:
|
|||
return self._entries.get(word.upper())
|
||||
|
||||
|
||||
_alt_re = re.compile(r'\([0-9]+\)')
|
||||
_alt_re = re.compile(r"\([0-9]+\)")
|
||||
|
||||
|
||||
def _parse_cmudict(file):
|
||||
cmudict = {}
|
||||
for line in file:
|
||||
if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"):
|
||||
parts = line.split(' ')
|
||||
word = re.sub(_alt_re, '', parts[0])
|
||||
if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"):
|
||||
parts = line.split(" ")
|
||||
word = re.sub(_alt_re, "", parts[0])
|
||||
pronunciation = _get_pronunciation(parts[1])
|
||||
if pronunciation:
|
||||
if word in cmudict:
|
||||
|
@ -80,36 +159,38 @@ def _parse_cmudict(file):
|
|||
|
||||
|
||||
def _get_pronunciation(s):
|
||||
parts = s.strip().split(' ')
|
||||
parts = s.strip().split(" ")
|
||||
for part in parts:
|
||||
if part not in _valid_symbol_set:
|
||||
return None
|
||||
return ' '.join(parts)
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
_whitespace_re = re.compile(r"\s+")
|
||||
|
||||
_whitespace_re = re.compile(r'\s+')
|
||||
|
||||
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
|
||||
('mrs', 'misess'),
|
||||
('mr', 'mister'),
|
||||
('dr', 'doctor'),
|
||||
('st', 'saint'),
|
||||
('co', 'company'),
|
||||
('jr', 'junior'),
|
||||
('maj', 'major'),
|
||||
('gen', 'general'),
|
||||
('drs', 'doctors'),
|
||||
('rev', 'reverend'),
|
||||
('lt', 'lieutenant'),
|
||||
('hon', 'honorable'),
|
||||
('sgt', 'sergeant'),
|
||||
('capt', 'captain'),
|
||||
('esq', 'esquire'),
|
||||
('ltd', 'limited'),
|
||||
('col', 'colonel'),
|
||||
('ft', 'fort'),
|
||||
]]
|
||||
_abbreviations = [
|
||||
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("mrs", "misess"),
|
||||
("mr", "mister"),
|
||||
("dr", "doctor"),
|
||||
("st", "saint"),
|
||||
("co", "company"),
|
||||
("jr", "junior"),
|
||||
("maj", "major"),
|
||||
("gen", "general"),
|
||||
("drs", "doctors"),
|
||||
("rev", "reverend"),
|
||||
("lt", "lieutenant"),
|
||||
("hon", "honorable"),
|
||||
("sgt", "sergeant"),
|
||||
("capt", "captain"),
|
||||
("esq", "esquire"),
|
||||
("ltd", "limited"),
|
||||
("col", "colonel"),
|
||||
("ft", "fort"),
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
def expand_abbreviations(text):
|
||||
|
@ -127,7 +208,7 @@ def lowercase(text):
|
|||
|
||||
|
||||
def collapse_whitespace(text):
|
||||
return re.sub(_whitespace_re, ' ', text)
|
||||
return re.sub(_whitespace_re, " ", text)
|
||||
|
||||
|
||||
def convert_to_ascii(text):
|
||||
|
@ -156,46 +237,42 @@ def english_cleaners(text):
|
|||
return text
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
_inflect = inflect.engine()
|
||||
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
|
||||
_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
|
||||
_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
|
||||
_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
|
||||
_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
|
||||
_number_re = re.compile(r'[0-9]+')
|
||||
_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
|
||||
_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
|
||||
_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
|
||||
_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
|
||||
_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
|
||||
_number_re = re.compile(r"[0-9]+")
|
||||
|
||||
|
||||
def _remove_commas(m):
|
||||
return m.group(1).replace(',', '')
|
||||
return m.group(1).replace(",", "")
|
||||
|
||||
|
||||
def _expand_decimal_point(m):
|
||||
return m.group(1).replace('.', ' point ')
|
||||
return m.group(1).replace(".", " point ")
|
||||
|
||||
|
||||
def _expand_dollars(m):
|
||||
match = m.group(1)
|
||||
parts = match.split('.')
|
||||
parts = match.split(".")
|
||||
if len(parts) > 2:
|
||||
return match + ' dollars'
|
||||
return match + " dollars"
|
||||
dollars = int(parts[0]) if parts[0] else 0
|
||||
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
|
||||
if dollars and cents:
|
||||
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
|
||||
cent_unit = 'cent' if cents == 1 else 'cents'
|
||||
return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
|
||||
dollar_unit = "dollar" if dollars == 1 else "dollars"
|
||||
cent_unit = "cent" if cents == 1 else "cents"
|
||||
return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
|
||||
elif dollars:
|
||||
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
|
||||
return '%s %s' % (dollars, dollar_unit)
|
||||
dollar_unit = "dollar" if dollars == 1 else "dollars"
|
||||
return "%s %s" % (dollars, dollar_unit)
|
||||
elif cents:
|
||||
cent_unit = 'cent' if cents == 1 else 'cents'
|
||||
return '%s %s' % (cents, cent_unit)
|
||||
cent_unit = "cent" if cents == 1 else "cents"
|
||||
return "%s %s" % (cents, cent_unit)
|
||||
else:
|
||||
return 'zero dollars'
|
||||
return "zero dollars"
|
||||
|
||||
|
||||
def _expand_ordinal(m):
|
||||
|
@ -206,37 +283,37 @@ def _expand_number(m):
|
|||
num = int(m.group(0))
|
||||
if num > 1000 and num < 3000:
|
||||
if num == 2000:
|
||||
return 'two thousand'
|
||||
return "two thousand"
|
||||
elif num > 2000 and num < 2010:
|
||||
return 'two thousand ' + _inflect.number_to_words(num % 100)
|
||||
return "two thousand " + _inflect.number_to_words(num % 100)
|
||||
elif num % 100 == 0:
|
||||
return _inflect.number_to_words(num // 100) + ' hundred'
|
||||
return _inflect.number_to_words(num // 100) + " hundred"
|
||||
else:
|
||||
return _inflect.number_to_words(num, andword='', zero='oh',
|
||||
group=2).replace(', ', ' ')
|
||||
return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
|
||||
else:
|
||||
return _inflect.number_to_words(num, andword='')
|
||||
return _inflect.number_to_words(num, andword="")
|
||||
|
||||
|
||||
def normalize_numbers(text):
|
||||
text = re.sub(_comma_number_re, _remove_commas, text)
|
||||
text = re.sub(_pounds_re, r'\1 pounds', text)
|
||||
text = re.sub(_pounds_re, r"\1 pounds", text)
|
||||
text = re.sub(_dollars_re, _expand_dollars, text)
|
||||
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
|
||||
text = re.sub(_ordinal_re, _expand_ordinal, text)
|
||||
text = re.sub(_number_re, _expand_number, text)
|
||||
return text
|
||||
|
||||
|
||||
""" from https://github.com/keithito/tacotron """
|
||||
|
||||
|
||||
_pad = '_'
|
||||
_punctuation = '!\'(),.:;? '
|
||||
_special = '-'
|
||||
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
||||
_pad = "_"
|
||||
_punctuation = "!'(),.:;? "
|
||||
_special = "-"
|
||||
_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
||||
|
||||
# Prepend "@" to ARPAbet symbols to ensure uniqueness:
|
||||
_arpabet = ['@' + s for s in valid_symbols]
|
||||
_arpabet = ["@" + s for s in valid_symbols]
|
||||
|
||||
# Export all symbols:
|
||||
symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpabet
|
||||
|
@ -245,7 +322,7 @@ symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpab
|
|||
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
||||
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
|
||||
|
||||
_curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)')
|
||||
_curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)")
|
||||
|
||||
|
||||
def get_arpabet(word, dictionary):
|
||||
|
@ -257,7 +334,7 @@ def get_arpabet(word, dictionary):
|
|||
|
||||
|
||||
def text_to_sequence(text, cleaner_names=[english_cleaners], dictionary=None):
|
||||
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
||||
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
||||
|
||||
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
|
||||
in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
|
||||
|
@ -269,9 +346,9 @@ def text_to_sequence(text, cleaner_names=[english_cleaners], dictionary=None):
|
|||
|
||||
Returns:
|
||||
List of integers corresponding to the symbols in the text
|
||||
'''
|
||||
"""
|
||||
sequence = []
|
||||
space = _symbols_to_sequence(' ')
|
||||
space = _symbols_to_sequence(" ")
|
||||
# Check for curly braces and treat their contents as ARPAbet:
|
||||
while len(text):
|
||||
m = _curly_re.match(text)
|
||||
|
@ -292,7 +369,7 @@ def text_to_sequence(text, cleaner_names=[english_cleaners], dictionary=None):
|
|||
sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names))
|
||||
sequence += _arpabet_to_sequence(m.group(2))
|
||||
text = m.group(3)
|
||||
|
||||
|
||||
# remove trailing space
|
||||
if dictionary is not None:
|
||||
sequence = sequence[:-1] if sequence[-1] == space[0] else sequence
|
||||
|
@ -300,16 +377,16 @@ def text_to_sequence(text, cleaner_names=[english_cleaners], dictionary=None):
|
|||
|
||||
|
||||
def sequence_to_text(sequence):
|
||||
'''Converts a sequence of IDs back to a string'''
|
||||
result = ''
|
||||
"""Converts a sequence of IDs back to a string"""
|
||||
result = ""
|
||||
for symbol_id in sequence:
|
||||
if symbol_id in _id_to_symbol:
|
||||
s = _id_to_symbol[symbol_id]
|
||||
# Enclose ARPAbet back in curly braces:
|
||||
if len(s) > 1 and s[0] == '@':
|
||||
s = '{%s}' % s[1:]
|
||||
if len(s) > 1 and s[0] == "@":
|
||||
s = "{%s}" % s[1:]
|
||||
result += s
|
||||
return result.replace('}{', ' ')
|
||||
return result.replace("}{", " ")
|
||||
|
||||
|
||||
def _clean_text(text, cleaner_names):
|
||||
|
@ -323,17 +400,18 @@ def _symbols_to_sequence(symbols):
|
|||
|
||||
|
||||
def _arpabet_to_sequence(text):
|
||||
return _symbols_to_sequence(['@' + s for s in text.split()])
|
||||
return _symbols_to_sequence(["@" + s for s in text.split()])
|
||||
|
||||
|
||||
def _should_keep_symbol(s):
|
||||
return s in _symbol_to_id and s != '_' and s != '~'
|
||||
return s in _symbol_to_id and s != "_" and s != "~"
|
||||
|
||||
|
||||
VOCAB_FILES_NAMES = {
|
||||
"dict_file": "dict_file.txt",
|
||||
}
|
||||
|
||||
|
||||
class GradTTSTokenizer(PreTrainedTokenizer):
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
|
||||
|
@ -341,17 +419,17 @@ class GradTTSTokenizer(PreTrainedTokenizer):
|
|||
super().__init__(**kwargs)
|
||||
self.cmu = CMUDict(dict_file)
|
||||
self.dict_file = dict_file
|
||||
|
||||
|
||||
def __call__(self, text):
|
||||
x = torch.LongTensor(intersperse(text_to_sequence(text, dictionary=self.cmu), len(symbols)))[None]
|
||||
x_lengths = torch.LongTensor([x.shape[-1]])
|
||||
return x, x_lengths
|
||||
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix = None):
|
||||
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix=None):
|
||||
dict_file = os.path.join(
|
||||
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["dict_file"]
|
||||
)
|
||||
|
||||
copyfile(self.dict_file, dict_file)
|
||||
|
||||
return (dict_file, )
|
||||
|
||||
return (dict_file,)
|
||||
|
|
|
@ -4,13 +4,13 @@ import math
|
|||
|
||||
import torch
|
||||
from torch import nn
|
||||
import tqdm
|
||||
|
||||
import tqdm
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.configuration_utils import ConfigMixin
|
||||
from diffusers.modeling_utils import ModelMixin
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
from .grad_tts_utils import GradTTSTokenizer # flake8: noqa
|
||||
from .grad_tts_utils import GradTTSTokenizer # flake8: noqa
|
||||
|
||||
|
||||
def sequence_mask(length, max_length=None):
|
||||
|
@ -382,7 +382,7 @@ class TextEncoder(ModelMixin, ConfigMixin):
|
|||
self.window_size = window_size
|
||||
self.spk_emb_dim = spk_emb_dim
|
||||
self.n_spks = n_spks
|
||||
|
||||
|
||||
self.emb = torch.nn.Embedding(n_vocab, n_channels)
|
||||
torch.nn.init.normal_(self.emb.weight, 0.0, n_channels**-0.5)
|
||||
|
||||
|
@ -403,7 +403,7 @@ class TextEncoder(ModelMixin, ConfigMixin):
|
|||
n_channels + (spk_emb_dim if n_spks > 1 else 0), filter_channels_dp, kernel_size, p_dropout
|
||||
)
|
||||
|
||||
def forward(self, x, x_lengths, spk=None):
|
||||
def forward(self, x, x_lengths, spk=None):
|
||||
x = self.emb(x) * math.sqrt(self.n_channels)
|
||||
x = torch.transpose(x, 1, -1)
|
||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||
|
@ -424,26 +424,30 @@ class GradTTS(DiffusionPipeline):
|
|||
def __init__(self, unet, text_encoder, noise_scheduler, tokenizer):
|
||||
super().__init__()
|
||||
noise_scheduler = noise_scheduler.set_format("pt")
|
||||
self.register_modules(unet=unet, text_encoder=text_encoder, noise_scheduler=noise_scheduler, tokenizer=tokenizer)
|
||||
|
||||
self.register_modules(
|
||||
unet=unet, text_encoder=text_encoder, noise_scheduler=noise_scheduler, tokenizer=tokenizer
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, text, num_inference_steps=50, temperature=1.3, length_scale=0.91, speaker_id=15, torch_device=None):
|
||||
def __call__(
|
||||
self, text, num_inference_steps=50, temperature=1.3, length_scale=0.91, speaker_id=15, torch_device=None
|
||||
):
|
||||
if torch_device is None:
|
||||
torch_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
self.unet.to(torch_device)
|
||||
self.text_encoder.to(torch_device)
|
||||
|
||||
|
||||
x, x_lengths = self.tokenizer(text)
|
||||
x = x.to(torch_device)
|
||||
x_lengths = x_lengths.to(torch_device)
|
||||
|
||||
|
||||
if speaker_id is not None:
|
||||
speaker_id= torch.LongTensor([speaker_id]).to(torch_device)
|
||||
|
||||
speaker_id = torch.LongTensor([speaker_id]).to(torch_device)
|
||||
|
||||
# Get encoder_outputs `mu_x` and log-scaled token durations `logw`
|
||||
mu_x, logw, x_mask = self.text_encoder(x, x_lengths)
|
||||
|
||||
|
||||
w = torch.exp(logw) * x_mask
|
||||
w_ceil = torch.ceil(w) * length_scale
|
||||
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
|
||||
|
@ -461,16 +465,16 @@ class GradTTS(DiffusionPipeline):
|
|||
|
||||
# Sample latent representation from terminal distribution N(mu_y, I)
|
||||
z = mu_y + torch.randn_like(mu_y, device=mu_y.device) / temperature
|
||||
|
||||
|
||||
xt = z * y_mask
|
||||
h = 1.0 / num_inference_steps
|
||||
for t in tqdm.tqdm(range(num_inference_steps), total=num_inference_steps):
|
||||
t = (1.0 - (t + 0.5)*h) * torch.ones(z.shape[0], dtype=z.dtype, device=z.device)
|
||||
t = (1.0 - (t + 0.5) * h) * torch.ones(z.shape[0], dtype=z.dtype, device=z.device)
|
||||
time = t.unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
|
||||
residual = self.unet(xt, y_mask, mu_y, t, speaker_id)
|
||||
|
||||
|
||||
xt = self.noise_scheduler.step(xt, residual, mu_y, h, time)
|
||||
xt = xt * y_mask
|
||||
|
||||
return xt[:, :, :y_max_length]
|
||||
|
||||
return xt[:, :, :y_max_length]
|
||||
|
|
|
@ -19,6 +19,6 @@
|
|||
from .classifier_free_guidance import ClassifierFreeGuidanceScheduler
|
||||
from .scheduling_ddim import DDIMScheduler
|
||||
from .scheduling_ddpm import DDPMScheduler
|
||||
from .scheduling_pndm import PNDMScheduler
|
||||
from .scheduling_grad_tts import GradTTSScheduler
|
||||
from .scheduling_pndm import PNDMScheduler
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
|
|
@ -36,11 +36,11 @@ class GradTTSScheduler(SchedulerMixin, ConfigMixin):
|
|||
self.timesteps = int(timesteps)
|
||||
|
||||
self.set_format(tensor_format=tensor_format)
|
||||
|
||||
|
||||
def sample_noise(self, timestep):
|
||||
noise = self.beta_start + (self.beta_end - self.beta_start) * timestep
|
||||
return noise
|
||||
|
||||
|
||||
def step(self, xt, residual, mu, h, timestep):
|
||||
noise_t = self.sample_noise(timestep)
|
||||
dxt = 0.5 * (mu - xt - residual)
|
||||
|
|
Loading…
Reference in New Issue