finish GradTTS pipeline
This commit is contained in:
parent
8007393614
commit
1d2551d716
|
@ -4,12 +4,13 @@ import math
|
|||
|
||||
import torch
|
||||
from torch import nn
|
||||
import tqdm
|
||||
|
||||
from diffusers.configuration_utils import ConfigMixin
|
||||
from diffusers.modeling_utils import ModelMixin
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
from .grad_tts_utils import text_to_sequence
|
||||
from .grad_tts_utils import GradTTSTokenizer # flake8: noqa
|
||||
|
||||
|
||||
def sequence_mask(length, max_length=None):
|
||||
|
@ -385,24 +386,29 @@ class TextEncoder(ModelMixin, ConfigMixin):
|
|||
x_dp = torch.detach(x)
|
||||
logw = self.proj_w(x_dp, x_mask)
|
||||
|
||||
return mu, logw, x_mask, spk
|
||||
return mu, logw, x_mask
|
||||
|
||||
|
||||
class GradTTS(DiffusionPipeline):
|
||||
def __init__(self, unet, text_encoder, noise_scheduler, tokenizer):
|
||||
super().__init__()
|
||||
noise_scheduler = noise_scheduler.set_format("pt")
|
||||
self.register_modules(diffwave=unet, text_encoder=text_encoder, noise_scheduler=noise_scheduler, tokenizer=tokenizer)
|
||||
self.register_modules(unet=unet, text_encoder=text_encoder, noise_scheduler=noise_scheduler, tokenizer=tokenizer)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, text, num_inference_steps, generator, temperature, length_scale, speaker_id=None, torch_device=None):
|
||||
if torch_device is None:
|
||||
torch_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
self.unet.to(torch_device)
|
||||
self.text_encoder.to(torch_device)
|
||||
|
||||
x, x_lengths = self.tokenizer(text)
|
||||
x = x.to(torch_device)
|
||||
x_lengths = x_lengths.to(torch_device)
|
||||
|
||||
if speaker_id is not None:
|
||||
speaker_id= torch.longTensor([speaker_id])
|
||||
speaker_id= torch.LongTensor([speaker_id]).to(torch_device)
|
||||
|
||||
# Get encoder_outputs `mu_x` and log-scaled token durations `logw`
|
||||
mu_x, logw, x_mask = self.text_encoder(x, x_lengths)
|
||||
|
@ -426,6 +432,15 @@ class GradTTS(DiffusionPipeline):
|
|||
# Sample latent representation from terminal distribution N(mu_y, I)
|
||||
z = mu_y + torch.randn_like(mu_y, device=mu_y.device) / temperature
|
||||
|
||||
xt = z * y_mask
|
||||
h = 1.0 / num_inference_steps
|
||||
for t in tqdm.tqdm(range(num_inference_steps), total=num_inference_steps):
|
||||
t = (1.0 - (t + 0.5)*h) * torch.ones(z.shape[0], dtype=z.dtype, device=z.device)
|
||||
time = t.unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
residual = self.unet(xt, y_mask, mu_y, t, speaker_id)
|
||||
|
||||
xt = self.noise_scheduler.step(xt, residual, mu_y, h, time)
|
||||
xt = xt * y_mask
|
||||
|
||||
return xt[:, :, :y_max_length]
|
Loading…
Reference in New Issue