being pipeline
This commit is contained in:
parent
986cc9b2f4
commit
7b55d334d5
|
@ -7,6 +7,9 @@ from torch import nn
|
|||
|
||||
from diffusers.configuration_utils import ConfigMixin
|
||||
from diffusers.modeling_utils import ModelMixin
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
from .grad_tts_utils import text_to_sequence
|
||||
|
||||
|
||||
def sequence_mask(length, max_length=None):
|
||||
|
@ -383,3 +386,18 @@ class TextEncoder(ModelMixin, ConfigMixin):
|
|||
logw = self.proj_w(x_dp, x_mask)
|
||||
|
||||
return mu, logw, x_mask
|
||||
|
||||
|
||||
class GradTTS(DiffusionPipeline):
|
||||
def __init__(self, unet, noise_scheduler):
|
||||
super().__init__()
|
||||
noise_scheduler = noise_scheduler.set_format("pt")
|
||||
self.register_modules(diffwave=unet, noise_scheduler=noise_scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, text, speaker_id, num_inference_steps, generator, torch_device=None):
|
||||
if torch_device is None:
|
||||
torch_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
pass
|
||||
|
||||
|
Loading…
Reference in New Issue