being pipeline

This commit is contained in:
patil-suraj 2022-06-16 14:08:53 +02:00
parent 986cc9b2f4
commit 7b55d334d5
1 changed files with 18 additions and 0 deletions

View File

@ -7,6 +7,9 @@ from torch import nn
from diffusers.configuration_utils import ConfigMixin
from diffusers.modeling_utils import ModelMixin
from diffusers import DiffusionPipeline
from .grad_tts_utils import text_to_sequence
def sequence_mask(length, max_length=None):
@ -383,3 +386,18 @@ class TextEncoder(ModelMixin, ConfigMixin):
logw = self.proj_w(x_dp, x_mask)
return mu, logw, x_mask
class GradTTS(DiffusionPipeline):
def __init__(self, unet, noise_scheduler):
super().__init__()
noise_scheduler = noise_scheduler.set_format("pt")
self.register_modules(diffwave=unet, noise_scheduler=noise_scheduler)
@torch.no_grad()
def __call__(self, text, speaker_id, num_inference_steps, generator, torch_device=None):
if torch_device is None:
torch_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
pass