finish GradTTS pipeline

This commit is contained in:
patil-suraj 2022-06-16 18:08:33 +02:00
parent 8007393614
commit 1d2551d716
1 changed files with 21 additions and 6 deletions

View File

@ -4,12 +4,13 @@ import math
import torch import torch
from torch import nn from torch import nn
import tqdm
from diffusers.configuration_utils import ConfigMixin from diffusers.configuration_utils import ConfigMixin
from diffusers.modeling_utils import ModelMixin from diffusers.modeling_utils import ModelMixin
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
from .grad_tts_utils import text_to_sequence from .grad_tts_utils import GradTTSTokenizer # flake8: noqa
def sequence_mask(length, max_length=None): def sequence_mask(length, max_length=None):
@ -385,24 +386,29 @@ class TextEncoder(ModelMixin, ConfigMixin):
x_dp = torch.detach(x) x_dp = torch.detach(x)
logw = self.proj_w(x_dp, x_mask) logw = self.proj_w(x_dp, x_mask)
return mu, logw, x_mask, spk return mu, logw, x_mask
class GradTTS(DiffusionPipeline): class GradTTS(DiffusionPipeline):
def __init__(self, unet, text_encoder, noise_scheduler, tokenizer): def __init__(self, unet, text_encoder, noise_scheduler, tokenizer):
super().__init__() super().__init__()
noise_scheduler = noise_scheduler.set_format("pt") noise_scheduler = noise_scheduler.set_format("pt")
self.register_modules(diffwave=unet, text_encoder=text_encoder, noise_scheduler=noise_scheduler, tokenizer=tokenizer) self.register_modules(unet=unet, text_encoder=text_encoder, noise_scheduler=noise_scheduler, tokenizer=tokenizer)
@torch.no_grad() @torch.no_grad()
def __call__(self, text, num_inference_steps, generator, temperature, length_scale, speaker_id=None, torch_device=None): def __call__(self, text, num_inference_steps, generator, temperature, length_scale, speaker_id=None, torch_device=None):
if torch_device is None: if torch_device is None:
torch_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") torch_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.unet.to(torch_device)
self.text_encoder.to(torch_device)
x, x_lengths = self.tokenizer(text) x, x_lengths = self.tokenizer(text)
x = x.to(torch_device)
x_lengths = x_lengths.to(torch_device)
if speaker_id is not None: if speaker_id is not None:
speaker_id= torch.longTensor([speaker_id]) speaker_id= torch.LongTensor([speaker_id]).to(torch_device)
# Get encoder_outputs `mu_x` and log-scaled token durations `logw` # Get encoder_outputs `mu_x` and log-scaled token durations `logw`
mu_x, logw, x_mask = self.text_encoder(x, x_lengths) mu_x, logw, x_mask = self.text_encoder(x, x_lengths)
@ -426,6 +432,15 @@ class GradTTS(DiffusionPipeline):
# Sample latent representation from terminal distribution N(mu_y, I) # Sample latent representation from terminal distribution N(mu_y, I)
z = mu_y + torch.randn_like(mu_y, device=mu_y.device) / temperature z = mu_y + torch.randn_like(mu_y, device=mu_y.device) / temperature
xt = z * y_mask
h = 1.0 / num_inference_steps
for t in tqdm.tqdm(range(num_inference_steps), total=num_inference_steps):
t = (1.0 - (t + 0.5)*h) * torch.ones(z.shape[0], dtype=z.dtype, device=z.device)
time = t.unsqueeze(-1).unsqueeze(-1)
residual = self.unet(xt, y_mask, mu_y, t, speaker_id)
xt = self.noise_scheduler.step(xt, residual, mu_y, h, time)
xt = xt * y_mask
return xt[:, :, :y_max_length]