finish GradTTS pipeline
This commit is contained in:
parent
8007393614
commit
1d2551d716
|
@ -4,12 +4,13 @@ import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
import tqdm
|
||||||
|
|
||||||
from diffusers.configuration_utils import ConfigMixin
|
from diffusers.configuration_utils import ConfigMixin
|
||||||
from diffusers.modeling_utils import ModelMixin
|
from diffusers.modeling_utils import ModelMixin
|
||||||
from diffusers import DiffusionPipeline
|
from diffusers import DiffusionPipeline
|
||||||
|
|
||||||
from .grad_tts_utils import text_to_sequence
|
from .grad_tts_utils import GradTTSTokenizer # flake8: noqa
|
||||||
|
|
||||||
|
|
||||||
def sequence_mask(length, max_length=None):
|
def sequence_mask(length, max_length=None):
|
||||||
|
@ -385,24 +386,29 @@ class TextEncoder(ModelMixin, ConfigMixin):
|
||||||
x_dp = torch.detach(x)
|
x_dp = torch.detach(x)
|
||||||
logw = self.proj_w(x_dp, x_mask)
|
logw = self.proj_w(x_dp, x_mask)
|
||||||
|
|
||||||
return mu, logw, x_mask, spk
|
return mu, logw, x_mask
|
||||||
|
|
||||||
|
|
||||||
class GradTTS(DiffusionPipeline):
|
class GradTTS(DiffusionPipeline):
|
||||||
def __init__(self, unet, text_encoder, noise_scheduler, tokenizer):
|
def __init__(self, unet, text_encoder, noise_scheduler, tokenizer):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
noise_scheduler = noise_scheduler.set_format("pt")
|
noise_scheduler = noise_scheduler.set_format("pt")
|
||||||
self.register_modules(diffwave=unet, text_encoder=text_encoder, noise_scheduler=noise_scheduler, tokenizer=tokenizer)
|
self.register_modules(unet=unet, text_encoder=text_encoder, noise_scheduler=noise_scheduler, tokenizer=tokenizer)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def __call__(self, text, num_inference_steps, generator, temperature, length_scale, speaker_id=None, torch_device=None):
|
def __call__(self, text, num_inference_steps, generator, temperature, length_scale, speaker_id=None, torch_device=None):
|
||||||
if torch_device is None:
|
if torch_device is None:
|
||||||
torch_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
torch_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
self.unet.to(torch_device)
|
||||||
|
self.text_encoder.to(torch_device)
|
||||||
|
|
||||||
x, x_lengths = self.tokenizer(text)
|
x, x_lengths = self.tokenizer(text)
|
||||||
|
x = x.to(torch_device)
|
||||||
|
x_lengths = x_lengths.to(torch_device)
|
||||||
|
|
||||||
if speaker_id is not None:
|
if speaker_id is not None:
|
||||||
speaker_id= torch.longTensor([speaker_id])
|
speaker_id= torch.LongTensor([speaker_id]).to(torch_device)
|
||||||
|
|
||||||
# Get encoder_outputs `mu_x` and log-scaled token durations `logw`
|
# Get encoder_outputs `mu_x` and log-scaled token durations `logw`
|
||||||
mu_x, logw, x_mask = self.text_encoder(x, x_lengths)
|
mu_x, logw, x_mask = self.text_encoder(x, x_lengths)
|
||||||
|
@ -426,6 +432,15 @@ class GradTTS(DiffusionPipeline):
|
||||||
# Sample latent representation from terminal distribution N(mu_y, I)
|
# Sample latent representation from terminal distribution N(mu_y, I)
|
||||||
z = mu_y + torch.randn_like(mu_y, device=mu_y.device) / temperature
|
z = mu_y + torch.randn_like(mu_y, device=mu_y.device) / temperature
|
||||||
|
|
||||||
|
xt = z * y_mask
|
||||||
|
h = 1.0 / num_inference_steps
|
||||||
|
for t in tqdm.tqdm(range(num_inference_steps), total=num_inference_steps):
|
||||||
|
t = (1.0 - (t + 0.5)*h) * torch.ones(z.shape[0], dtype=z.dtype, device=z.device)
|
||||||
|
time = t.unsqueeze(-1).unsqueeze(-1)
|
||||||
|
|
||||||
|
residual = self.unet(xt, y_mask, mu_y, t, speaker_id)
|
||||||
|
|
||||||
|
xt = self.noise_scheduler.step(xt, residual, mu_y, h, time)
|
||||||
|
xt = xt * y_mask
|
||||||
|
|
||||||
|
return xt[:, :, :y_max_length]
|
||||||
|
|
Loading…
Reference in New Issue